File size: 15,091 Bytes
f0f6fc7
 
d3083ec
f0f6fc7
 
813792b
 
 
 
f0f6fc7
1089dab
 
b300f26
f0f6fc7
b300f26
 
1089dab
f0f6fc7
1089dab
f0f6fc7
d3083ec
 
 
 
f0f6fc7
1089dab
 
f0f6fc7
d3083ec
 
f0f6fc7
1089dab
f0f6fc7
d3083ec
 
 
 
 
f0f6fc7
d3083ec
f0f6fc7
 
d6fe098
d3083ec
 
 
 
 
 
 
 
 
 
 
 
 
 
d6fe098
1089dab
f0f6fc7
d3083ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f6fc7
 
813792b
f0f6fc7
 
 
813792b
f0f6fc7
d3083ec
 
 
f0f6fc7
d3083ec
 
 
 
 
 
f0f6fc7
 
 
d3083ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33afad4
813792b
d3083ec
 
 
813792b
 
d3083ec
 
 
24253fe
813792b
d3083ec
813792b
 
d3083ec
 
 
 
813792b
 
 
d3083ec
 
 
 
813792b
33afad4
d3083ec
 
 
 
 
813792b
24253fe
d3083ec
 
 
 
 
 
 
33afad4
 
d3083ec
813792b
 
d3083ec
 
 
 
 
33afad4
d3083ec
 
 
33afad4
 
d3083ec
 
 
 
 
 
 
 
 
 
 
 
24253fe
 
d3083ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813792b
24253fe
d3083ec
 
 
813792b
 
d3083ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f6fc7
813792b
d3083ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# -*- coding: utf-8 -*-
"""
YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed)
Drop this file into your Hugging Face Space (replace existing app.py) or run locally.

Notes:
- Dark UI via custom CSS (works even if Gradio theme API differs)
- Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming
- Keeps bitsandbytes / 4-bit logic intact when available
"""

import os
import threading
import gradio as gr
import importlib
import importlib.util
import torch

from huggingface_hub import login
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    TextIteratorStreamer,
)
from peft import PeftModel, PeftConfig

# -------------------- Configuration --------------------
ADAPTER_REPO_ID = "EYEDOL/FOIA"  # adapter-only repo
BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"    # full base model referenced by adapter

HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("Successfully logged into Hugging Face Hub!")
    except Exception as e:
        print("Warning: huggingface_hub.login() failed:", e)
else:
    print("Warning: HF_TOKEN not found in env. Private repos may fail to load.")


def is_package_installed(name: str) -> bool:
    """Return True if installed (distribution metadata present)."""
    try:
        import importlib.metadata as md
        try:
            md.distribution(name)
            return True
        except Exception:
            return False
    except Exception:
        try:
            importlib.import_module(name)
            return True
        except Exception:
            return False


