OthnnyEL commited on
Commit
d3083ec
·
verified ·
1 Parent(s): 24253fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +356 -631
app.py CHANGED
@@ -1,706 +1,431 @@
1
  # -*- coding: utf-8 -*-
2
-
3
  """
4
- FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed)
5
-
6
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
7
 
8
-
9
-
10
  Notes:
11
-
12
  - Dark UI via custom CSS (works even if Gradio theme API differs)
13
-
14
  - Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming
15
-
16
  - Keeps bitsandbytes / 4-bit logic intact when available
17
-
18
  """
19
 
20
-
21
-
22
  import os
23
-
24
  import threading
25
-
26
  import gradio as gr
27
-
28
  import importlib
29
-
30
  import importlib.util
31
-
32
  import torch
33
 
34
-
35
-
36
  from huggingface_hub import login
37
-
38
  from transformers import (
39
-
40
-     AutoTokenizer,
41
-
42
-     AutoModelForCausalLM,
43
-
44
-     pipeline,
45
-
46
-     TextIteratorStreamer,
47
-
48
  )
49
-
50
  from peft import PeftModel, PeftConfig
51
 
52
-
53
-
54
  # -------------------- Configuration --------------------
55
-
56
- ADAPTER_REPO_ID = "EYEDOL/FOIA"  # adapter-only repo
57
-
58
- BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"    # full base model referenced by adapter
59
-
60
-
61
 
62
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
63
-
64
  if HF_TOKEN:
65
-
66
-     try:
67
-
68
-         login(token=HF_TOKEN)
69
-
70
-         print("Successfully logged into Hugging Face Hub!")
71
-
72
-     except Exception as e:
73
-
74
-         print("Warning: huggingface_hub.login() failed:", e)
75
-
76
  else:
77
-
78
-     print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")
79
-
80
-
81
-
82
 
83
 
84
  def is_package_installed(name: str) -> bool:
85
-
86
-     """Return True if installed (distribution metadata present)."""
87
-
88
-     try:
89
-
90
-         import importlib.metadata as md
91
-
92
-         try:
93
-
94
-             md.distribution(name)
95
-
96
-             return True
97
-
98
-         except Exception:
99
-
100
-             return False
101
-
102
-     except Exception:
103
-
104
-         try:
105
-
106
-             importlib.import_module(name)
107
-
108
-             return True
109
-
110
-         except Exception:
111
-
112
-             return False
113
-
114
-
115
-
116
 
117
 
118
  class WeeboAssistant:
