DimChi / app.py
OthnnyEL's picture
Update app.py
d3083ec verified
# -*- coding: utf-8 -*-
"""
YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed)
Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
Notes:
- Dark UI via custom CSS (works even if Gradio theme API differs)
- Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming
- Keeps bitsandbytes / 4-bit logic intact when available
"""
import os
import threading
import gradio as gr
import importlib
import importlib.util
import torch
from huggingface_hub import login
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
TextIteratorStreamer,
)
from peft import PeftModel, PeftConfig
# -------------------- Configuration --------------------
ADAPTER_REPO_ID = "EYEDOL/FOIA" # adapter-only repo
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" # full base model referenced by adapter
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("Successfully logged into Hugging Face Hub!")
except Exception as e:
print("Warning: huggingface_hub.login() failed:", e)
else:
print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")
def is_package_installed(name: str) -> bool:
"""Return True if installed (distribution metadata present)."""
try:
import importlib.metadata as md
try:
md.distribution(name)
return True
except Exception:
return False
except Exception:
try:
importlib.import_module(name)
return True
except Exception:
return False
class WeeboAssistant:
def __init__(self):
# system prompt instructs the assistant to answer concisely in English
self.SYSTEM_PROMPT = (
"You are an intelligent assistant. Answer questions briefly and accurately. "
"Respond only in English. No long answers.\n"
)
# generation defaults tuned for speed (adjust if you need different behavior)
self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed
self.DO_SAMPLE = False # greedy = faster; set True if you want sampling
self.NUM_BEAMS = 1 # keep 1 for greedy (increase >1 for beam search)
self._init_models()
def _init_models(self):
print("Initializing models...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}")
BNB_AVAILABLE = is_package_installed("bitsandbytes")
print("bitsandbytes available:", BNB_AVAILABLE)
# load tokenizer (prefer base tokenizer)
try:
self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
print("Loaded tokenizer from BASE_MODEL_ID")
except Exception as e:
print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
print("Loaded tokenizer from ADAPTER_REPO_ID")
# ensure tokenizer has pad_token_id to avoid generation stalls
if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
else:
# fallback to 0 to prevent crashes (not ideal but safe)
self.llm_tokenizer.pad_token_id = 0
# decide device_map (never pass None)
if torch.cuda.is_available():
device_map = "auto"
else:
device_map = {"": "cpu"}
print("device_map being used for model load:", device_map)
base_model_kwargs = dict(
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
trust_remote_code=True,
)
if BNB_AVAILABLE and torch.cuda.is_available():
base_model_kwargs["load_in_4bit"] = True
print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).")
else:
print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).")
try:
self.llm_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
**base_model_kwargs,
)
# ensure use_cache set for faster autoregressive generation
try:
self.llm_model.config.use_cache = True
except Exception:
pass
print("Base model loaded from", BASE_MODEL_ID)
except Exception as e:
raise RuntimeError(
"Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: "
+ str(e)
)
# load and apply PEFT adapter
try:
try:
peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
print("Loaded PEFT config from", ADAPTER_REPO_ID)
except Exception:
peft_config = None
print("Warning: could not load PeftConfig; continuing to attempt adapter load.")
peft_kwargs = dict(
device_map=device_map,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
)
self.llm_model = PeftModel.from_pretrained(
self.llm_model,
ADAPTER_REPO_ID,
**peft_kwargs,
)
# ensure adapter-wrapped model also has use_cache
try:
self.llm_model.config.use_cache = True
except Exception:
pass
print("PEFT adapter applied from", ADAPTER_REPO_ID)
except Exception as e:
raise RuntimeError(
"Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
+ str(e)
)
# optional non-streaming pipeline (useful for quick tests)
try:
device_index = 0 if torch.cuda.is_available() else -1
self.llm_pipeline = pipeline(
"text-generation",
model=self.llm_model,
tokenizer=self.llm_tokenizer,
device=device_index,
model_kwargs={"torch_dtype": self.torch_dtype},
)
print("Created text-generation pipeline (non-streaming).")
except Exception as e:
print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
self.llm_pipeline = None
print("LLM base + adapter loaded successfully.")
def get_llm_response(self, chat_history):
# Build prompt (system + conversation)
prompt_lines = [self.SYSTEM_PROMPT]
for user_msg, assistant_msg in chat_history:
if user_msg:
prompt_lines.append("User: " + user_msg)
if assistant_msg:
prompt_lines.append("Assistant: " + assistant_msg)
prompt_lines.append("Assistant: ")
prompt = "\n".join(prompt_lines)
# Tokenize inputs
inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
try:
model_device = next(self.llm_model.parameters()).device
except StopIteration:
model_device = torch.device("cpu")
inputs = {k: v.to(model_device) for k, v in inputs.items()}
# Use TextIteratorStreamer for streaming outputs to Gradio
streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
# Prefill generation kwargs optimized for speed
input_len = inputs["input_ids"].shape[1]
max_new = self.MAX_NEW_TOKENS
max_length = input_len + max_new
generation_kwargs = dict(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
max_length=max_length, # input_len + max_new
max_new_tokens=max_new, # explicit
do_sample=self.DO_SAMPLE, # greedy if False -> faster
num_beams=self.NUM_BEAMS, # keep 1 for speed
streamer=streamer,
eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
use_cache=True,
early_stopping=True,
)
# Run generate under no_grad to save memory and time
def _generate_thread():
with torch.no_grad():
try:
self.llm_model.generate(**generation_kwargs)
except Exception as e:
print("Generation error:", e)
gen_thread = threading.Thread(target=_generate_thread, daemon=True)
gen_thread.start()
return streamer
# create assistant instance (loads model once at startup)
assistant = WeeboAssistant()
# -------------------- Gradio pipeline functions --------------------
def t2t_pipeline(text_input, chat_history):
chat_history = chat_history or []
chat_history.append((text_input, "")) # placeholder for assistant reply
yield chat_history
response_stream = assistant.get_llm_response(chat_history)
llm_response_text = ""
for text_chunk in response_stream:
llm_response_text += text_chunk
chat_history[-1] = (text_input, llm_response_text)
yield chat_history
def clear_textbox():
return gr.Textbox.update(value="")
# -------------------- MODIFIED: Modern Dark UI CSS --------------------
MODERN_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap');
:root {
--body-bg: linear-gradient(135deg, #10141a 0%, #06090f 100%);
--chat-bg: #0b0f19;
--border-color: rgba(255, 255, 255, 0.08);
--text-color: #E6EEF8;
--input-bg: #131926;
--user-msg-bg: #1B2336;
--bot-msg-bg: #0F1522;
--primary-color: #0084ff;
--primary-hover: #006fdb;
--font-family: 'Poppins', sans-serif;
}
body, .gradio-container {
background: var(--body-bg) !important;
color: var(--text-color) !important;
font-family: var(--font-family) !important;
}
.gradio-container * {
font-family: var(--font-family) !important;
}
h1, h2, h3, .markdown {
color: var(--text-color) !important;
}
.gr-block, .gr-box, .gr-row, .gr-column {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
.gr-chatbot {
background: var(--chat-bg) !important;
border: 1px solid var(--border-color) !important;
border-radius: 12px !important;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2) !important;
}
.gr-chatbot .message {
border-radius: 8px !important;
padding: 12px !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
border: none !important;
}
.gr-chatbot .message.user {
background: var(--user-msg-bg) !important;
color: var(--text-color) !important;
}
.gr-chatbot .message.bot {
background: var(--bot-msg-bg) !important;
color: var(--text-color) !important;
}
.gr-chatbot .message p { margin: 0; }
.gr-textbox, .gr-textbox textarea {
background: var(--input-bg) !important;
color: var(--text-color) !important;
border: 1px solid var(--border-color) !important;
border-radius: 8px !important;
transition: all 0.2s ease-in-out;
}
.gr-textbox:focus, .gr-textbox textarea:focus {
border-color: var(--primary-color) !important;
box-shadow: 0 0 0 2px rgba(0, 132, 255, 0.3) !important;
}
.gr-button {
background: var(--primary-color) !important;
color: white !important;
border: none !important;
border-radius: 8px !important;
box-shadow: 0 4px 12px rgba(0, 132, 255, 0.2) !important;
transition: all 0.2s ease-in-out !important;
font-weight: 500 !important;
display: flex;
justify-content: center;
align-items: center;
gap: 8px; /* Space between icon and text */
}
.gr-button:hover {
background: var(--primary-hover) !important;
transform: translateY(-2px);
box-shadow: 0 6px 16px rgba(0, 132, 255, 0.3) !important;
}
/* Hide default Gradio button text when we add our own */
.send-btn span {
font-size: 1rem;
}
/* Add a send icon to the button */
.send-btn::before {
content: '';
display: inline-block;
width: 20px;
height: 20px;
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M2.01 21L23 12 2.01 3 2 10l15 2-15 2z'/%3E%3C/svg%3E");
background-size: contain;
background-repeat: no-repeat;
background-position: center;
}
footer, .footer {
display: none !important;
}
"""
# -------------------- MODIFIED: Gradio UI with Logo --------------------
with gr.Blocks(css=MODERN_CSS, title="DimChi FOIA Assistant") as demo:
# NEW: Centered header with logo
with gr.Row():
gr.Markdown(
"""
<div style="text-align: center; display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 20px;">
<img src="file/logo.png" alt="DimChi Logo" style="max-width: 120px; margin-bottom: 15px;">
<h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">DimChi FOIA Assistant</h1>
<p style="margin: 5px 0 0 0; font-size: 1.1rem; color: #a0b0c0;">Your intelligent chat partner for FOIA inquiries.</p>
</div>
"""
)
t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520)
# NEW: Added elem_classes for specific button styling
with gr.Row():
t2t_text_in = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
scale=4,
container=False
)
t2t_submit_btn = gr.Button(
"Send",
variant="primary",
scale=1,
elem_classes="send-btn" # NEW: Class for CSS targeting
)
t2t_submit_btn.click(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot],
queue=True,
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in,
)
t2t_text_in.submit(
fn=t2t_pipeline,
inputs=[t2t_text_in, t2t_chatbot],
outputs=[t2t_chatbot],
queue=True,
).then(
fn=clear_textbox,
inputs=None,
outputs=t2t_text_in,
)
# launch
# MODIFIED: Removed debug=True for a cleaner console in production
demo.queue().launch()