EYEDOL commited on
Commit
813792b
·
verified ·
1 Parent(s): 485e894

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -30
app.py CHANGED
@@ -1,14 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Refactored Salama Assistant: text-only chatbot (STT and TTS removed)
4
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
5
 
6
- Performance-focused tweaks:
7
- - lower max_new_tokens
8
- - use greedy decoding (do_sample=False) for speed
9
- - call generate() under torch.no_grad()
10
- - set model.config.use_cache = True
11
- - other minor safe optimizations
12
  """
13
 
14
  import os
@@ -61,14 +59,15 @@ def is_package_installed(name: str) -> bool:
61
 
62
  class WeeboAssistant:
63
  def __init__(self):
 
64
  self.SYSTEM_PROMPT = (
65
  "You are an intelligent assistant. Answer questions briefly and accurately. "
66
  "Respond only in English. No long answers.\n"
67
  )
68
- # set sensible defaults for generation speed
69
  self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed
70
- self.DO_SAMPLE = False # greedy = faster; set True if you need randomness
71
- self.NUM_BEAMS = 1 # keep 1 for greedy; increase for beam search (slower)
72
  self._init_models()
73
 
74
  def _init_models(self):
@@ -80,6 +79,7 @@ class WeeboAssistant:
80
  BNB_AVAILABLE = is_package_installed("bitsandbytes")
81
  print("bitsandbytes available:", BNB_AVAILABLE)
82
 
 
83
  try:
84
  self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
85
  print("Loaded tokenizer from BASE_MODEL_ID")
@@ -88,15 +88,15 @@ class WeeboAssistant:
88
  self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
89
  print("Loaded tokenizer from ADAPTER_REPO_ID")
90
 
91
- # ensure tokenizer has pad_token_id (some HF models lack it)
92
  if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
93
- # try to set eos_token_id as pad if pad missing
94
  if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
95
  self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
96
  else:
97
- # fallback to 0 (not ideal but prevents crashes)
98
  self.llm_tokenizer.pad_token_id = 0
99
 
 
100
  if torch.cuda.is_available():
101
  device_map = "auto"
102
  else:
@@ -121,7 +121,7 @@ class WeeboAssistant:
121
  BASE_MODEL_ID,
122
  **base_model_kwargs,
123
  )
124
- # make sure use_cache is enabled for faster autoregressive generation
125
  try:
126
  self.llm_model.config.use_cache = True
127
  except Exception:
@@ -133,6 +133,7 @@ class WeeboAssistant:
133
  + str(e)
134
  )
135
 
 
136
  try:
137
  try:
138
  peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
@@ -164,6 +165,7 @@ class WeeboAssistant:
164
  + str(e)
165
  )
166
 
 
167
  try:
168
  device_index = 0 if torch.cuda.is_available() else -1
169
  self.llm_pipeline = pipeline(
@@ -181,6 +183,7 @@ class WeeboAssistant:
181
  print("LLM base + adapter loaded successfully.")
182
 
183
  def get_llm_response(self, chat_history):
 
184
  prompt_lines = [self.SYSTEM_PROMPT]
185
  for user_msg, assistant_msg in chat_history:
186
  if user_msg:
@@ -190,7 +193,7 @@ class WeeboAssistant:
190
  prompt_lines.append("Assistant: ")
191
  prompt = "\n".join(prompt_lines)
192
 
193
- # Tokenize
194
  inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
195
  try:
196
  model_device = next(self.llm_model.parameters()).device
@@ -198,10 +201,10 @@ class WeeboAssistant:
198
  model_device = torch.device("cpu")
199
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
200
 
201
- # Streamer unchanged (still yields chunks)
202
  streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
203
 
204
- # Prefill some generation kwargs optimized for speed
205
  input_len = inputs["input_ids"].shape[1]
206
  max_new = self.MAX_NEW_TOKENS
207
  max_length = input_len + max_new
@@ -209,10 +212,10 @@ class WeeboAssistant:
209
  generation_kwargs = dict(
210
  input_ids=inputs["input_ids"],
211
  attention_mask=inputs.get("attention_mask", None),
212
- max_length=max_length, # prefer max_length = input_len + max_new_tokens
213
- max_new_tokens=max_new, # kept for clarity / compatibility
214
  do_sample=self.DO_SAMPLE, # greedy if False -> faster
215
- num_beams=self.NUM_BEAMS, # beam search >1 slows down; keep 1 for speed
216
  streamer=streamer,
217
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
218
  pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
@@ -220,15 +223,12 @@ class WeeboAssistant:
220
  early_stopping=True,
221
  )
222
 
223
- # Run generate under no_grad for speed / memory
224
  def _generate_thread():
225
  with torch.no_grad():
226
  try:
227
- # call generate on model (PEFT-wrapped)
228
  self.llm_model.generate(**generation_kwargs)
229
  except Exception as e:
230
- # if streaming fails, put an error chunk into streamer by raising
231
- # streamer does not provide a direct API to inject text; print to log
232
  print("Generation error:", e)
233
 
234
  gen_thread = threading.Thread(target=_generate_thread, daemon=True)
@@ -237,12 +237,14 @@ class WeeboAssistant:
237
  return streamer
238
 
239
 
 
240
  assistant = WeeboAssistant()
241
 
242
 
 
243
  def t2t_pipeline(text_input, chat_history):
244
  chat_history = chat_history or []
245
- chat_history.append((text_input, ""))
246
  yield chat_history
247
 
248
  response_stream = assistant.get_llm_response(chat_history)
@@ -257,12 +259,71 @@ def clear_textbox():
257
  return gr.Textbox.update(value="")
258
 
259
 
260
- # -------------------- English UI --------------------
261
- with gr.Blocks(theme=gr.themes.Soft(), title="Swahili Assistant - Text Chat") as demo:
262
- gr.Markdown("# 🤖 Swahili Assistant (Text Chat)")
263
- gr.Markdown("Chat (text-based) with the assistant in English. Use the box below to type your question.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=500)
266
  with gr.Row():
267
  t2t_text_in = gr.Textbox(show_label=False, placeholder="Type your message here...", scale=4, container=False)
268
  t2t_submit_btn = gr.Button("Send", variant="primary", scale=1)
@@ -289,4 +350,5 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Swahili Assistant - Text Chat") as
289
  outputs=t2t_text_in,
290
  )
291
 
 
292
  demo.queue().launch(debug=True)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed)
4
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
5
 
6
+ Notes:
7
+ - Dark UI via custom CSS (works even if Gradio theme API differs)
8
+ - Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming
9
+ - Keeps bitsandbytes / 4-bit logic intact when available
 
 
10
  """