119
-
120
-     def __init__(self):
121
-
122
-         # system prompt instructs the assistant to answer concisely in English
123
-
124
-         self.SYSTEM_PROMPT = (
125
-
126
-             "You are an intelligent assistant. Answer questions briefly and accurately. "
127
-
128
-             "Respond only in English. No long answers.\n"
129
-
130
-         )
131
-
132
-         # generation defaults tuned for speed (adjust if you need different behavior)
133
-
134
-         self.MAX_NEW_TOKENS = 256   # lowered from 512 for speed
135
-
136
-         self.DO_SAMPLE = False      # greedy = faster; set True if you want sampling
137
-
138
-         self.NUM_BEAMS = 1          # keep 1 for greedy (increase >1 for beam search)
139
-
140
-         self._init_models()
141
-
142
-
143
-
144
-     def _init_models(self):
145
-
146
-         print("Initializing models...")
147
-
148
-         self.device = "cuda" if torch.cuda.is_available() else "cpu"
149
-
150
-         self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
151
-
152
-         print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}")
153
-
154
-
155
-
156
-         BNB_AVAILABLE = is_package_installed("bitsandbytes")
157
-
158
-         print("bitsandbytes available:", BNB_AVAILABLE)
159
-
160
-
161
-
162
-         # load tokenizer (prefer base tokenizer)
163
-
164
-         try:
165
-
166
-             self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
167
-
168
-             print("Loaded tokenizer from BASE_MODEL_ID")
169
-
170
-         except Exception as e:
171
-
172
-             print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
173
-
174
-             self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
175
-
176
-             print("Loaded tokenizer from ADAPTER_REPO_ID")
177
-
178
-
179
-
180
-         # ensure tokenizer has pad_token_id to avoid generation stalls
181
-
182
-         if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
183
-
184
-             if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
185
-
186
-                 self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
187
-
188
-             else:
189
-
190
-                 # fallback to 0 to prevent crashes (not ideal but safe)
191
-
192
-                 self.llm_tokenizer.pad_token_id = 0
193
-
194
-
195
-
196
-         # decide device_map (never pass None)
197
-
198
-         if torch.cuda.is_available():
199
-
200
-             device_map = "auto"
201
-
202
-         else:
203
-
204
-             device_map = {"": "cpu"}
205
-
206
-         print("device_map being used for model load:", device_map)
207
-
208
-
209
-
210
-         base_model_kwargs = dict(
211
-
212
-             torch_dtype=self.torch_dtype,
213
-
214
-             low_cpu_mem_usage=True,
215
-
216
-             device_map=device_map,
217
-
218
-             trust_remote_code=True,
219
-
220
-         )
221
-
222
-
223
-
224
-         if BNB_AVAILABLE and torch.cuda.is_available():
225
-
226
-             base_model_kwargs["load_in_4bit"] = True
227
-
228
-             print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).")
229
-
230
-         else:
231
-
232
-             print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).")
233
-
234
-
235
-
236
-         try:
237
-
238
-             self.llm_model = AutoModelForCausalLM.from_pretrained(
239
-
240
-                 BASE_MODEL_ID,
241
-
242
-                 **base_model_kwargs,
243
-
244
-             )
245
-
246
-             # ensure use_cache set for faster autoregressive generation
247
-
248
-             try:
249
-
250
-                 self.llm_model.config.use_cache = True
251
-
252
-             except Exception:
253
-
254
-                 pass
255
-
256
-             print("Base model loaded from", BASE_MODEL_ID)
257
-
258
-         except Exception as e:
259
-
260
-             raise RuntimeError(
261
-
262
-                 "Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: "
263
-
264
-                 + str(e)
265
-
266
-             )
267
-
268
-
269
-
270
-         # load and apply PEFT adapter
271
-
272
-         try:
273
-
274
-             try:
275
-
276
-                 peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
277
-
278
-                 print("Loaded PEFT config from", ADAPTER_REPO_ID)
279
-
280
-             except Exception:
281
-
282
-                 peft_config = None
283
-
284
-                 print("Warning: could not load PeftConfig; continuing to attempt adapter load.")
285
-
286
-
287
-
288
-             peft_kwargs = dict(
289
-
290
-                 device_map=device_map,
291
-
292
-                 torch_dtype=self.torch_dtype,
293
-
294
-                 low_cpu_mem_usage=True,
295
-
296
-             )
297
-
298
-
299
-
300
-             self.llm_model = PeftModel.from_pretrained(
301
-
302
-                 self.llm_model,
303
-
304
-                 ADAPTER_REPO_ID,
305
-
306
-                 **peft_kwargs,
307
-
308
-             )
309
-
310
-             # ensure adapter-wrapped model also has use_cache
311
-
312
-             try:
313
-
314
-                 self.llm_model.config.use_cache = True
315
-
316
-             except Exception:
317
-
318
-                 pass
319
-
320
-             print("PEFT adapter applied from", ADAPTER_REPO_ID)
321
-
322
-         except Exception as e:
323
-
324
-             raise RuntimeError(
325
-
326
-                 "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
327
-
328
-                 + str(e)
329
-
330
-             )
331
-
332
-
333
-
334
-         # optional non-streaming pipeline (useful for quick tests)
335
-
336
-         try:
337
-
338
-             device_index = 0 if torch.cuda.is_available() else -1
339
-
340
-             self.llm_pipeline = pipeline(
341
-
342
-                 "text-generation",
343
-
344
-                 model=self.llm_model,
345
-
346
-                 tokenizer=self.llm_tokenizer,
347
-
348
-                 device=device_index,
349
-
350
-                 model_kwargs={"torch_dtype": self.torch_dtype},
351
-
352
-             )
353
-
354
-             print("Created text-generation pipeline (non-streaming).")
355
-
356
-         except Exception as e:
357
-
358
-             print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
359
-
360
-             self.llm_pipeline = None
361
-
362
-
363
-
364
-         print("LLM base + adapter loaded successfully.")
365
-
366
-
367
-
368
-     def get_llm_response(self, chat_history):
369
-
370
-         # Build prompt (system + conversation)
371
-
372
-         prompt_lines = [self.SYSTEM_PROMPT]
373
-
374
-         for user_msg, assistant_msg in chat_history:
375
-
376
-             if user_msg:
377
-
378
-                 prompt_lines.append("User: " + user_msg)
379
-
380
-             if assistant_msg:
381
-
382
-                 prompt_lines.append("Assistant: " + assistant_msg)
383
-
384
-         prompt_lines.append("Assistant: ")
385
-
386
-         prompt = "\n".join(prompt_lines)
387
-
388
-
389
-
390
-         # Tokenize inputs
391
-
392
-         inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
393
-
394
-         try:
395
-
396
-             model_device = next(self.llm_model.parameters()).device
397
-
398
-         except StopIteration:
399
-
400
-             model_device = torch.device("cpu")
401
-
402
-         inputs = {k: v.to(model_device) for k, v in inputs.items()}
403
-
404
-
405
-
406
-         # Use TextIteratorStreamer for streaming outputs to Gradio
407
-
408
-         streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
409
-
410
-
411
-
412
-         # Prefill generation kwargs optimized for speed
413
-
414
-         input_len = inputs["input_ids"].shape[1]
415
-
416
-         max_new = self.MAX_NEW_TOKENS
417
-
418
-         max_length = input_len + max_new
419
-
420
-
421
-
422
-         generation_kwargs = dict(
423
-
424
-             input_ids=inputs["input_ids"],
425
-
426
-             attention_mask=inputs.get("attention_mask", None),
427
-
428
-             max_length=max_length,               # input_len + max_new
429
-
430
-             max_new_tokens=max_new,              # explicit
431
-
432
-             do_sample=self.DO_SAMPLE,            # greedy if False -> faster
433
-
434
-             num_beams=self.NUM_BEAMS,            # keep 1 for speed
435
-
436
-             streamer=streamer,
437
-
438
-             eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
439
-
440
-             pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
441
-
442
-             use_cache=True,
443
-
444
-             early_stopping=True,
445
-
446
-         )
447
-
448
-
449
-
450
-         # Run generate under no_grad to save memory and time
451
-
452
-         def _generate_thread():
453
-
454
-             with torch.no_grad():
455
-
456
-                 try:
457
-
458
-                     self.llm_model.generate(**generation_kwargs)
459
-
460
-                 except Exception as e:
461
-
462
-                     print("Generation error:", e)
463
-
464
-
465
-
466
-         gen_thread = threading.Thread(target=_generate_thread, daemon=True)
467
-
468
-         gen_thread.start()
469
-
470
-
471
-
472
-         return streamer
473
-
474
-
475
-
476
 
