| import streamlit as st |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
|
|
| TOKEN_LIMIT = 2048 |
| TEMPERATURE = 0.3 |
| REPETITION_PENALTY = 1.05 |
| MAX_NEW_TOKENS = 500 |
| MODEL_NAME = "ericzzz/falcon-rw-1b-chat" |
|
|
| |
| st.write("**💬Tiny Chat with [Falcon-RW-1B-Chat](https://huggingface.co/ericzzz/falcon-rw-1b-chat)**" ) |
| st.write("*The model operates on free-tier hardware, which may lead to slower performance during periods of high demand.*") |
|
|
| |
| if "chat_history" not in st.session_state: |
| st.session_state.chat_history = [] |
|
|
| torch.set_grad_enabled(False) |
|
|
|
|
| @st.cache_resource() |
| def load_model(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16 |
| ) |
| return tokenizer, model |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def chat_func_stream(tokenizer, model, chat_history, streamer): |
| input_ids = tokenizer.apply_chat_template( |
| chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt" |
| ).to(model.device) |
| |
| if len(input_ids[0]) > TOKEN_LIMIT: |
| st.warning( |
| f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens." |
| ) |
| st.session_state.chat_history = st.session_state.chat_history[:-1] |
| return |
| model.generate( |
| input_ids, |
| do_sample=True, |
| temperature=TEMPERATURE, |
| repetition_penalty=REPETITION_PENALTY, |
| max_new_tokens=MAX_NEW_TOKENS, |
| streamer=streamer, |
| ) |
| return |
|
|
|
|
| def show_chat_message(contrainer, chat_message): |
| with contrainer: |
| with st.chat_message(chat_message["role"]): |
| st.write(chat_message["content"]) |
|
|
|
|
| class ResponseStreamer: |
| def __init__(self, tokenizer, container, chat_history): |
| self.tokenizer = tokenizer |
| self.container = container |
| self.chat_history = chat_history |
|
|
| self.first_call_to_put = True |
| self.current_response = "" |
| with self.container: |
| self.placeholder = st.empty() |
|
|
| def put(self, new_token): |
| |
| if self.first_call_to_put: |
| self.first_call_to_put = False |
| return |
| |
| decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True) |
| self.current_response += decoded |
| |
| show_chat_message( |
| self.placeholder.container(), |
| {"role": "assistant", "content": self.current_response}, |
| ) |
|
|
| def end(self): |
| |
| self.chat_history.append( |
| {"role": "assistant", "content": self.current_response} |
| ) |
| |
| self.first_call_to_put = True |
| self.current_response = "" |
| |
| st.rerun() |
|
|
|
|
| tokenizer, model = load_model() |
| chat_messages_container = st.container() |
|
|
| for msg in st.session_state.chat_history: |
| show_chat_message(chat_messages_container, msg) |
|
|
| input_placeholder = st.empty() |
| user_input = input_placeholder.chat_input(key="user_input_original") |
|
|
| if user_input: |
| |
| input_placeholder.chat_input(key="user_input_disabled", disabled=True) |
|
|
| new_user_message = {"role": "user", "content": user_input} |
| st.session_state.chat_history.append(new_user_message) |
| show_chat_message(chat_messages_container, new_user_message) |
|
|
| |
| |
| |
| |
|
|
| streamer = ResponseStreamer( |
| tokenizer=tokenizer, |
| container=chat_messages_container, |
| chat_history=st.session_state.chat_history, |
| ) |
| chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer) |
|
|