11
 
12
  import os
 
59
 
60
  class WeeboAssistant:
61
  def __init__(self):
62
+ # system prompt instructs the assistant to answer concisely in English
63
  self.SYSTEM_PROMPT = (
64
  "You are an intelligent assistant. Answer questions briefly and accurately. "
65
  "Respond only in English. No long answers.\n"
66
  )
67
+ # generation defaults tuned for speed (adjust if you need different behavior)
68
  self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed
69
+ self.DO_SAMPLE = False # greedy = faster; set True if you want sampling
70
+ self.NUM_BEAMS = 1 # keep 1 for greedy (increase >1 for beam search)
71
  self._init_models()
72
 
73
  def _init_models(self):
 
79
  BNB_AVAILABLE = is_package_installed("bitsandbytes")
80
  print("bitsandbytes available:", BNB_AVAILABLE)
81
 
82
+ # load tokenizer (prefer base tokenizer)
83
  try:
84
  self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
85
  print("Loaded tokenizer from BASE_MODEL_ID")
 
88
  self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
89
  print("Loaded tokenizer from ADAPTER_REPO_ID")
90
 
91
+ # ensure tokenizer has pad_token_id to avoid generation stalls
92
  if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
 
93
  if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
94
  self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
95
  else:
96
+ # fallback to 0 to prevent crashes (not ideal but safe)
97
  self.llm_tokenizer.pad_token_id = 0
98
 
99
+ # decide device_map (never pass None)
100
  if torch.cuda.is_available():
101
  device_map = "auto"
102
  else:
 
121
  BASE_MODEL_ID,
122
  **base_model_kwargs,
123
  )
124
+ # ensure use_cache set for faster autoregressive generation
125
  try:
126
  self.llm_model.config.use_cache = True
127
  except Exception:
 
133
  + str(e)
134
  )
135
 
136
+ # load and apply PEFT adapter
137
  try:
138
  try:
139
  peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
 
165
  + str(e)
166
  )
167
 
168
+ # optional non-streaming pipeline (useful for quick tests)
169
  try:
170
  device_index = 0 if torch.cuda.is_available() else -1