477
 
478
  # create assistant instance (loads model once at startup)
479
-
480
  assistant = WeeboAssistant()
481
 
482
 
483
-
484
-
485
-
486
  # -------------------- Gradio pipeline functions --------------------
487
-
488
  def t2t_pipeline(text_input, chat_history):
 
 
 
489
 
490
-     chat_history = chat_history or []
491
-
492
-     chat_history.append((text_input, ""))  # placeholder for assistant reply
493
-
494
-     yield chat_history
495
-
496
-
497
-
498
-     response_stream = assistant.get_llm_response(chat_history)
499
-
500
-     llm_response_text = ""
501
-
502
-     for text_chunk in response_stream:
503
-
504
-         llm_response_text += text_chunk
505
-
506
-         chat_history[-1] = (text_input, llm_response_text)
507
-
508
-         yield chat_history
509
-
510
-
511
-
512
 
513
 
514
  def clear_textbox():
515
-
516
-     return gr.Textbox.update(value="")
517
-
518
-
519
-
520
-
521
-
522
- # -------------------- Dark UI CSS --------------------
523
-
524
- DARK_CSS = """
525
-
526
- /* Base background & text */
 
 
 
 
 
 
 
527
 
528
  body, .gradio-container {
529
-
530
-   background: linear-gradient(180deg, #04060a 0%, #0b1220 100%) !important;
531
-
532
-   color: #E6EEF8 !important;
533
-
534
  }
535
 
536
-
537
-
538
- /* Header / Markdown text */
539
 
540
  h1, h2, h3, .markdown {
541
-
542
-   color: #E6EEF8 !important;
543
-
544
  }
545
 
546
-
547
-
548
- /* Card backgrounds */
549
-
550
- .gr-block, .gr-box, .gr-row, .gr-column, .gradio-container .container {
551
-
552
-   background-color: transparent !important;
553
-
554
  }
555
 
556
-
557
-
558
- /* Chatbot area */
559
-
560
  .gr-chatbot {
561
-
562
-   background: rgba(10, 14, 22, 0.6) !important;
563
-
564
-   border: 1px solid rgba(255,255,255,0.04) !important;
565
-
566
-   color: #E6EEF8 !important;
567
-
568
  }
569
 
570
-
571
-
572
- /* Chat messages - user and assistant bubbles */
573
-
574
- .gr-chatbot .message.user, .gr-chatbot .message.user p {
575
-
576
-   background: linear-gradient(180deg, #0f1724, #0b1220) !important;
577
-
578
-   color: #CFE7FF !important;
579
-
580
-   border: 1px solid rgba(255,255,255,0.04) !important;
581
-
582
  }
583
 
584
- .gr-chatbot .message.bot, .gr-chatbot .message.bot p {
585
-
586
-   background: linear-gradient(180deg, #071126, #081426) !important;
587
-
588
-   color: #E6EEF8 !important;
589
-
590
-   border: 1px solid rgba(255,255,255,0.03) !important;
591
-
592
  }
593
 
594
-
595
-
596
- /* Input textbox and button */
597
 
598
  .gr-textbox, .gr-textbox textarea {
599
-
600
-   background: #071226 !important;
601
-
602
-   color: #E6EEF8 !important;
603
-
604
-   border: 1px solid rgba(255,255,255,0.04) !important;
605
-
606
  }
607
-
608
- .gr-button, .gr-button:hover {
609
-
610
-   background: linear-gradient(180deg, #0b63ff, #0a4ad6) !important;
611
-
612
-   color: white !important;
613
-
614
-   border: none !important;
615
-
616
-   box-shadow: 0 6px 18px rgba(6, 18, 55, 0.5) !important;
617
-
618
  }
619
 
620
-
621
-
622
- /* Small UI tweaks */
623
-
624
- footer, .footer {
625
-
626
-   display: none;
627
-
 
 
 
 
628
  }
629
 
630
- .gradio-container * {
631
-
632
-   font-family: Inter, ui-sans-serif, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial;
633
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  }
635
 
 
 
 
636
  """
