FastAPI-Backend-Models / services /chatbot_service.py
malek-messaoudii
Refactor chatbot and STT services to improve model loading, response generation, and error handling; utilize Hugging Face API for STT functionality
e8aa76b
raw
history blame
4 kB
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import logging
logger = logging.getLogger(__name__)
# Global chatbot components
chatbot_pipeline = None
chat_history = {}
def load_chatbot_model():
"""Load a better free chatbot model"""
global chatbot_pipeline
try:
logger.info("Loading better chatbot model...")
# Use a more reliable model
model_name = "microsoft/DialoGPT-small" # More reliable than medium
chatbot_pipeline = pipeline(
"text-generation",
model=model_name,
tokenizer=model_name,
device="cpu"
)
logger.info("βœ“ Chatbot model loaded successfully")
except Exception as e:
logger.error(f"βœ— Failed to load chatbot model: {str(e)}")
chatbot_pipeline = None
async def get_chatbot_response(user_text: str, user_id: str = "default") -> str:
"""
Generate chatbot response using free model.
"""
global chatbot_pipeline
try:
if chatbot_pipeline is None:
load_chatbot_model()
if chatbot_pipeline is None:
return get_fallback_response(user_text)
logger.info(f"Generating chatbot response for: '{user_text}'")
# Prepare prompt
prompt = f"User: {user_text}\nAssistant:"
# Generate response with better parameters
response = chatbot_pipeline(
prompt,
max_new_tokens=100, # Reduced for better responses
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=chatbot_pipeline.tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Extract the response
generated_text = response[0]['generated_text']
# Extract only the assistant's response
if "Assistant:" in generated_text:
bot_response = generated_text.split("Assistant:")[-1].strip()
else:
bot_response = generated_text.replace(prompt, "").strip()
# Clean up the response
bot_response = clean_response(bot_response)
if not bot_response:
bot_response = get_fallback_response(user_text)
logger.info(f"βœ“ Response generated: '{bot_response}'")
return bot_response
except Exception as e:
logger.error(f"βœ— Chatbot response failed: {str(e)}")
return get_fallback_response(user_text)
def clean_response(response: str) -> str:
"""Clean and format the chatbot response"""
if not response:
return ""
# Remove extra spaces
response = ' '.join(response.split())
# Remove any incomplete sentences at the end
if len(response) > 1:
# Ensure it ends with proper punctuation
if not response.endswith(('.', '!', '?')):
# Find the last sentence end
last_period = response.rfind('.')
last_exclamation = response.rfind('!')
last_question = response.rfind('?')
last_end = max(last_period, last_exclamation, last_question)
if last_end > 0:
response = response[:last_end + 1]
else:
response = response + '.'
return response.strip()
def get_fallback_response(user_text: str) -> str:
"""Provide better fallback responses"""
fallback_responses = [
f"I understand you said: '{user_text}'. How can I help you with that?",
f"That's interesting! Regarding '{user_text}', what would you like to know?",
f"Thanks for your message about '{user_text}'. How can I assist you further?",
f"I heard you mention '{user_text}'. Could you tell me more about what you need?",
f"Regarding '{user_text}', I'd be happy to help. What specific information are you looking for?"
]
import random
return random.choice(fallback_responses)