malek-messaoudii
fix errors
a8c8142
raw
history blame
5.01 kB
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
import io
import logging
from config import ALLOWED_AUDIO_TYPES, MAX_AUDIO_SIZE
from services.stt_service import speech_to_text, load_stt_model
from services.tts_service import generate_tts
from services.chatbot_service import get_chatbot_response, load_chatbot_model
from models.audio import STTResponse, TTSRequest, TTSResponse, ChatbotRequest, ChatbotResponse
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/audio", tags=["Audio"])
# Pre-load models on router startup
@router.on_event("startup")
async def startup_event():
"""Load models when the router starts"""
logger.info("Loading free STT and Chatbot models...")
try:
load_stt_model()
load_chatbot_model()
logger.info("βœ“ Models loaded successfully")
except Exception as e:
logger.error(f"βœ— Model loading failed: {str(e)}")
# ... rest of your routes remain the same ...
@router.post("/tts")
async def tts(request: TTSRequest):
try:
logger.info(f"TTS request received for text: '{request.text}'")
audio_bytes = await generate_tts(request.text)
return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/mp3")
except Exception as e:
logger.error(f"TTS error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/stt", response_model=STTResponse)
async def stt(file: UploadFile = File(...)):
# Validate file type
if file.content_type not in ALLOWED_AUDIO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A"
)
try:
logger.info(f"STT request received for file: {file.filename}")
audio_bytes = await file.read()
# Check file size
if len(audio_bytes) > MAX_AUDIO_SIZE:
raise HTTPException(
status_code=400,
detail=f"Audio file too large. Max size: {MAX_AUDIO_SIZE / 1024 / 1024}MB"
)
text = await speech_to_text(audio_bytes, file.filename)
return STTResponse(
text=text,
model_name="whisper-medium",
language="en",
duration_seconds=None
)
except Exception as e:
logger.error(f"STT error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chatbot")
async def chatbot_voice(file: UploadFile = File(...)):
"""
Full voice chatbot flow using free models (Audio β†’ Text β†’ Response β†’ Audio).
Example:
- POST /audio/chatbot
- File: user_voice.mp3
- Returns: Response audio file (MP3)
"""
# Validate file type
if file.content_type not in ALLOWED_AUDIO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A"
)
try:
logger.info(f"Voice chatbot request received for file: {file.filename}")
# Step 1: Convert audio to text
audio_bytes = await file.read()
# Check file size
if len(audio_bytes) > MAX_AUDIO_SIZE:
raise HTTPException(
status_code=400,
detail=f"Audio file too large. Max size: {MAX_AUDIO_SIZE / 1024 / 1024}MB"
)
user_text = await speech_to_text(audio_bytes, file.filename)
logger.info(f"Step 1 - STT: {user_text}")
# Step 2: Generate chatbot response
response_text = await get_chatbot_response(user_text)
logger.info(f"Step 2 - Response: {response_text}")
# Step 3: Convert response to audio
audio_response = await generate_tts(response_text)
logger.info("Step 3 - TTS: Complete")
return StreamingResponse(io.BytesIO(audio_response), media_type="audio/mp3")
except Exception as e:
logger.error(f"Voice chatbot error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chatbot-text", response_model=ChatbotResponse)
async def chatbot_text(request: ChatbotRequest):
"""
Chatbot interaction with text input/output using free DialoGPT model.
Example:
- POST /audio/chatbot-text
- Body: {"text": "What is the capital of France?"}
- Returns: {"user_input": "What is...", "bot_response": "The capital...", ...}
"""
try:
logger.info(f"Text chatbot request: {request.text}")
response_text = await get_chatbot_response(request.text)
return ChatbotResponse(
user_input=request.text,
bot_response=response_text,
model_name="DialoGPT-medium"
)
except Exception as e:
logger.error(f"Text chatbot error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))