637
 
638
-
639
-
640
- # -------------------- Gradio UI (dark) --------------------
641
-
642
- with gr.Blocks(css=DARK_CSS, title="YOUR FOIA CHAT ASSISTANCE") as demo:
643
-
644
-     gr.Markdown("# YOUR FOIA CHAT ASSISTANCE")
645
-
646
-     gr.Markdown("Chat (text-based) with the FOIA assistant. Use the box below to type your question.")
647
-
648
-
649
-
650
-     t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520)
651
-
652
-     with gr.Row():
653
-
654
-         t2t_text_in = gr.Textbox(show_label=False, placeholder="Type your message here...", scale=4, container=False)
655
-
656
-         t2t_submit_btn = gr.Button("Send", variant="primary", scale=1)
657
-
658
-
659
-
660
-     t2t_submit_btn.click(
661
-
662
-         fn=t2t_pipeline,
663
-
664
-         inputs=[t2t_text_in, t2t_chatbot],
665
-
666
-         outputs=[t2t_chatbot],
667
-
668
-         queue=True,
669
-
670
-     ).then(
671
-
672
-         fn=clear_textbox,
673
-
674
-         inputs=None,
675
-
676
-         outputs=t2t_text_in,
677
-
678
-     )
679
-
680
-
681
-
682
-     t2t_text_in.submit(
683
-
684
-         fn=t2t_pipeline,
685
-
686
-         inputs=[t2t_text_in, t2t_chatbot],
687
-
688
-         outputs=[t2t_chatbot],
689
-
690
-         queue=True,
691
-
692
-     ).then(
693
-
694
-         fn=clear_textbox,
695
-
696
-         inputs=None,
697
-
698
-         outputs=t2t_text_in,
699
-
700
-     )
701
-
702
-
703
 
