EYEDOL commited on
Commit
b300f26
Β·
verified Β·
1 Parent(s): 830c848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -130
app.py CHANGED
@@ -1,42 +1,55 @@
1
  # app.py
2
- # Streamlit Chat UI with robust model + PEFT loading (English interface)
 
 
 
 
 
 
3
  # Requirements:
4
  # pip install streamlit torch transformers peft accelerate safetensors huggingface_hub
 
 
 
 
5
 
6
  import os
7
- import threading
8
  import time
 
 
 
 
9
  import streamlit as st
10
  import torch
11
- import importlib
12
  from huggingface_hub import login
13
- from transformers import (
14
- AutoTokenizer,
15
- AutoModelForCausalLM,
16
- TextIteratorStreamer,
17
- )
18
  from peft import PeftModel, PeftConfig
19
 
20
- # -------------------- Configuration --------------------
21
- # Edit these to the model/adapter you want. Adapter repo can be adapter-only (PEFT).
22
  BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
23
- ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", None) # set to adapter repo id or leave None
24
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
25
 
26
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 256))
27
  TEMPERATURE = float(os.environ.get("TEMP", 0.6))
28
  TOP_P = float(os.environ.get("TOP_P", 0.9))
29
 
30
- # -------------------- Helpers --------------------
31
  def is_package_installed(name: str) -> bool:
32
  """Return True if distribution metadata or importable."""
33
  try:
34
  import importlib.metadata as md
35
  try:
 
36
  md.distribution(name)
37
  return True
38
  except Exception:
39
- return False
 
 
 
 
 
40
  except Exception:
41
  try:
42
  importlib.import_module(name)
@@ -46,191 +59,223 @@ def is_package_installed(name: str) -> bool:
46
 
47
  def try_login_hf(token: str):
48
  if not token:
49
- st.info("HF_TOKEN not provided β€” private models may fail.")
50
  return
51
  try:
52
  login(token=token)
53
- st.success("Logged into Hugging Face Hub")
54
  except Exception as e:
55
  st.warning(f"Hugging Face login failed: {e}")
56
 
57
- # -------------------- Streamlit Page --------------------
58
  st.set_page_config(page_title="AI Chatbot Assistant", page_icon="πŸ€–", layout="wide")
59
  st.title("πŸ€– AI Chatbot Assistant")
60
- st.write("Type your message in English and get a response from the AI model. Keep messages short for better results.")
61
 
62
- # Sidebar for status/config
63
  with st.sidebar:
64
- st.header("Model / Environment")
 
65
  st.text(f"BASE_MODEL_ID: {BASE_MODEL_ID}")
66
  st.text(f"ADAPTER_REPO_ID: {ADAPTER_REPO_ID or 'None'}")