171
  self.llm_pipeline = pipeline(
 
183
  print("LLM base + adapter loaded successfully.")
184
 
185
  def get_llm_response(self, chat_history):
186
+ # Build prompt (system + conversation)
187
  prompt_lines = [self.SYSTEM_PROMPT]
188
  for user_msg, assistant_msg in chat_history:
189
  if user_msg:
 
193
  prompt_lines.append("Assistant: ")
194
  prompt = "\n".join(prompt_lines)
195
 
196
+ # Tokenize inputs
197
  inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
198
  try:
199
  model_device = next(self.llm_model.parameters()).device
 
201
  model_device = torch.device("cpu")
202
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
203
 
204
+ # Use TextIteratorStreamer for streaming outputs to Gradio
205
  streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
206
 
207
+ # Prefill generation kwargs optimized for speed
208
  input_len = inputs["input_ids"].shape[1]
209
  max_new = self.MAX_NEW_TOKENS
210
  max_length = input_len + max_new
 
212
  generation_kwargs = dict(
213
  input_ids=inputs["input_ids"],
214
  attention_mask=inputs.get("attention_mask", None),
215
+ max_length=max_length, # input_len + max_new
216
+ max_new_tokens=max_new, # explicit
217
  do_sample=self.DO_SAMPLE, # greedy if False -> faster
218
+ num_beams=self.NUM_BEAMS, # keep 1 for speed
219
  streamer=streamer,
220
  eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
221
  pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
 
223
  early_stopping=True,
224
  )
225
 
226
+ # Run generate under no_grad to save memory and time
227
  def _generate_thread():
228
  with torch.no_grad():
229
  try:
 
230
  self.llm_model.generate(**generation_kwargs)
231
  except Exception as e:
 
 
232
  print("Generation error:", e)
233
 
234
  gen_thread = threading.Thread(target=_generate_thread, daemon=True)
 
237
  return streamer
238
 
239
 
240
+ # create assistant instance (loads model once at startup)
241
  assistant = WeeboAssistant()
242
 
243
 
244
+ # -------------------- Gradio pipeline functions --------------------
245
  def t2t_pipeline(text_input, chat_history):
246
  chat_history = chat_history or []
247
+ chat_history.append((text_input, "")) # placeholder for assistant reply
248
  yield chat_history
249
 
250
  response_stream = assistant.get_llm_response(chat_history)
 
259
  return gr.Textbox.update(value="")
260
 
261
 
262
+ # -------------------- Dark UI CSS --------------------
263
+ DARK_CSS = """
264
+ /* Base background & text */
265
+ body, .gradio-container {
266
+ background: linear-gradient(180deg, #04060a 0%, #0b1220 100%) !important;
267
+ color: #E6EEF8 !important;
268
+ }
269
+
270
+ /* Header / Markdown text */
271
+ h1, h2, h3, .markdown {
272
+ color: #E6EEF8 !important;
273
+ }
274
+
275
+ /* Card backgrounds */
276
+ .gr-block, .gr-box, .gr-row, .gr-column, .gradio-container .container {
277
+ background-color: transparent !important;
278
+ }
279
+
280
+ /* Chatbot area */
281
+ .gr-chatbot {
282
+ background: rgba(10, 14, 22, 0.6) !important;
283
+ border: 1px solid rgba(255,255,255,0.04) !important;
284
+ color: #E6EEF8 !important;
285
+ }
286
+
287
+ /* Chat messages - user and assistant bubbles */
288
+ .gr-chatbot .message.user, .gr-chatbot .message.user p {
289
+ background: linear-gradient(180deg, #0f1724, #0b1220) !important;
290
+ color: #CFE7FF !important;
291
+ border: 1px solid rgba(255,255,255,0.04) !important;
292
+ }
293
+ .gr-chatbot .message.bot, .gr-chatbot .message.bot p {
294
+ background: linear-gradient(180deg, #071126, #081426) !important;
295
+ color: #E6EEF8 !important;
296
+ border: 1px solid rgba(255,255,255,0.03) !important;
297
+ }
298
+
299
+ /* Input textbox and button */
300
+ .gr-textbox, .gr-textbox textarea {
301
+ background: #071226 !important;
302
+ color: #E6EEF8 !important;
303
+ border: 1px solid rgba(255,255,255,0.04) !important;
304
+ }
305
+ .gr-button, .gr-button:hover {
306
+ background: linear-gradient(180deg, #0b63ff, #0a4ad6) !important;
307
+ color: white !important;
308
+ border: none !important;
309
+ box-shadow: 0 6px 18px rgba(6, 18, 55, 0.5) !important;
310
+ }
311
+
312
+ /* Small UI tweaks */
313
+ footer, .footer {
314
+ display: none;
315
+ }
316
+ .gradio-container * {
317
+ font-family: Inter, ui-sans-serif, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial;
318
+ }
319
+ """
320
+
321
+ # -------------------- Gradio UI (dark) --------------------
322
+ with gr.Blocks(css=DARK_CSS, title="YOUR FOIA CHAT ASSISTANCE") as demo:
323
+ gr.Markdown("# YOUR FOIA CHAT ASSISTANCE")
324
+ gr.Markdown("Chat (text-based) with the FOIA assistant. Use the box below to type your question.")
325
 
326
+ t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520)
327
  with gr.Row():
328
  t2t_text_in = gr.Textbox(show_label=False, placeholder="Type your message here...", scale=4, container=False)
329
  t2t_submit_btn = gr.Button("Send", variant="primary", scale=1)
 
350
  outputs=t2t_text_in,
351
  )
352
 
353
+ # launch
354
  demo.queue().launch(debug=True)