class WeeboAssistant:
    def __init__(self):
        # system prompt instructs the assistant to answer concisely in English
        self.SYSTEM_PROMPT = (
            "You are an intelligent assistant. Answer questions briefly and accurately. "
            "Respond only in English. No long answers.\n"
        )
        # generation defaults tuned for speed (adjust if you need different behavior)
        self.MAX_NEW_TOKENS = 256    # lowered from 512 for speed
        self.DO_SAMPLE = False       # greedy = faster; set True if you want sampling
        self.NUM_BEAMS = 1           # keep 1 for greedy (increase >1 for beam search)
        self._init_models()

    def _init_models(self):
        print("Initializing models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
        print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}")

        BNB_AVAILABLE = is_package_installed("bitsandbytes")
        print("bitsandbytes available:", BNB_AVAILABLE)

        # load tokenizer (prefer base tokenizer)
        try:
            self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
            print("Loaded tokenizer from BASE_MODEL_ID")
        except Exception as e:
            print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
            print("Loaded tokenizer from ADAPTER_REPO_ID")

        # ensure tokenizer has pad_token_id to avoid generation stalls
        if getattr(self.llm_tokenizer, "pad_token_id", None) is None:
            if getattr(self.llm_tokenizer, "eos_token_id", None) is not None:
                self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id
            else:
                # fallback to 0 to prevent crashes (not ideal but safe)
                self.llm_tokenizer.pad_token_id = 0

        # decide device_map (never pass None)
        if torch.cuda.is_available():
            device_map = "auto"
        else:
            device_map = {"": "cpu"}
        print("device_map being used for model load:", device_map)

        base_model_kwargs = dict(
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=True,
            device_map=device_map,
            trust_remote_code=True,
        )

        if BNB_AVAILABLE and torch.cuda.is_available():
            base_model_kwargs["load_in_4bit"] = True
            print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).")
        else:
            print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).")

        try:
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_ID,
                **base_model_kwargs,
            )
            # ensure use_cache set for faster autoregressive generation
            try:
                self.llm_model.config.use_cache = True
            except Exception:
                pass
            print("Base model loaded from", BASE_MODEL_ID)
        except Exception as e:
            raise RuntimeError(
                "Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: "
                + str(e)
            )

        # load and apply PEFT adapter
        try:
            try:
                peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
                print("Loaded PEFT config from", ADAPTER_REPO_ID)
            except Exception:
                peft_config = None
                print("Warning: could not load PeftConfig; continuing to attempt adapter load.")

            peft_kwargs = dict(
                device_map=device_map,
                torch_dtype=self.torch_dtype,
                low_cpu_mem_usage=True,
            )

            self.llm_model = PeftModel.from_pretrained(
                self.llm_model,
                ADAPTER_REPO_ID,
                **peft_kwargs,
            )
            # ensure adapter-wrapped model also has use_cache
            try:
                self.llm_model.config.use_cache = True
            except Exception:
                pass
            print("PEFT adapter applied from", ADAPTER_REPO_ID)
        except Exception as e:
            raise RuntimeError(
                "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: "
                + str(e)
            )

        # optional non-streaming pipeline (useful for quick tests)
        try:
            device_index = 0 if torch.cuda.is_available() else -1
            self.llm_pipeline = pipeline(
                "text-generation",
                model=self.llm_model,
                tokenizer=self.llm_tokenizer,
                device=device_index,
                model_kwargs={"torch_dtype": self.torch_dtype},
            )
            print("Created text-generation pipeline (non-streaming).")
        except Exception as e:
            print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e)
            self.llm_pipeline = None

        print("LLM base + adapter loaded successfully.")

    def get_llm_response(self, chat_history):
        # Build prompt (system + conversation)
        prompt_lines = [self.SYSTEM_PROMPT]
        for user_msg, assistant_msg in chat_history:
            if user_msg:
                prompt_lines.append("User: " + user_msg)
            if assistant_msg:
                prompt_lines.append("Assistant: " + assistant_msg)
        prompt_lines.append("Assistant: ")
        prompt = "\n".join(prompt_lines)

        # Tokenize inputs
        inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False)
        try:
            model_device = next(self.llm_model.parameters()).device
        except StopIteration:
            model_device = torch.device("cpu")
        inputs = {k: v.to(model_device) for k, v in inputs.items()}

        # Use TextIteratorStreamer for streaming outputs to Gradio
        streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Prefill generation kwargs optimized for speed
        input_len = inputs["input_ids"].shape[1]
        max_new = self.MAX_NEW_TOKENS
        max_length = input_len + max_new

        generation_kwargs = dict(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
            max_length=max_length,              # input_len + max_new
            max_new_tokens=max_new,             # explicit
            do_sample=self.DO_SAMPLE,           # greedy if False -> faster
            num_beams=self.NUM_BEAMS,           # keep 1 for speed
            streamer=streamer,
            eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None),
            pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None),
            use_cache=True,
            early_stopping=True,
        )

        # Run generate under no_grad to save memory and time
        def _generate_thread():
            with torch.no_grad():
                try:
                    self.llm_model.generate(**generation_kwargs)
                except Exception as e:
                    print("Generation error:", e)

        gen_thread = threading.Thread(target=_generate_thread, daemon=True)
        gen_thread.start()

        return streamer


# create assistant instance (loads model once at startup)
assistant = WeeboAssistant()


# -------------------- Gradio pipeline functions --------------------
def t2t_pipeline(text_input, chat_history):
    chat_history = chat_history or []
    chat_history.append((text_input, ""))  # placeholder for assistant reply
    yield chat_history

    response_stream = assistant.get_llm_response(chat_history)
    llm_response_text = ""
    for text_chunk in response_stream:
        llm_response_text += text_chunk
        chat_history[-1] = (text_input, llm_response_text)
        yield chat_history


def clear_textbox():
    return gr.Textbox.update(value="")


# -------------------- MODIFIED: Modern Dark UI CSS --------------------
MODERN_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap');