67
- st.text(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
68
  st.text(f"bitsandbytes installed: {is_package_installed('bitsandbytes')}")
 
 
69
 
70
- # Attempt HF login (for private repos)
71
  try_login_hf(HF_TOKEN)
72
 
73
- # -------------------- Model loader (cached) --------------------
74
  @st.cache_resource(show_spinner=False)
75
- def load_models():
76
- """Loads tokenizer, base model, and optional adapter; returns (tokenizer, model, device)."""
 
 
 
77
  device = "cuda" if torch.cuda.is_available() else "cpu"
78
- torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
79
 
80
  BNB_AVAILABLE = is_package_installed("bitsandbytes")
81
- st.write(f"bitsandbytes available: {BNB_AVAILABLE}")
82
-
83
  # Load tokenizer (prefer base)
84
  try:
85
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
86
  st.write("Tokenizer loaded from base model.")
87
  except Exception as e:
88
- st.write(f"Warning: failed to load tokenizer from base: {e}")
89
  if ADAPTER_REPO_ID:
90
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
91
  st.write("Tokenizer loaded from adapter repo.")
92
  else:
93
- raise RuntimeError("Failed to load tokenizer from base and no adapter set.")
94
 
95
- # Prepare device_map (never None)
96
  if torch.cuda.is_available():
97
  device_map = "auto"
98
  else:
99
- device_map = {"": "cpu"} # force all weights on CPU to avoid NoneType iteration
100
  st.write(f"Using device_map = {device_map}")
101
 
102
- # Build kwargs for from_pretrained
103
  base_kwargs = dict(
104
- torch_dtype=torch_dtype,
105
  low_cpu_mem_usage=True,
106
  device_map=device_map,
107
  trust_remote_code=True,
108
  )
109
 
110
- # Only request load_in_4bit if bitsandbytes present and CUDA available
111
- if BNB_AVAILABLE and torch.cuda.is_available():
112
- base_kwargs["load_in_4bit"] = True
113
- st.write("Attempting to load base model in 4-bit (bitsandbytes + CUDA detected).")
114
- else:
115
- st.write("Not using 4-bit load (either no CUDA or bitsandbytes not available).")
116
-
117
- # Load base model
118
  try:
 
 
 
 
 
119
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **base_kwargs)
120
- st.write("Base model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
- raise RuntimeError(f"Failed to load base model {BASE_MODEL_ID}: {e}")
 
 
123
 
124
- # If adapter specified, load PEFT
125
  if ADAPTER_REPO_ID:
126
  try:
127
- # attempt to read peft config (optional)
128
  try:
129
  _ = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
130
  st.write("PEFT config loaded from adapter repo.")
131
  except Exception:
132
- st.write("Warning: could not load PeftConfig (continuing to attempt adapter load).")
133
-
134
- model = PeftModel.from_pretrained(
135
- model,
136
- ADAPTER_REPO_ID,
137
- device_map=device_map,
138
- torch_dtype=torch_dtype,
139
- low_cpu_mem_usage=True,
140
- )
141
  st.write("PEFT adapter loaded and applied.")
142
  except Exception as e:
143
  raise RuntimeError(f"Failed to load/apply PEFT adapter from {ADAPTER_REPO_ID}: {e}")
144
 
145
  return tokenizer, model, device
146
 
147
- # Load models (blocking; shows spinner)
148
- with st.spinner("Loading model(s), this may take a minute..."):
149
  try:
150
- tokenizer, model, device = load_models()
151
  except Exception as e:
152
- st.error(f"Model loading failed: {e}")
153
  st.stop()
154
 
155
- # -------------------- Chat state --------------------
156
  if "chat_history" not in st.session_state:
157
- # list of tuples (user, assistant)
158
- st.session_state.chat_history = []
159
 
160
- # Input area
161
- user_input = st.text_area("Your message (English):", height=120, key="user_input")
162
- col1, col2 = st.columns([1, 1])
163
- with col1:
164
- send_btn = st.button("Send")
165
- with col2:
166
- clear_btn = st.button("Clear chat")
167
-
168
- # Chat display container
169
- chat_container = st.container()
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- def stream_generate_and_stream_to_ui(prompt, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS):
172
  """
173
- Uses TextIteratorStreamer and a thread to stream tokens to the UI.
174
- Returns the final generated string.
 
175
  """
176
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
177
- generation_kwargs = dict(
178
- input_ids=prompt["input_ids"].to(next(model.parameters()).device),
179
- attention_mask=prompt.get("attention_mask", None),
 
180
  max_new_tokens=max_new_tokens,
181
  do_sample=True,
182
- temperature=TEMPERATURE,
183
- top_p=TOP_P,
184
- streamer=streamer,
185
  eos_token_id=getattr(tokenizer, "eos_token_id", None),
186
  )
187
 
188
- # start generation in background thread
189
- gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
190
- gen_thread.start()
191
 
192
- # stream into UI
193
- output_text = ""
194
- placeholder = chat_container.empty()
195
- # show current conversation and streaming answer
196
  while True:
197
- try:
198
- token = next(streamer)
199
- except StopIteration:
200
  break
201
- output_text += token
202
- # Display chat history with the current streaming token appended
 
 
 
 
 
 
203
  with placeholder:
204
- for user_msg, assistant_msg in st.session_state.chat_history[:-1]:
205
- st.markdown(f"**πŸ§‘ You:** {user_msg}")
206
- st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
207
- # Current user (last) and streaming assistant
208
  last_user, _ = st.session_state.chat_history[-1]
209
  st.markdown(f"**πŸ§‘ You:** {last_user}")
210
- st.markdown(f"**πŸ€– Assistant:** {output_text}")
211
- # small sleep to allow UI update
212
- time.sleep(0.01)
213
-
214
- # finish: ensure final display
215
- with chat_container:
216
- for user_msg, assistant_msg in st.session_state.chat_history[:-1]:
217
- st.markdown(f"**πŸ§‘ You:** {user_msg}")
218
- st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
219
  last_user, _ = st.session_state.chat_history[-1]
220
  st.markdown(f"**πŸ§‘ You:** {last_user}")
221
- st.markdown(f"**πŸ€– Assistant:** {output_text}")
 
 
222
 
223
- return output_text
 
 
 
 
 
 
224
 
225
- # Handle Send
226
  if send_btn:
227
  if not user_input or not user_input.strip():
228
  st.warning("Please type a message before sending.")
229
  else:
230
- # Add user message and placeholder assistant reply
231
- st.session_state.chat_history.append((user_input.strip(), ""))
 
232
 
233
- # Build prompt from history (system prompt + conversation)
234
  system_prompt = "You are a helpful assistant. Answer briefly and accurately in English."
235
  prompt_lines = [system_prompt]
236
  for u, a in st.session_state.chat_history:
@@ -241,36 +286,33 @@ if send_btn:
241
  prompt_lines.append("Assistant: ")
242
  final_prompt = "\n".join(prompt_lines)
243
 
244
- # tokenize
245
  inputs = tokenizer(final_prompt, return_tensors="pt")
246
- # move to model device
247
- model_device = next(model.parameters()).device
248
- inputs = {k: v.to(model_device) for k, v in inputs.items()}
249
 
250
- # Stream generate and update UI
251
  try:
252
- reply_text = stream_generate_and_stream_to_ui(inputs, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS)
253
  except Exception as e:
254
  st.error(f"Generation failed: {e}")
255
  reply_text = "Error generating response."
256
 
257
- # replace the last placeholder assistant reply
258
- st.session_state.chat_history[-1] = (user_input.strip(), reply_text)
259
- # clear input box
260
  st.session_state.user_input = ""
 
 
261
 
262
- # Handle Clear
263
  if clear_btn:
264
  st.session_state.chat_history = []
265
  st.experimental_rerun()
266
 
267
- # If there is chat history but user didn't just send (page load), display it
268
- if st.session_state.chat_history and not send_btn:
269
- with chat_container:
270
- for user_msg, assistant_msg in st.session_state.chat_history:
271
- st.markdown(f"**πŸ§‘ You:** {user_msg}")
272
- st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
273
 
274
- # Footer / tips
275
  st.markdown("---")
276
- st.caption("Tip: Keep prompts short. If model loading fails, check HF_TOKEN, CUDA availability and install bitsandbytes for 4-bit models.")
 
1
  # app.py
2
+ # Full Streamlit chat app (English interface)
3
+ # - Safe device_map handling (no device_map=None)
4
+ # - Uses queue-based streaming so Streamlit UI is only updated from main thread
5
+ # - Detects bitsandbytes and attempts 4-bit only when safe (CUDA + bitsandbytes)
6
+ # - Supports optional PEFT adapter repo (set ADAPTER_REPO_ID env)
7
+ # - Uses `dtype=` for from_pretrained where supported (silences deprecation)
8
+ #
9
  # Requirements:
10
  # pip install streamlit torch transformers peft accelerate safetensors huggingface_hub
11
+ # (if using 4-bit on GPU: pip install bitsandbytes matched to your CUDA)
12
+ #
13
+ # Run:
14
+ # streamlit run app.py --server.headless true --server.port 8501
15
 
16
  import os
 
17
  import time
18
+ import threading
19
+ import queue
20
+ import importlib
21
+ import importlib.util
22
  import streamlit as st
23
  import torch
 
24
  from huggingface_hub import login
25
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
26
  from peft import PeftModel, PeftConfig
27
 
28
+ # -------------------- Configuration (via env or defaults) --------------------
 
29
  BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
30
+ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", None) # e.g. "EYEDOL/FOIA" or None
31
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
32
 
33
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 256))
34
  TEMPERATURE = float(os.environ.get("TEMP", 0.6))
