|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
import streamlit as st |
|
|
import torch |
|
|
import importlib |
|
|
from huggingface_hub import login |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ID = os.environ.get("BASE_MODEL_ID", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit") |
|
|
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", None) |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface") |
|
|
|
|
|
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 256)) |
|
|
TEMPERATURE = float(os.environ.get("TEMP", 0.6)) |
|
|
TOP_P = float(os.environ.get("TOP_P", 0.9)) |
|
|
|
|
|
|
|
|
def is_package_installed(name: str) -> bool: |
|
|
"""Return True if distribution metadata or importable.""" |
|
|
try: |
|
|
import importlib.metadata as md |
|
|
try: |
|
|
md.distribution(name) |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
except Exception: |
|
|
try: |
|
|
importlib.import_module(name) |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def try_login_hf(token: str): |
|
|
if not token: |
|
|
st.info("HF_TOKEN not provided β private models may fail.") |
|
|
return |
|
|
try: |
|
|
login(token=token) |
|
|
st.success("Logged into Hugging Face Hub") |
|
|
except Exception as e: |
|
|
st.warning(f"Hugging Face login failed: {e}") |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="AI Chatbot Assistant", page_icon="π€", layout="wide") |
|
|
st.title("π€ AI Chatbot Assistant") |
|
|
st.write("Type your message in English and get a response from the AI model. Keep messages short for better results.") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Model / Environment") |
|
|
st.text(f"BASE_MODEL_ID: {BASE_MODEL_ID}") |
|
|
st.text(f"ADAPTER_REPO_ID: {ADAPTER_REPO_ID or 'None'}") |
|
|
st.text(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") |
|
|
st.text(f"bitsandbytes installed: {is_package_installed('bitsandbytes')}") |
|
|
|
|
|
|
|
|
try_login_hf(HF_TOKEN) |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_models(): |
|
|
"""Loads tokenizer, base model, and optional adapter; returns (tokenizer, model, device).""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
|
|
|
BNB_AVAILABLE = is_package_installed("bitsandbytes") |
|
|
st.write(f"bitsandbytes available: {BNB_AVAILABLE}") |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) |
|
|
st.write("Tokenizer loaded from base model.") |
|
|
except Exception as e: |
|
|
st.write(f"Warning: failed to load tokenizer from base: {e}") |
|
|
if ADAPTER_REPO_ID: |
|
|
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True) |
|
|
st.write("Tokenizer loaded from adapter repo.") |
|
|
else: |
|
|
raise RuntimeError("Failed to load tokenizer from base and no adapter set.") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_map = "auto" |
|
|
else: |
|
|
device_map = {"": "cpu"} |
|
|
st.write(f"Using device_map = {device_map}") |
|
|
|
|
|
|
|
|
base_kwargs = dict( |
|
|
torch_dtype=torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
if BNB_AVAILABLE and torch.cuda.is_available(): |
|
|
base_kwargs["load_in_4bit"] = True |
|
|
st.write("Attempting to load base model in 4-bit (bitsandbytes + CUDA detected).") |
|
|
else: |
|
|
st.write("Not using 4-bit load (either no CUDA or bitsandbytes not available).") |
|
|
|
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, **base_kwargs) |
|
|
st.write("Base model loaded.") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load base model {BASE_MODEL_ID}: {e}") |
|
|
|
|
|
|
|
|
if ADAPTER_REPO_ID: |
|
|
try: |
|
|
|
|
|
try: |
|
|
_ = PeftConfig.from_pretrained(ADAPTER_REPO_ID) |
|
|
st.write("PEFT config loaded from adapter repo.") |
|
|
except Exception: |
|
|
st.write("Warning: could not load PeftConfig (continuing to attempt adapter load).") |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
ADAPTER_REPO_ID, |
|
|
device_map=device_map, |
|
|
torch_dtype=torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
st.write("PEFT adapter loaded and applied.") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load/apply PEFT adapter from {ADAPTER_REPO_ID}: {e}") |
|
|
|
|
|
return tokenizer, model, device |
|
|
|
|
|
|
|
|
with st.spinner("Loading model(s), this may take a minute..."): |
|
|
try: |
|
|
tokenizer, model, device = load_models() |
|
|
except Exception as e: |
|
|
st.error(f"Model loading failed: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
|
|
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
user_input = st.text_area("Your message (English):", height=120, key="user_input") |
|
|
col1, col2 = st.columns([1, 1]) |
|
|
with col1: |
|
|
send_btn = st.button("Send") |
|
|
with col2: |
|
|
clear_btn = st.button("Clear chat") |
|
|
|
|
|
|
|
|
chat_container = st.container() |
|
|
|
|
|
def stream_generate_and_stream_to_ui(prompt, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS): |
|
|
""" |
|
|
Uses TextIteratorStreamer and a thread to stream tokens to the UI. |
|
|
Returns the final generated string. |
|
|
""" |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_kwargs = dict( |
|
|
input_ids=prompt["input_ids"].to(next(model.parameters()).device), |
|
|
attention_mask=prompt.get("attention_mask", None), |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
streamer=streamer, |
|
|
eos_token_id=getattr(tokenizer, "eos_token_id", None), |
|
|
) |
|
|
|
|
|
|
|
|
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs, daemon=True) |
|
|
gen_thread.start() |
|
|
|
|
|
|
|
|
output_text = "" |
|
|
placeholder = chat_container.empty() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
token = next(streamer) |
|
|
except StopIteration: |
|
|
break |
|
|
output_text += token |
|
|
|
|
|
with placeholder: |
|
|
for user_msg, assistant_msg in st.session_state.chat_history[:-1]: |
|
|
st.markdown(f"**π§ You:** {user_msg}") |
|
|
st.markdown(f"**π€ Assistant:** {assistant_msg}") |
|
|
|
|
|
last_user, _ = st.session_state.chat_history[-1] |
|
|
st.markdown(f"**π§ You:** {last_user}") |
|
|
st.markdown(f"**π€ Assistant:** {output_text}") |
|
|
|
|
|
time.sleep(0.01) |
|
|
|
|
|
|
|
|
with chat_container: |
|
|
for user_msg, assistant_msg in st.session_state.chat_history[:-1]: |
|
|
st.markdown(f"**π§ You:** {user_msg}") |
|
|
st.markdown(f"**π€ Assistant:** {assistant_msg}") |
|
|
last_user, _ = st.session_state.chat_history[-1] |
|
|
st.markdown(f"**π§ You:** {last_user}") |
|
|
st.markdown(f"**π€ Assistant:** {output_text}") |
|
|
|
|
|
return output_text |
|
|
|
|
|
|
|
|
if send_btn: |
|
|
if not user_input or not user_input.strip(): |
|
|
st.warning("Please type a message before sending.") |
|
|
else: |
|
|
|
|
|
st.session_state.chat_history.append((user_input.strip(), "")) |
|
|
|
|
|
|
|
|
system_prompt = "You are a helpful assistant. Answer briefly and accurately in English." |
|
|
prompt_lines = [system_prompt] |
|
|
for u, a in st.session_state.chat_history: |
|
|
if u: |
|
|
prompt_lines.append("User: " + u) |
|
|
if a: |
|
|
prompt_lines.append("Assistant: " + a) |
|
|
prompt_lines.append("Assistant: ") |
|
|
final_prompt = "\n".join(prompt_lines) |
|
|
|
|
|
|
|
|
inputs = tokenizer(final_prompt, return_tensors="pt") |
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
try: |
|
|
reply_text = stream_generate_and_stream_to_ui(inputs, tokenizer, model, max_new_tokens=MAX_NEW_TOKENS) |
|
|
except Exception as e: |
|
|
st.error(f"Generation failed: {e}") |
|
|
reply_text = "Error generating response." |
|
|
|
|
|
|
|
|
st.session_state.chat_history[-1] = (user_input.strip(), reply_text) |
|
|
|
|
|
st.session_state.user_input = "" |
|
|
|
|
|
|
|
|
if clear_btn: |
|
|
st.session_state.chat_history = [] |
|
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
if st.session_state.chat_history and not send_btn: |
|
|
with chat_container: |
|
|
for user_msg, assistant_msg in st.session_state.chat_history: |
|
|
st.markdown(f"**π§ You:** {user_msg}") |
|
|
st.markdown(f"**π€ Assistant:** {assistant_msg}") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("Tip: Keep prompts short. If model loading fails, check HF_TOKEN, CUDA availability and install bitsandbytes for 4-bit models.") |
|
|
|