704
  # launch
705
-
706
- demo.queue().launch(debug=True)
 
1
  # -*- coding: utf-8 -*-
 
2
  """
3
+ YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed)
 
4
  Drop this file into your Hugging Face Space (replace existing app.py) or run locally.
5
 
 
 
6
  Notes:
 
7
  - Dark UI via custom CSS (works even if Gradio theme API differs)
 
8
  - Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming
 
9
  - Keeps bitsandbytes / 4-bit logic intact when available
 
10
  """
11
 
 
 
12
  import os
 
13
  import threading
 
14
  import gradio as gr
 
15
  import importlib
 
16
  import importlib.util
 
17
  import torch
18
 
 
 
19
  from huggingface_hub import login
 
20
  from transformers import (
21
+ AutoTokenizer,
22
+ AutoModelForCausalLM,
23
+ pipeline,
24
+ TextIteratorStreamer,
 
 
 
 
 
25
  )
 
26
  from peft import PeftModel, PeftConfig
27
 
 
 
28
  # -------------------- Configuration --------------------
29
+ ADAPTER_REPO_ID = "EYEDOL/FOIA" # adapter-only repo
30
+ BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" # full base model referenced by adapter
 
 
 
 
31
 
32
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
 
33
  if HF_TOKEN:
34
+ try:
35
+ login(token=HF_TOKEN)
36
+ print("Successfully logged into Hugging Face Hub!")
37
+ except Exception as e:
38
+ print("Warning: huggingface_hub.login() failed:", e)
 
 
 
 
 
 
39
  else:
40
+ print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")
 
 
 
 
41
 
42
 
43
  def is_package_installed(name: str) -> bool:
44
+ """Return True if installed (distribution metadata present)."""
45
+ try:
46
+ import importlib.metadata as md
47
+ try:
48
+ md.distribution(name)
49
+ return True
50
+ except Exception:
51
+ return False
52
+ except Exception:
53
+ try:
54
+ importlib.import_module(name)
55
+ return True
56
+ except Exception:
57
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  class WeeboAssistant:
61
+ def __init__(self):
62
+ # system prompt instructs the assistant to answer concisely in English
63
+ self.SYSTEM_PROMPT = (
64
+ "You are an intelligent assistant. Answer questions briefly and accurately. "
65
+ "Respond only in English. No long answers.\n"
66
+ )
67
+ # generation defaults tuned for speed (adjust if you need different behavior)
68
+ self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed
69
+ self.DO_SAMPLE = False # greedy = faster; set True if you want sampling
70
+ self.NUM_BEAMS = 1 # keep 1 for greedy (increase >1 for beam search)
71
+ self._init_models()
72
+
73
+ def _init_models(self):
74
+ print("Initializing models...")
75
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
77
+ print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}")
78
+
79
+ BNB_AVAILABLE = is_package_installed("bitsandbytes")
80
+ print("bitsandbytes available:", BNB_AVAILABLE)
81
+
82
+ # load tokenizer (prefer base tokenizer)
83
+ try:
84
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
85
+ print("Loaded tokenizer from BASE_MODEL_ID")
86
+ except Exception as e:
87
+ print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
88
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
89
+ print("Loaded tokenizer from ADAPTER_REPO_ID")
90
+
91
+ # ensure tokenizer has pad_token_id to avoid generation stalls
92
+ if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
93
+ if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
94
+ self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
95
+ else:
96
+ # fallback to 0 to prevent crashes (not ideal but safe)
97
+ self.llm_tokenizer.pad_token_id = 0
98
+
99
+ # decide device_map (never pass None)
100
+ if torch.cuda.is_available():
101
+ device_map = "auto"
102
+ else:
103
+ device_map = {"": "cpu"}
104
+ print("device_map being used for model load:", device_map)
105
+
106
+ base_model_kwargs = dict(
107
+ torch_dtype=self.torch_dtype,
108
+ low_cpu_mem_usage=True,
109
+ device_map=device_map,
110
+ trust_remote_code=True,
111
+ )
112
+
113
+ if BNB_AVAILABLE and torch.cuda.is_available():
114
+ base_model_kwargs["load_in_4bit"] = True
115
+ print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).")
116
+ else:
117
+ print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).")
118
+
119
+ try:
120
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
121
+ BASE_MODEL_ID,
122
+ **base_model_kwargs,
123
+ )
124
+ # ensure use_cache set for faster autoregressive generation
125
+ try:
126
+ self.llm_model.config.use_cache = True
127
+ except Exception:
128
+ pass
129
+ print("Base model loaded from", BASE_MODEL_ID)
130
+ except Exception as e:
131
+ raise RuntimeError(
132
+ "Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: "
133
+ + str(e)
134
+ )
135
+
136
+ # load and apply PEFT adapter
137
+ try:
138
+ try:
139
+ peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
140
+ print("Loaded PEFT config from", ADAPTER_REPO_ID)
141
+ except Exception:
142
+ peft_config = None
143
+ print("Warning: could not load PeftConfig; continuing to attempt adapter load.")
144
+
145
+ peft_kwargs = dict(
146
+ device_map=device_map,
147
+ torch_dtype=self.torch_dtype,
148
+ low_cpu_mem_usage=True,
149
+ )
150
+
151
+ self.llm_model = PeftModel.from_pretrained(
152
+ self.llm_model,
153
+ ADAPTER_REPO_ID,
154
+ **peft_kwargs,
155
+ )
156
+ # ensure adapter-wrapped model also has use_cache
157
+ try:
158
+ self.llm_model.config.use_cache = True
159
+ except Exception:
160
+ pass
161
+ print("PEFT adapter applied from", ADAPTER_REPO_ID)
162
+ except Exception as e:
163
+ raise RuntimeError(
164
+ "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
165
+ + str(e)
166
+ )
167
+
168
+ # optional non-streaming pipeline (useful for quick tests)
169
+ try:
170
+ device_index = 0 if torch.cuda.is_available() else -1
171
+ self.llm_pipeline = pipeline(
172
+ "text-generation",
173
+ model=self.llm_model,
174
+ tokenizer=self.llm_tokenizer,
175
+ device=device_index,
176
+ model_kwargs={"torch_dtype": self.torch_dtype},
177
+ )
178
+ print("Created text-generation pipeline (non-streaming).")
179
+ except Exception as e:
180
+ print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
181
+ self.llm_pipeline = None
182
+
183
+ print("LLM base + adapter loaded successfully.")
184
+
185
+ def get_llm_response(self, chat_history):
186
+ # Build prompt (system + conversation)
187
+ prompt_lines = [self.SYSTEM_PROMPT]
188
+ for user_msg, assistant_msg in chat_history:
189
+ if user_msg:
190
+ prompt_lines.append("User: " + user_msg)
191
+ if assistant_msg:
192
+ prompt_lines.append("Assistant: " + assistant_msg)
193
+ prompt_lines.append("Assistant: ")
194
+ prompt = "\n".join(prompt_lines)
195
+
196
+ # Tokenize inputs
197
+ inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
198
+ try:
199
+ model_device = next(self.llm_model.parameters()).device
200
+ except StopIteration:
201
+ model_device = torch.device("cpu")
202
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
203
+
204
+ # Use TextIteratorStreamer for streaming outputs to Gradio
205
+ streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)
206
+
207
+ # Prefill generation kwargs optimized for speed
208
+ input_len = inputs["input_ids"].shape[1]
209
+ max_new = self.MAX_NEW_TOKENS
210
+ max_length = input_len + max_new
211
+
212
+ generation_kwargs = dict(
213
+ input_ids=inputs["input_ids"],
214
+ attention_mask=inputs.get("attention_mask", None),
215
+ max_length=max_length, # input_len + max_new
216
+ max_new_tokens=max_new, # explicit
217
+ do_sample=self.DO_SAMPLE, # greedy if False -> faster
218
+ num_beams=self.NUM_BEAMS, # keep 1 for speed
219
+ streamer=streamer,
220
+ eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
221
+ pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
222
+ use_cache=True,
223
+ early_stopping=True,
224
+ )
225
+
226
+ # Run generate under no_grad to save memory and time
227
+ def _generate_thread():
228
+ with torch.no_grad():
229
+ try:
230
+ self.llm_model.generate(**generation_kwargs)
231
+ except Exception as e:
232
+ print("Generation error:", e)
233
+
234
+ gen_thread = threading.Thread(target=_generate_thread, daemon=True)
235
+ gen_thread.start()
236
+
237
+ return streamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
 