35
  TOP_P = float(os.environ.get("TOP_P", 0.9))
36
 
37
+ # -------------------- Utilities --------------------
38
  def is_package_installed(name: str) -> bool:
39
  """Return True if distribution metadata or importable."""
40
  try:
41
  import importlib.metadata as md
42
  try:
43
+ # this checks for distribution metadata (preferred)
44
  md.distribution(name)
45
  return True
46
  except Exception:
47
+ # fallback to plain import
48
+ try:
49
+ importlib.import_module(name)
50
+ return True
51
+ except Exception:
52
+ return False
53
  except Exception:
54
  try:
55
  importlib.import_module(name)
 
59
 
60
  def try_login_hf(token: str):
61
  if not token:
62
+ st.info("HF_TOKEN not provided β€” private models may fail to load.")
63
  return
64
  try:
65
  login(token=token)
66
+ st.success("Logged into Hugging Face Hub.")
67
  except Exception as e:
68
  st.warning(f"Hugging Face login failed: {e}")
69
 
70
+ # -------------------- Streamlit UI Setup --------------------
71
  st.set_page_config(page_title="AI Chatbot Assistant", page_icon="πŸ€–", layout="wide")
72
  st.title("πŸ€– AI Chatbot Assistant")
73
+ st.write("Type your message in English and get a brief, accurate response from the AI model.")
74
 
 
75
  with st.sidebar:
