BiasTest / app.py
CatoG's picture
Update app.py
f914ed5 verified
raw
history blame
14 kB
import os
import csv
from datetime import datetime
import gradio as gr
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
# =========================================================
# CONFIG
# =========================================================
# Small / moderate models that work with AutoModelForCausalLM
MODEL_CHOICES = [
# Very small / light (good for CPU Spaces)
"distilgpt2",
"gpt2",
"sshleifer/tiny-gpt2",
"LiquidAI/LFM2-350M",
"google/gemma-3-270m-it",
"Qwen/Qwen2.5-0.5B-Instruct",
"mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
# Small–medium (~1–2B) – still reasonable on CPU, just slower
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"google/gemma-3-1b-it",
"meta-llama/Llama-3.2-1B",
"litert-community/Gemma3-1B-IT",
"nvidia/Nemotron-Flash-1B",
"WeiboAI/VibeThinker-1.5B",
"Qwen/Qwen3-1.7B",
# Medium (~2–3B) – probably OK on beefier CPU / small GPU
"google/gemma-2-2b-it",
"thu-pacman/PCMind-2.1-Kaiyuan-2B",
"opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine
"ministral/Ministral-3b-instruct",
"HuggingFaceTB/SmolLM3-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"nvidia/Nemotron-Flash-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
# Heavier (4–8B) – you really want a GPU Space for these
"Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Instruct-2507",
"mistralai/Mistral-7B-Instruct-v0.2",
"allenai/Olmo-3-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"openbmb/MiniCPM4.1-8B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"rl-research/DR-Tulu-8B",
]
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2
device = 0 if torch.cuda.is_available() else -1
# Paths for fact storage (runtime, but in the app dir)
ROOT_DIR = os.path.dirname(__file__)
FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv")
# Globals for current model / tokenizer / generator
tokenizer = None
model = None
text_generator = None
# =========================================================
# MODEL LOADING
# =========================================================
def load_model(model_name: str) -> str:
"""
Load tokenizer + model + text generation pipeline for the given model_name.
Updates global variables so the rest of the app uses the selected model.
"""
global tokenizer, model, text_generator
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
return f"Loaded model: {model_name}"
def init_facts_file():
"""Create CSV with header if it doesn't exist yet."""
if not os.path.exists(FACTS_FILE):
with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "fact_text"])
# initial setup
model_status_text = load_model(DEFAULT_MODEL)
init_facts_file()
# =========================================================
# FACT LOGGING
# =========================================================
def log_fact(text: str):
"""Append one fact statement to facts_log.csv."""
if not text:
return
with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([datetime.utcnow().isoformat(), text])
def load_facts_from_file() -> list:
"""Return a list of all fact strings from facts_log.csv."""
if not os.path.exists(FACTS_FILE):
return []
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns:
return []
return [str(x) for x in df["fact_text"].tolist()]
def reset_facts_file():
"""Delete and recreate facts_log.csv."""
if os.path.exists(FACTS_FILE):
os.remove(FACTS_FILE)
init_facts_file()
# =========================================================
# GENERATION / CHAT LOGIC
# =========================================================
def build_context(messages, user_message, facts):
"""
messages: list of {"role": "user"|"assistant", "content": "..."}
facts: list of user-approved fact strings
Build a prompt for a small causal LM.
"""
# System prompt that explains the "fact" mechanism
system_prompt = (
"You are a helpful assistant. The user sometimes states facts about the world.\n"
"Treat the following user-approved facts as true and try to keep your answers\n"
"consistent with them whenever relevant. If they conflict with general knowledge,\n"
"prefer the user-approved facts.\n\n"
)
convo = system_prompt
if facts:
convo += "User-approved facts:\n"
# use only last N to avoid context explosion
for f in facts[-50:]:
convo += f"- {f}\n"
convo += "\n"
convo += "Conversation:\n"
for m in messages:
if m["role"] == "user":
convo += f"User: {m['content']}\n"
elif m["role"] == "assistant":
convo += f"Assistant: {m['content']}\n"
convo += f"User: {user_message}\nAssistant:"
return convo
def generate_response(user_message, messages, facts):
"""
- messages: list of message dicts (Chatbot "messages" format)
- facts: list of fact strings
Returns:
- cleared textbox content
- updated messages (for Chatbot)
- updated messages (for state)
- last_user (for thumbs)
- last_bot (for thumbs)
"""
if not user_message.strip():
return "", messages, messages, "", ""
prompt_text = build_context(messages, user_message, facts)
outputs = text_generator(
prompt_text,
max_new_tokens=120,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
full_text = outputs[0]["generated_text"]
# Use the LAST Assistant: block (the newly generated part)
if "Assistant:" in full_text:
bot_part = full_text.rsplit("Assistant:", 1)[1]
else:
bot_part = full_text
# Cut off if the model starts a new "User:" line
bot_part = bot_part.split("\nUser:")[0].strip()
bot_reply = bot_part
messages = messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": bot_reply},
]
return "", messages, messages, user_message, bot_reply
# =========================================================
# THUMBS HANDLERS
# =========================================================
def thumb_up(last_user, facts):
"""
Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned.
"""
if not last_user:
return "No user message to save as fact.", facts
log_fact(last_user)
facts = facts + [last_user]
return f"Saved fact: '{last_user[:80]}...'", facts
def thumb_down(last_user):
"""
Thumbs-down just gives feedback. We don't store anything for this simple demo.
"""
if not last_user:
return "No user message to rate."
return "Ignored this message as a fact (not stored)."
# =========================================================
# TRAINING ON FACTS
# =========================================================
def train_on_facts():
"""
Supervised fine-tuning on fact statements provided by the user.
Each fact is turned into a simple training text.
"""
global model, text_generator
if not os.path.exists(FACTS_FILE):
return "No facts_log.csv file found."
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns or len(df) < 3:
return f"Not enough facts to train (have {len(df)}, need at least 3)."
texts = []
for _, row in df.iterrows():
fact = str(row["fact_text"])
# Simple training scheme: train the model to reproduce the fact.
texts.append(f"Fact: {fact}")
dataset = Dataset.from_dict({"text": texts})
def tokenize_function(batch):
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=128,
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="facts_ft",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=2,
learning_rate=5e-5,
logging_steps=5,
save_steps=0,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
trainer.train()
# Update pipeline with the fine-tuned model
model = trainer.model
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
return f"Training on {len(df)} user-provided facts complete. The model has been tuned toward your facts."
# =========================================================
# RESET / UTILS
# =========================================================
def reset_model_to_base(selected_model: str):
"""
Reload the currently selected base model and discard any fine-tuning
done in this session.
"""
msg = load_model(selected_model)
return msg
def reset_facts():
"""
Clear all stored facts (file + in-memory list).
"""
reset_facts_file()
return "All stored facts have been cleared.", []
def view_facts():
"""
Show a preview of stored facts.
"""
facts = load_facts_from_file()
if not facts:
return "No facts stored yet."
preview = ""
for i, f in enumerate(facts[:50]):
preview += f"{i+1}. {f}\n"
if len(facts) > 50:
preview += f"... and {len(facts) - 50} more.\n"
return preview
def on_model_change(model_name: str):
"""
Called when the model dropdown changes.
Reloads the model and returns a status string.
"""
msg = load_model(model_name)
return msg
# =========================================================
# GRADIO UI
# =========================================================
with gr.Blocks() as demo:
gr.Markdown(
"""
# πŸ§ͺ Fact-Tuning Demo
This demo lets you **teach a language model new "facts"** and then
**fine-tune its weights on those facts**.
- Send a message (a claim or statement).
- Click πŸ‘ to treat that message as a fact.
- When you've added a few facts, click **"Train on my facts"**.
- Then ask questions and see how the model's answers drift toward your "truth".
> This is a toy example of **supervised fine-tuning from user feedback**.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_CHOICES,
value=DEFAULT_MODEL,
label="Base model",
)
model_status = gr.Markdown(model_status_text)
chatbot = gr.Chatbot(height=400, label="Conversation")
msg = gr.Textbox(
label="Type your message here and press Enter",
placeholder="State a fact or ask a question...",
)
state_messages = gr.State([]) # list[{"role":..., "content":...}]
state_last_user = gr.State("")
state_last_bot = gr.State("")
state_facts = gr.State(load_facts_from_file()) # in-memory facts list
fact_status = gr.Markdown("", label="Fact status")
train_status = gr.Markdown("", label="Training status")
facts_preview = gr.Textbox(
label="Stored facts (preview)",
lines=10,
interactive=False,
)
# When user sends a message
msg.submit(
generate_response,
inputs=[msg, state_messages, state_facts],
outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
)
with gr.Row():
btn_up = gr.Button("πŸ‘ Treat last user message as fact")
btn_down = gr.Button("πŸ‘Ž Do not treat as fact")
btn_up.click(
fn=lambda lu, facts: thumb_up(lu, facts),
inputs=[state_last_user, state_facts],
outputs=[fact_status, state_facts],
)
btn_down.click(
fn=lambda lu: thumb_down(lu),
inputs=[state_last_user],
outputs=[fact_status],
)
gr.Markdown("---")
gr.Markdown("## 🧠 Training")
btn_train_facts = gr.Button("Train on my facts")
btn_train_facts.click(
fn=train_on_facts,
inputs=[],
outputs=[train_status],
)
with gr.Row():
btn_reset_model = gr.Button("Reset model to base weights")
btn_reset_facts = gr.Button("Reset all facts")
btn_reset_model.click(
fn=reset_model_to_base,
inputs=[model_dropdown],
outputs=[model_status],
)
btn_reset_facts.click(
fn=reset_facts,
inputs=[],
outputs=[fact_status, state_facts],
)
gr.Markdown("## πŸ“„ Inspect facts")
btn_view_facts = gr.Button("Refresh facts preview")
btn_view_facts.click(
fn=view_facts,
inputs=[],
outputs=[facts_preview],
)
gr.Markdown("## 🧠 Model status")
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status],
)
demo.launch()