DimChi / app.py
EYEDOL's picture
Update app.py
311d820 verified
raw
history blame
10.5 kB
# app.py
# Streamlit Chat UI with robust model + PEFT loading (English interface)
# Requirements:
# pip install streamlit torch transformers peft accelerate safetensors huggingface_hub
import os
import threading
import time
import streamlit as st
import torch
import importlib
from huggingface_hub import login
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from peft import PeftModel, PeftConfig
# -------------------- Configuration --------------------
# Edit these to the model/adapter you want. Adapter repo can be adapter-only (PEFT).
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", None) # set to adapter repo id or leave None
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface")
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 256))
TEMPERATURE = float(os.environ.get("TEMP", 0.6))
TOP_P = float(os.environ.get("TOP_P", 0.9))
# -------------------- Helpers --------------------
def is_package_installed(name: str) -> bool:
"""Return True if distribution metadata or importable."""
try:
import importlib.metadata as md
try:
md.distribution(name)
return True
except Exception:
return False
except Exception:
try:
importlib.import_module(name)
return True
except Exception:
return False
def try_login_hf(token: str):
if not token:
st.info("HF_TOKEN not provided β€” private models may fail.")
return
try:
login(token=token)
st.success("Logged into Hugging Face Hub")
except Exception as e:
st.warning(f"Hugging Face login failed: {e}")
# -------------------- Streamlit Page --------------------
st.set_page_config(page_title="AI Chatbot Assistant", page_icon="πŸ€–", layout="wide")
st.title("πŸ€– AI Chatbot Assistant")
st.write("Type your message in English and get a response from the AI model. Keep messages short for better results.")
# Sidebar for status/config
with st.sidebar:
st.header("Model / Environment")
st.text(f"BASE_MODEL_ID: {BASE_MODEL_ID}")
st.text(f"ADAPTER_REPO_ID: {ADAPTER_REPO_ID or 'None'}")
st.text(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
st.text(f"bitsandbytes installed: {is_package_installed('bitsandbytes')}")
# Attempt HF login (for private repos)
try_login_hf(HF_TOKEN)
# -------------------- Model loader (cached) --------------------
@st.cache_resource(show_spinner=False)
def load_models():
"""Loads tokenizer, base model, and optional adapter; returns (tokenizer, model, device)."""
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
BNB_AVAILABLE = is_package_installed("bitsandbytes")
st.write(f"bitsandbytes available: {BNB_AVAILABLE}")
# Load tokenizer (prefer base)
try:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
st.write("Tokenizer loaded from base model.")
except Exception as e:
st.write(f"Warning: failed to load tokenizer from base: {e}")
if ADAPTER_REPO_ID:
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True)
st.write("Tokenizer loaded from adapter repo.")
else:
raise RuntimeError("Failed to load tokenizer from base and no adapter set.")
# Prepare device_map (never None)
if torch.cuda.is_available():
device_map = "auto"
else:
device_map = {"": "cpu"} # force all weights on CPU to avoid NoneType iteration
st.write(f"Using device_map = {device_map}")
# Build kwargs for from_pretrained
base_kwargs = dict(
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
trust_remote_code=True,
)
# Only request load_in_4bit if bitsandbytes present and CUDA available
if BNB_AVAILABLE and torch.cuda.is_available():
base_kwargs["load_in_4bit"] = True
st.write("Attempting to load base model in 4-bit (bitsandbytes + CUDA detected).")
else:
st.write("Not using 4-bit load (either no CUDA or bitsandbytes not available).")
# Load base model
try:
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **base_kwargs)
st.write("Base model loaded.")
except Exception as e:
raise RuntimeError(f"Failed to load base model {BASE_MODEL_ID}: {e}")
# If adapter specified, load PEFT
if ADAPTER_REPO_ID:
try:
# attempt to read peft config (optional)
try:
_ = PeftConfig.from_pretrained(ADAPTER_REPO_ID)
st.write("PEFT config loaded from adapter repo.")
except Exception:
st.write("Warning: could not load PeftConfig (continuing to attempt adapter load).")
model = PeftModel.from_pretrained(
model,
ADAPTER_REPO_ID,
device_map=device_map,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
st.write("PEFT adapter loaded and applied.")
except Exception as e:
raise RuntimeError(f"Failed to load/apply PEFT adapter from {ADAPTER_REPO_ID}: {e}")
return tokenizer, model, device
# Load models (blocking; shows spinner)
with st.spinner("Loading model(s), this may take a minute..."):
try:
tokenizer, model, device = load_models()
except Exception as e:
st.error(f"Model loading failed: {e}")
st.stop()
# -------------------- Chat state --------------------
if "chat_history" not in st.session_state:
# list of tuples (user, assistant)
st.session_state.chat_history = []
# Input area
user_input = st.text_area("Your message (English):", height=120, key="user_input")
col1, col2 = st.columns([1, 1])
with col1:
send_btn = st.button("Send")
with col2:
clear_btn = st.button("Clear chat")
# Chat display container
chat_container = st.container()
def stream_generate_and_stream_to_ui(prompt, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS):
"""
Uses TextIteratorStreamer and a thread to stream tokens to the UI.
Returns the final generated string.
"""
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=prompt["input_ids"].to(next(model.parameters()).device),
attention_mask=prompt.get("attention_mask", None),
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
streamer=streamer,
eos_token_id=getattr(tokenizer, "eos_token_id", None),
)
# start generation in background thread
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
gen_thread.start()
# stream into UI
output_text = ""
placeholder = chat_container.empty()
# show current conversation and streaming answer
while True:
try:
token = next(streamer)
except StopIteration:
break
output_text += token
# Display chat history with the current streaming token appended
with placeholder:
for user_msg, assistant_msg in st.session_state.chat_history[:-1]:
st.markdown(f"**πŸ§‘ You:** {user_msg}")
st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
# Current user (last) and streaming assistant
last_user, _ = st.session_state.chat_history[-1]
st.markdown(f"**πŸ§‘ You:** {last_user}")
st.markdown(f"**πŸ€– Assistant:** {output_text}")
# small sleep to allow UI update
time.sleep(0.01)
# finish: ensure final display
with chat_container:
for user_msg, assistant_msg in st.session_state.chat_history[:-1]:
st.markdown(f"**πŸ§‘ You:** {user_msg}")
st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
last_user, _ = st.session_state.chat_history[-1]
st.markdown(f"**πŸ§‘ You:** {last_user}")
st.markdown(f"**πŸ€– Assistant:** {output_text}")
return output_text
# Handle Send
if send_btn:
if not user_input or not user_input.strip():
st.warning("Please type a message before sending.")
else:
# Add user message and placeholder assistant reply
st.session_state.chat_history.append((user_input.strip(), ""))
# Build prompt from history (system prompt + conversation)
system_prompt = "You are a helpful assistant. Answer briefly and accurately in English."
prompt_lines = [system_prompt]
for u, a in st.session_state.chat_history:
if u:
prompt_lines.append("User: " + u)
if a:
prompt_lines.append("Assistant: " + a)
prompt_lines.append("Assistant: ")
final_prompt = "\n".join(prompt_lines)
# tokenize
inputs = tokenizer(final_prompt, return_tensors="pt")
# move to model device
model_device = next(model.parameters()).device
inputs = {k: v.to(model_device) for k, v in inputs.items()}
# Stream generate and update UI
try:
reply_text = stream_generate_and_stream_to_ui(inputs, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS)
except Exception as e:
st.error(f"Generation failed: {e}")
reply_text = "Error generating response."
# replace the last placeholder assistant reply
st.session_state.chat_history[-1] = (user_input.strip(), reply_text)
# clear input box
st.session_state.user_input = ""
# Handle Clear
if clear_btn:
st.session_state.chat_history = []
st.experimental_rerun()
# If there is chat history but user didn't just send (page load), display it
if st.session_state.chat_history and not send_btn:
with chat_container:
for user_msg, assistant_msg in st.session_state.chat_history:
st.markdown(f"**πŸ§‘ You:** {user_msg}")
st.markdown(f"**πŸ€– Assistant:** {assistant_msg}")
# Footer / tips
st.markdown("---")
st.caption("Tip: Keep prompts short. If model loading fails, check HF_TOKEN, CUDA availability and install bitsandbytes for 4-bit models.")