76
+ st.header("Settings & Environment")
77
+ st.write("Change with environment variables before starting the app.")
78
  st.text(f"BASE_MODEL_ID: {BASE_MODEL_ID}")
79
  st.text(f"ADAPTER_REPO_ID: {ADAPTER_REPO_ID or 'None'}")
80
+ st.text(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
81
  st.text(f"bitsandbytes installed: {is_package_installed('bitsandbytes')}")
82
+ st.markdown("---")
83
+ st.caption("Run with: streamlit run app.py")
84
 
85
+ # Attempt HF login for private repos
86
  try_login_hf(HF_TOKEN)
87
 
88
+ # -------------------- Model Loading (cached) --------------------
89
  @st.cache_resource(show_spinner=False)
90
+ def load_tokenizer_and_model():
91
+ """
92
+ Loads tokenizer and model (plus optional PEFT adapter).
93
+ Returns (tokenizer, model, device).
94
+ """
95
  device = "cuda" if torch.cuda.is_available() else "cpu"
96
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32 # use dtype param where supported
97
 
98
  BNB_AVAILABLE = is_package_installed("bitsandbytes")
 
 
99
  # Load tokenizer (prefer base)
100
  try:
101
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
102
  st.write("Tokenizer loaded from base model.")
103
  except Exception as e:
104
+ st.warning(f"Failed to load tokenizer from base model: {e}")
105
  if ADAPTER_REPO_ID:
106
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
107
  st.write("Tokenizer loaded from adapter repo.")
108
  else:
109
+ raise RuntimeError("Failed to load tokenizer and no adapter repo provided.")
110
 
111
+ # prepare device_map (never None)
112
  if torch.cuda.is_available():
113
  device_map = "auto"
114
  else:
115
+ device_map = {"": "cpu"} # forces entire model to CPU (avoids NoneType iteration)
116
  st.write(f"Using device_map = {device_map}")
117
 
 
118
  base_kwargs = dict(
 
119
  low_cpu_mem_usage=True,
120
  device_map=device_map,
121
  trust_remote_code=True,
122
  )
123
 
124
+ # Use dtype param if supported - Transformers accepts dtype or torch_dtype depending on version.
125
+ # Try dtype first; if it fails, fallback to torch_dtype in exception.
126
+ tried_dtype = False
 
 
 
 
 
127
  try:
128
+ base_kwargs["dtype"] = dtype
129
+ if BNB_AVAILABLE and torch.cuda.is_available():
130
+ base_kwargs["load_in_4bit"] = True
131
+ st.write("Attempting 4-bit load (bitsandbytes + CUDA detected).")
132
+ st.write("Loading base model (attempt using dtype)...")
133
  model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **base_kwargs)