240
  # create assistant instance (loads model once at startup)
 
241
  assistant = WeeboAssistant()
242
 
243
 
 
 
 
244
  # -------------------- Gradio pipeline functions --------------------
 
245
  def t2t_pipeline(text_input, chat_history):
246
+ chat_history = chat_history or []
247
+ chat_history.append((text_input, "")) # placeholder for assistant reply
248
+ yield chat_history
249
 
250
+ response_stream = assistant.get_llm_response(chat_history)
251
+ llm_response_text = ""
252
+ for text_chunk in response_stream:
253
+ llm_response_text += text_chunk
254
+ chat_history[-1] = (text_input, llm_response_text)
255
+ yield chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  def clear_textbox():
259
+ return gr.Textbox.update(value="")
260
+
261
+
262
+ # -------------------- MODIFIED: Modern Dark UI CSS --------------------
263
+ MODERN_CSS = """
264
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap');
265
+
266
+ :root {
267
+ --body-bg: linear-gradient(135deg, #10141a 0%, #06090f 100%);
268
+ --chat-bg: #0b0f19;
269
+ --border-color: rgba(255, 255, 255, 0.08);
270
+ --text-color: #E6EEF8;
271
+ --input-bg: #131926;
272
+ --user-msg-bg: #1B2336;
273
+ --bot-msg-bg: #0F1522;
274
+ --primary-color: #0084ff;
275
+ --primary-hover: #006fdb;
276
+ --font-family: 'Poppins', sans-serif;
277
+ }
278
 
279
  body, .gradio-container {
280
+ background: var(--body-bg) !important;
281
+ color: var(--text-color) !important;
282
+ font-family: var(--font-family) !important;
 
 
283
  }
284
 
285
+ .gradio-container * {
286
+ font-family: var(--font-family) !important;
287
+ }
288
 
289
  h1, h2, h3, .markdown {
290
+ color: var(--text-color) !important;
 
 
291
  }
292
 
293
+ .gr-block, .gr-box, .gr-row, .gr-column {
294
+ background: transparent !important;
295
+ border: none !important;
296
+ box-shadow: none !important;
 
 
 
 
297
  }
298
 
 
 
 
 
299
  .gr-chatbot {
300
+ background: var(--chat-bg) !important;
301
+ border: 1px solid var(--border-color) !important;
302
+ border-radius: 12px !important;
303
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2) !important;
 
 
 
304
  }
305
 
306
+ .gr-chatbot .message {
307
+ border-radius: 8px !important;
308
+ padding: 12px !important;
309
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
310
+ border: none !important;
 
 
 
 
 
 
 
311
  }
312
 
313
+ .gr-chatbot .message.user {
314
+ background: var(--user-msg-bg) !important;
315
+ color: var(--text-color) !important;
316
+ }
317
+ .gr-chatbot .message.bot {
318
+ background: var(--bot-msg-bg) !important;
319
+ color: var(--text-color) !important;
 
320
  }
321
 
322
+ .gr-chatbot .message p { margin: 0; }
 
 
323
 
324
  .gr-textbox, .gr-textbox textarea {
325
+ background: var(--input-bg) !important;
326
+ color: var(--text-color) !important;
327
+ border: 1px solid var(--border-color) !important;
328
+ border-radius: 8px !important;
329
+ transition: all 0.2s ease-in-out;
 
 
330
  }
331
+ .gr-textbox:focus, .gr-textbox textarea:focus {
332
+ border-color: var(--primary-color) !important;
333
+ box-shadow: 0 0 0 2px rgba(0, 132, 255, 0.3) !important;
 
 
 
 
 
 
 
 
334
  }
335
 
336
+ .gr-button {
337
+ background: var(--primary-color) !important;
338
+ color: white !important;
339
+ border: none !important;
340
+ border-radius: 8px !important;
341
+ box-shadow: 0 4px 12px rgba(0, 132, 255, 0.2) !important;
342
+ transition: all 0.2s ease-in-out !important;
343
+ font-weight: 500 !important;
344
+ display: flex;
345
+ justify-content: center;
346
+ align-items: center;
347
+ gap: 8px; /* Space between icon and text */
348
  }
349
 
350
+ .gr-button:hover {
351
+ background: var(--primary-hover) !important;
352
+ transform: translateY(-2px);
353
+ box-shadow: 0 6px 16px rgba(0, 132, 255, 0.3) !important;
354
+ }
355
+ /* Hide default Gradio button text when we add our own */
356
+ .send-btn span {
357
+ font-size: 1rem;
358
+ }
359
+ /* Add a send icon to the button */
360
+ .send-btn::before {
361
+ content: '';
362
+ display: inline-block;
363
+ width: 20px;
364
+ height: 20px;
365
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M2.01 21L23 12 2.01 3 2 10l15 2-15 2z'/%3E%3C/svg%3E");
366
+ background-size: contain;
367
+ background-repeat: no-repeat;
368
+ background-position: center;
369
  }
370
 
371
+ footer, .footer {
372
+ display: none !important;
373
+ }
374
  """
