|
|
|
|
|
""" |
|
|
YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed) |
|
|
Drop this file into your Hugging Face Space (replace existing app.py) or run locally. |
|
|
|
|
|
Notes: |
|
|
- Dark UI via custom CSS (works even if Gradio theme API differs) |
|
|
- Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming |
|
|
- Keeps bitsandbytes / 4-bit logic intact when available |
|
|
""" |
|
|
|
|
|
import os |
|
|
import threading |
|
|
import gradio as gr |
|
|
import importlib |
|
|
import importlib.util |
|
|
import torch |
|
|
|
|
|
from huggingface_hub import login |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
pipeline, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
|
|
|
ADAPTER_REPO_ID = "EYEDOL/FOIA" |
|
|
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface") |
|
|
if HF_TOKEN: |
|
|
try: |
|
|
login(token=HF_TOKEN) |
|
|
print("Successfully logged into Hugging Face Hub!") |
|
|
except Exception as e: |
|
|
print("Warning: huggingface_hub.login() failed:", e) |
|
|
else: |
|
|
print("Warning: HF_TOKEN not found in env. Private repos may fail to load.") |
|
|
|
|
|
|
|
|
def is_package_installed(name: str) -> bool: |
|
|
"""Return True if installed (distribution metadata present).""" |
|
|
try: |
|
|
import importlib.metadata as md |
|
|
try: |
|
|
md.distribution(name) |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
except Exception: |
|
|
try: |
|
|
importlib.import_module(name) |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
class WeeboAssistant: |
|
|
def __init__(self): |
|
|
|
|
|
self.SYSTEM_PROMPT = ( |
|
|
"You are an intelligent assistant. Answer questions briefly and accurately. " |
|
|
"Respond only in English. No long answers.\n" |
|
|
) |
|
|
|
|
|
self.MAX_NEW_TOKENS = 256 |
|
|
self.DO_SAMPLE = False |
|
|
self.NUM_BEAMS = 1 |
|
|
self._init_models() |
|
|
|
|
|
def _init_models(self): |
|
|
print("Initializing models...") |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
|
|
print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}") |
|
|
|
|
|
BNB_AVAILABLE = is_package_installed("bitsandbytes") |
|
|
print("bitsandbytes available:", BNB_AVAILABLE) |
|
|
|
|
|
|
|
|
try: |
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) |
|
|
print("Loaded tokenizer from BASE_MODEL_ID") |
|
|
except Exception as e: |
|
|
print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e) |
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True) |
|
|
print("Loaded tokenizer from ADAPTER_REPO_ID") |
|
|
|
|
|
|
|
|
if getattr(self.llm_tokenizer, "pad_token_id", None) is None: |
|
|
if getattr(self.llm_tokenizer, "eos_token_id", None) is not None: |
|
|
self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id |
|
|
else: |
|
|
|
|
|
self.llm_tokenizer.pad_token_id = 0 |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_map = "auto" |
|
|
else: |
|
|
device_map = {"": "cpu"} |
|
|
print("device_map being used for model load:", device_map) |
|
|
|
|
|
base_model_kwargs = dict( |
|
|
torch_dtype=self.torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
if BNB_AVAILABLE and torch.cuda.is_available(): |
|
|
base_model_kwargs["load_in_4bit"] = True |
|
|
print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).") |
|
|
else: |
|
|
print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).") |
|
|
|
|
|
try: |
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
**base_model_kwargs, |
|
|
) |
|
|
|
|
|
try: |
|
|
self.llm_model.config.use_cache = True |
|
|
except Exception: |
|
|
pass |
|
|
print("Base model loaded from", BASE_MODEL_ID) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: " |
|
|
+ str(e) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID) |
|
|
print("Loaded PEFT config from", ADAPTER_REPO_ID) |
|
|
except Exception: |
|
|
peft_config = None |
|
|
print("Warning: could not load PeftConfig; continuing to attempt adapter load.") |
|
|
|
|
|
peft_kwargs = dict( |
|
|
device_map=device_map, |
|
|
torch_dtype=self.torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
self.llm_model = PeftModel.from_pretrained( |
|
|
self.llm_model, |
|
|
ADAPTER_REPO_ID, |
|
|
**peft_kwargs, |
|
|
) |
|
|
|
|
|
try: |
|
|
self.llm_model.config.use_cache = True |
|
|
except Exception: |
|
|
pass |
|
|
print("PEFT adapter applied from", ADAPTER_REPO_ID) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: " |
|
|
+ str(e) |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
device_index = 0 if torch.cuda.is_available() else -1 |
|
|
self.llm_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=self.llm_model, |
|
|
tokenizer=self.llm_tokenizer, |
|
|
device=device_index, |
|
|
model_kwargs={"torch_dtype": self.torch_dtype}, |
|
|
) |
|
|
print("Created text-generation pipeline (non-streaming).") |
|
|
except Exception as e: |
|
|
print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e) |
|
|
self.llm_pipeline = None |
|
|
|
|
|
print("LLM base + adapter loaded successfully.") |
|
|
|
|
|
def get_llm_response(self, chat_history): |
|
|
|
|
|
prompt_lines = [self.SYSTEM_PROMPT] |
|
|
for user_msg, assistant_msg in chat_history: |
|
|
if user_msg: |
|
|
prompt_lines.append("User: " + user_msg) |
|
|
if assistant_msg: |
|
|
prompt_lines.append("Assistant: " + assistant_msg) |
|
|
prompt_lines.append("Assistant: ") |
|
|
prompt = "\n".join(prompt_lines) |
|
|
|
|
|
|
|
|
inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False) |
|
|
try: |
|
|
model_device = next(self.llm_model.parameters()).device |
|
|
except StopIteration: |
|
|
model_device = torch.device("cpu") |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[1] |
|
|
max_new = self.MAX_NEW_TOKENS |
|
|
max_length = input_len + max_new |
|
|
|
|
|
generation_kwargs = dict( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs.get("attention_mask", None), |
|
|
max_length=max_length, |
|
|
max_new_tokens=max_new, |
|
|
do_sample=self.DO_SAMPLE, |
|
|
num_beams=self.NUM_BEAMS, |
|
|
streamer=streamer, |
|
|
eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None), |
|
|
pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None), |
|
|
use_cache=True, |
|
|
early_stopping=True, |
|
|
) |
|
|
|
|
|
|
|
|
def _generate_thread(): |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
self.llm_model.generate(**generation_kwargs) |
|
|
except Exception as e: |
|
|
print("Generation error:", e) |
|
|
|
|
|
gen_thread = threading.Thread(target=_generate_thread, daemon=True) |
|
|
gen_thread.start() |
|
|
|
|
|
return streamer |
|
|
|
|
|
|
|
|
|
|
|
assistant = WeeboAssistant() |
|
|
|
|
|
|
|
|
|
|
|
def t2t_pipeline(text_input, chat_history): |
|
|
chat_history = chat_history or [] |
|
|
chat_history.append((text_input, "")) |
|
|
yield chat_history |
|
|
|
|
|
response_stream = assistant.get_llm_response(chat_history) |
|
|
llm_response_text = "" |
|
|
for text_chunk in response_stream: |
|
|
llm_response_text += text_chunk |
|
|
chat_history[-1] = (text_input, llm_response_text) |
|
|
yield chat_history |
|
|
|
|
|
|
|
|
def clear_textbox(): |
|
|
return gr.Textbox.update(value="") |
|
|
|
|
|
|
|
|
|
|
|
MODERN_CSS = """ |
|
|
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap'); |
|
|
|
|
|
:root { |
|
|
--body-bg: linear-gradient(135deg, #10141a 0%, #06090f 100%); |
|
|
--chat-bg: #0b0f19; |
|
|
--border-color: rgba(255, 255, 255, 0.08); |
|
|
--text-color: #E6EEF8; |
|
|
--input-bg: #131926; |
|
|
--user-msg-bg: #1B2336; |
|
|
--bot-msg-bg: #0F1522; |
|
|
--primary-color: #0084ff; |
|
|
--primary-hover: #006fdb; |
|
|
--font-family: 'Poppins', sans-serif; |
|
|
} |
|
|
|
|
|
body, .gradio-container { |
|
|
background: var(--body-bg) !important; |
|
|
color: var(--text-color) !important; |
|
|
font-family: var(--font-family) !important; |
|
|
} |
|
|
|
|
|
.gradio-container * { |
|
|
font-family: var(--font-family) !important; |
|
|
} |
|
|
|
|
|
h1, h2, h3, .markdown { |
|
|
color: var(--text-color) !important; |
|
|
} |
|
|
|
|
|
.gr-block, .gr-box, .gr-row, .gr-column { |
|
|
background: transparent !important; |
|
|
border: none !important; |
|
|
box-shadow: none !important; |
|
|
} |
|
|
|
|
|
.gr-chatbot { |
|
|
background: var(--chat-bg) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
border-radius: 12px !important; |
|
|
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2) !important; |
|
|
} |
|
|
|
|
|
.gr-chatbot .message { |
|
|
border-radius: 8px !important; |
|
|
padding: 12px !important; |
|
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important; |
|
|
border: none !important; |
|
|
} |
|
|
|
|
|
.gr-chatbot .message.user { |
|
|
background: var(--user-msg-bg) !important; |
|
|
color: var(--text-color) !important; |
|
|
} |
|
|
.gr-chatbot .message.bot { |
|
|
background: var(--bot-msg-bg) !important; |
|
|
color: var(--text-color) !important; |
|
|
} |
|
|
|
|
|
.gr-chatbot .message p { margin: 0; } |
|
|
|
|
|
.gr-textbox, .gr-textbox textarea { |
|
|
background: var(--input-bg) !important; |
|
|
color: var(--text-color) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
border-radius: 8px !important; |
|
|
transition: all 0.2s ease-in-out; |
|
|
} |
|
|
.gr-textbox:focus, .gr-textbox textarea:focus { |
|
|
border-color: var(--primary-color) !important; |
|
|
box-shadow: 0 0 0 2px rgba(0, 132, 255, 0.3) !important; |
|
|
} |
|
|
|
|
|
.gr-button { |
|
|
background: var(--primary-color) !important; |
|
|
color: white !important; |
|
|
border: none !important; |
|
|
border-radius: 8px !important; |
|
|
box-shadow: 0 4px 12px rgba(0, 132, 255, 0.2) !important; |
|
|
transition: all 0.2s ease-in-out !important; |
|
|
font-weight: 500 !important; |
|
|
display: flex; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
gap: 8px; /* Space between icon and text */ |
|
|
} |
|
|
|
|
|
.gr-button:hover { |
|
|
background: var(--primary-hover) !important; |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 6px 16px rgba(0, 132, 255, 0.3) !important; |
|
|
} |
|
|
/* Hide default Gradio button text when we add our own */ |
|
|
.send-btn span { |
|
|
font-size: 1rem; |
|
|
} |
|
|
/* Add a send icon to the button */ |
|
|
.send-btn::before { |
|
|
content: ''; |
|
|
display: inline-block; |
|
|
width: 20px; |
|
|
height: 20px; |
|
|
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M2.01 21L23 12 2.01 3 2 10l15 2-15 2z'/%3E%3C/svg%3E"); |
|
|
background-size: contain; |
|
|
background-repeat: no-repeat; |
|
|
background-position: center; |
|
|
} |
|
|
|
|
|
footer, .footer { |
|
|
display: none !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=MODERN_CSS, title="DimChi FOIA Assistant") as demo: |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
<div style="text-align: center; display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 20px;"> |
|
|
<img src="file/logo.png" alt="DimChi Logo" style="max-width: 120px; margin-bottom: 15px;"> |
|
|
<h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">DimChi FOIA Assistant</h1> |
|
|
<p style="margin: 5px 0 0 0; font-size: 1.1rem; color: #a0b0c0;">Your intelligent chat partner for FOIA inquiries.</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
t2t_text_in = gr.Textbox( |
|
|
show_label=False, |
|
|
placeholder="Type your message here...", |
|
|
scale=4, |
|
|
container=False |
|
|
) |
|
|
t2t_submit_btn = gr.Button( |
|
|
"Send", |
|
|
variant="primary", |
|
|
scale=1, |
|
|
elem_classes="send-btn" |
|
|
) |
|
|
|
|
|
t2t_submit_btn.click( |
|
|
fn=t2t_pipeline, |
|
|
inputs=[t2t_text_in, t2t_chatbot], |
|
|
outputs=[t2t_chatbot], |
|
|
queue=True, |
|
|
).then( |
|
|
fn=clear_textbox, |
|
|
inputs=None, |
|
|
outputs=t2t_text_in, |
|
|
) |
|
|
|
|
|
t2t_text_in.submit( |
|
|
fn=t2t_pipeline, |
|
|
inputs=[t2t_text_in, t2t_chatbot], |
|
|
outputs=[t2t_chatbot], |
|
|
queue=True, |
|
|
).then( |
|
|
fn=clear_textbox, |
|
|
inputs=None, |
|
|
outputs=t2t_text_in, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch() |