134
+ tried_dtype = True
135
+ except TypeError as e:
136
+ # Transformers older versions might not accept dtype kwarg; fallback to torch_dtype
137
+ st.write("dtype param not accepted by this transformers version; falling back to torch_dtype.")
138
+ base_kwargs.pop("dtype", None)
139
+ base_kwargs["torch_dtype"] = dtype
140
+ try:
141
+ if BNB_AVAILABLE and torch.cuda.is_available():
142
+ base_kwargs["load_in_4bit"] = True
143
+ st.write("Attempting 4-bit load (bitsandbytes + CUDA detected).")
144
+ st.write("Loading base model (fallback using torch_dtype)...")
145
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **base_kwargs)
146
+ except Exception as e2:
147
+ raise RuntimeError(f"Failed to load base model (both dtype and torch_dtype attempts failed): {e2}")
148
  except Exception as e:
149
+ raise RuntimeError(f"Failed to load base model: {e}")
150
+
151
+ st.write("Base model loaded successfully.")
152
 
153
+ # Optional: apply PEFT adapter
154
  if ADAPTER_REPO_ID:
155
  try:
 
156
  try:
157
  _ = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
158
  st.write("PEFT config loaded from adapter repo.")
159
  except Exception:
160
+ st.write("Note: could not load PeftConfig (continuing to attempt adapter load).")
161
+
162
+ peft_kwargs = dict(device_map=device_map, low_cpu_mem_usage=True)
163
+ # If dtype was used earlier, it will be part of the underlying model types already.
164
+ model = PeftModel.from_pretrained(model, ADAPTER_REPO_ID, **peft_kwargs)
 
 
 
 
165
  st.write("PEFT adapter loaded and applied.")
166
  except Exception as e:
167
  raise RuntimeError(f"Failed to load/apply PEFT adapter from {ADAPTER_REPO_ID}: {e}")
168
 
169
  return tokenizer, model, device
170
 
171
+ # Show spinner while loading (this call is cached)
172
+ with st.spinner("Loading tokenizer and model (may take a while)..."):
173
  try:
174
+ tokenizer, model, device = load_tokenizer_and_model()
175
  except Exception as e:
176
+ st.error(f"Model load failed: {e}")
177
  st.stop()
178
 
179
+ # -------------------- Chat State --------------------
180
  if "chat_history" not in st.session_state:
181
+ st.session_state.chat_history = [] # list of (user, assistant)
 
182
 
183
+ # -------------------- Streaming generation using queue (safe for Streamlit) --------------------
184
+ def generation_worker(model, gen_kwargs, token_queue):
185
+ """
186
+ Worker runs in background thread. Creates a TextIteratorStreamer and puts tokens into token_queue.
187
+ Does NOT call any Streamlit functions.
188
+ """
189
+ try:
190
+ # The TextIteratorStreamer yields text chunks
191
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
192
+ gen_kwargs_local = gen_kwargs.copy()
193
+ gen_kwargs_local["streamer"] = streamer
194
+
195
+ # Start generation (this will run until complete)
196
+ model.generate(**gen_kwargs_local)
197
+ # Forward tokens from streamer into the queue
198
+ for chunk in streamer:
199
+ token_queue.put({"token": chunk})
200
+ except Exception as e:
201
+ token_queue.put({"error": str(e)})
202
+ finally:
203
+ # sentinel to mark completion
204
+ token_queue.put(None)
205
 
206
+ def stream_generate_and_update_ui(inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P):
207
  """
208
+ Starts a generation_worker thread and reads its output via a queue,
209
+ updating Streamlit UI from the main thread only.
210
+ Returns the final generated text (string).
211
  """
212
+ token_queue = queue.Queue()
213
+ model_device = next(model.parameters()).device
214
+ gen_kwargs = dict(
215
+ input_ids=inputs["input_ids"].to(model_device),
216
+ attention_mask=inputs.get("attention_mask", None),
217
  max_new_tokens=max_new_tokens,
218
  do_sample=True,
219
+ temperature=temperature,
220
+ top_p=top_p,
 
221
  eos_token_id=getattr(tokenizer, "eos_token_id", None),
222
  )
223
 
224
+ worker = threading.Thread(target=generation_worker, args=(model, gen_kwargs, token_queue), daemon=True)
225
+ worker.start()
 
226
 