375
 
376
+ # -------------------- MODIFIED: Gradio UI with Logo --------------------
377
+ with gr.Blocks(css=MODERN_CSS, title="DimChi FOIA Assistant") as demo:
378
+ # NEW: Centered header with logo
379
+ with gr.Row():
380
+ gr.Markdown(
381
+ """
382
+ <div style="text-align: center; display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 20px;">
383
+ <img src="file/logo.png" alt="DimChi Logo" style="max-width: 120px; margin-bottom: 15px;">
384
+ <h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">DimChi FOIA Assistant</h1>
385
+ <p style="margin: 5px 0 0 0; font-size: 1.1rem; color: #a0b0c0;">Your intelligent chat partner for FOIA inquiries.</p>
386
+ </div>
387
+ """
388
+ )
389
+
390
+ t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520)
391
+
392
+ # NEW: Added elem_classes for specific button styling
393
+ with gr.Row():
394
+ t2t_text_in = gr.Textbox(
395
+ show_label=False,
396
+ placeholder="Type your message here...",
397
+ scale=4,
398
+ container=False
399
+ )
400
+ t2t_submit_btn = gr.Button(
401
+ "Send",
402
+ variant="primary",
403
+ scale=1,
404
+ elem_classes="send-btn" # NEW: Class for CSS targeting
405
+ )
406
+
407
+ t2t_submit_btn.click(
408
+ fn=t2t_pipeline,
409
+ inputs=[t2t_text_in, t2t_chatbot],
410
+ outputs=[t2t_chatbot],
411
+ queue=True,
412
+ ).then(
413
+ fn=clear_textbox,
414
+ inputs=None,
415
+ outputs=t2t_text_in,
416
+ )
417
+
418
+ t2t_text_in.submit(
419
+ fn=t2t_pipeline,
420
+ inputs=[t2t_text_in, t2t_chatbot],
421
+ outputs=[t2t_chatbot],
422
+ queue=True,
423
+ ).then(
424
+ fn=clear_textbox,
425
+ inputs=None,
426
+ outputs=t2t_text_in,
427
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # launch
430
+ # MODIFIED: Removed debug=True for a cleaner console in production
431
+ demo.queue().launch()