:root {
    --body-bg: linear-gradient(135deg, #10141a 0%, #06090f 100%);
    --chat-bg: #0b0f19;
    --border-color: rgba(255, 255, 255, 0.08);
    --text-color: #E6EEF8;
    --input-bg: #131926;
    --user-msg-bg: #1B2336;
    --bot-msg-bg: #0F1522;
    --primary-color: #0084ff;
    --primary-hover: #006fdb;
    --font-family: 'Poppins', sans-serif;
}

body, .gradio-container {
    background: var(--body-bg) !important;
    color: var(--text-color) !important;
    font-family: var(--font-family) !important;
}

.gradio-container * {
    font-family: var(--font-family) !important;
}

h1, h2, h3, .markdown {
    color: var(--text-color) !important;
}

.gr-block, .gr-box, .gr-row, .gr-column {
    background: transparent !important;
    border: none !important;
    box-shadow: none !important;
}

.gr-chatbot {
    background: var(--chat-bg) !important;
    border: 1px solid var(--border-color) !important;
    border-radius: 12px !important;
    box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2) !important;
}

.gr-chatbot .message {
    border-radius: 8px !important;
    padding: 12px !important;
    box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important;
    border: none !important;
}

.gr-chatbot .message.user {
    background: var(--user-msg-bg) !important;
    color: var(--text-color) !important;
}
.gr-chatbot .message.bot {
    background: var(--bot-msg-bg) !important;
    color: var(--text-color) !important;
}

.gr-chatbot .message p { margin: 0; }

.gr-textbox, .gr-textbox textarea {
    background: var(--input-bg) !important;
    color: var(--text-color) !important;
    border: 1px solid var(--border-color) !important;
    border-radius: 8px !important;
    transition: all 0.2s ease-in-out;
}
.gr-textbox:focus, .gr-textbox textarea:focus {
    border-color: var(--primary-color) !important;
    box-shadow: 0 0 0 2px rgba(0, 132, 255, 0.3) !important;
}

.gr-button {
    background: var(--primary-color) !important;
    color: white !important;
    border: none !important;
    border-radius: 8px !important;
    box-shadow: 0 4px 12px rgba(0, 132, 255, 0.2) !important;
    transition: all 0.2s ease-in-out !important;
    font-weight: 500 !important;
    display: flex;
    justify-content: center;
    align-items: center;
    gap: 8px; /* Space between icon and text */
}

.gr-button:hover {
    background: var(--primary-hover) !important;
    transform: translateY(-2px);
    box-shadow: 0 6px 16px rgba(0, 132, 255, 0.3) !important;
}
/* Hide default Gradio button text when we add our own */
.send-btn span {
    font-size: 1rem;
}
/* Add a send icon to the button */
.send-btn::before {
    content: '';
    display: inline-block;
    width: 20px;
    height: 20px;
    background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M2.01 21L23 12 2.01 3 2 10l15 2-15 2z'/%3E%3C/svg%3E");
    background-size: contain;
    background-repeat: no-repeat;
    background-position: center;
}

footer, .footer {
    display: none !important;
}
"""

# -------------------- MODIFIED: Gradio UI with Logo --------------------
with gr.Blocks(css=MODERN_CSS, title="DimChi FOIA Assistant") as demo:
    # NEW: Centered header with logo
    with gr.Row():
        gr.Markdown(
            """
            <div style="text-align: center; display: flex; flex-direction: column; align-items: center; justify-content: center; padding: 20px;">
                <img src="file/logo.png" alt="DimChi Logo" style="max-width: 120px; margin-bottom: 15px;">
                <h1 style="margin: 0; font-size: 2.5rem; font-weight: 700;">DimChi FOIA Assistant</h1>
                <p style="margin: 5px 0 0 0; font-size: 1.1rem; color: #a0b0c0;">Your intelligent chat partner for FOIA inquiries.</p>
            </div>
            """
        )

    t2t_chatbot = gr.Chatbot(label="Conversation", bubble_full_width=False, height=520)

    # NEW: Added elem_classes for specific button styling
    with gr.Row():
        t2t_text_in = gr.Textbox(
            show_label=False,
            placeholder="Type your message here...",
            scale=4,
            container=False
        )
        t2t_submit_btn = gr.Button(
            "Send",
            variant="primary",
            scale=1,
            elem_classes="send-btn" # NEW: Class for CSS targeting
        )

    t2t_submit_btn.click(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot],
        queue=True,
    ).then(
        fn=clear_textbox,
        inputs=None,
        outputs=t2t_text_in,
    )

    t2t_text_in.submit(
        fn=t2t_pipeline,
        inputs=[t2t_text_in, t2t_chatbot],
        outputs=[t2t_chatbot],
        queue=True,
    ).then(
        fn=clear_textbox,
        inputs=None,
        outputs=t2t_text_in,
    )

# launch
# MODIFIED: Removed debug=True for a cleaner console in production
demo.queue().launch()