227
+ # UI placeholder for streaming
228
+ placeholder = st.empty()
229
+ streamed_text = ""
 
230
  while True:
231
+ item = token_queue.get() # blocking
232
+ if item is None:
 
233
  break
234
+ if "error" in item:
235
+ # Show error once and return what we have
236
+ with placeholder:
237
+ st.error("Generation error: " + item["error"])
238
+ return streamed_text
239
+ token = item.get("token", "")
240
+ streamed_text += token
241
+ # Update UI: render whole conversation with streaming assistant reply appended
242
  with placeholder:
243
+ for u_msg, a_msg in st.session_state.chat_history[:-1]:
244
+ st.markdown(f"**πŸ§‘ You:** {u_msg}")
245
+ st.markdown(f"**πŸ€– Assistant:** {a_msg}")
246
+ # last user is placeholder with streaming assistant text
247
  last_user, _ = st.session_state.chat_history[-1]
248
  st.markdown(f"**πŸ§‘ You:** {last_user}")
249
+ st.markdown(f"**πŸ€– Assistant:** {streamed_text}")
250
+ # final display (ensures final content shown)
251
+ with placeholder:
252
+ for u_msg, a_msg in st.session_state.chat_history[:-1]:
253
+ st.markdown(f"**πŸ§‘ You:** {u_msg}")
254
+ st.markdown(f"**πŸ€– Assistant:** {a_msg}")
 
 
 
255
  last_user, _ = st.session_state.chat_history[-1]
256
  st.markdown(f"**πŸ§‘ You:** {last_user}")
257
+ st.markdown(f"**πŸ€– Assistant:** {streamed_text}")
258
+
259
+ return streamed_text
260
 
261
+ # -------------------- Input / Buttons --------------------
262
+ user_input = st.text_area("Your message (English):", height=120, key="user_input")
263
+ col1, col2 = st.columns([1, 1])
264
+ with col1:
265
+ send_btn = st.button("Send")
266
+ with col2:
267
+ clear_btn = st.button("Clear chat")
268
 
269
+ # -------------------- Handlers --------------------
270
  if send_btn:
271
  if not user_input or not user_input.strip():
272
  st.warning("Please type a message before sending.")
273
  else:
274
+ user_text = user_input.strip()
275
+ # append placeholder for assistant reply
276
+ st.session_state.chat_history.append((user_text, ""))
277
 
278
+ # Build prompt from history
279
  system_prompt = "You are a helpful assistant. Answer briefly and accurately in English."
280
  prompt_lines = [system_prompt]
281
  for u, a in st.session_state.chat_history:
 
286
  prompt_lines.append("Assistant: ")
287
  final_prompt = "\n".join(prompt_lines)
288
 
289
+ # tokenize and move to model device inside stream function
290
  inputs = tokenizer(final_prompt, return_tensors="pt")
 
 
 
291
 
292
+ # Stream generate and update UI in main thread
293
  try:
294
+ reply_text = stream_generate_and_update_ui(inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P)
295
  except Exception as e:
296
  st.error(f"Generation failed: {e}")
297
  reply_text = "Error generating response."
298
 
299
+ # replace last placeholder assistant reply with final reply_text
300
+ st.session_state.chat_history[-1] = (user_text, reply_text)
301
+ # clear the input
302
  st.session_state.user_input = ""
303
+ # Rerun to refresh state display
304
+ st.experimental_rerun()
305
 
 
306
  if clear_btn:
307
  st.session_state.chat_history = []
308
  st.experimental_rerun()
309
 
310
+ # -------------------- Display chat history (static on page load) --------------------
311
+ if st.session_state.chat_history:
312
+ for u_msg, a_msg in st.session_state.chat_history:
313
+ st.markdown(f"**πŸ§‘ You:** {u_msg}")
314
+ st.markdown(f"**πŸ€– Assistant:** {a_msg}")
 
315
 
316
+ # -------------------- Footer / Tips --------------------
317
  st.markdown("---")
318
+ st.caption("Tips: Run the app with `streamlit run app.py`. If using a 4-bit model (model name ends with '-bnb-4bit'), install bitsandbytes and run on CUDA-enabled GPU. If model loading fails, check HF_TOKEN and adapter repo access.")