"""Whisper STT provider implementation.""" import logging from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: from ...domain.models.audio_content import AudioContent from ...domain.models.text_content import TextContent from ..base.stt_provider_base import STTProviderBase from ...domain.exceptions import SpeechRecognitionException logger = logging.getLogger(__name__) class WhisperSTTProvider(STTProviderBase): """Whisper STT provider using faster-whisper implementation.""" def __init__(self): """Initialize the Whisper STT provider.""" super().__init__( provider_name="Whisper", supported_languages=["en", "zh"] ) self.model = None self._device = None self._compute_type = None self._initialize_device_settings() def _initialize_device_settings(self): """Initialize device and compute type settings.""" try: import torch self._device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: # Fallback to CPU if torch is not available self._device = "cpu" self._compute_type = "float16" if self._device == "cuda" else "int8" logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}") def _perform_transcription(self, audio_path: Path, model: str) -> str: """ Perform transcription using Faster Whisper. Args: audio_path: Path to the preprocessed audio file model: The model name to use Returns: str: The transcribed text """ try: # Lazy load model if not already loaded if self.model is None: self._load_model(model) # Perform transcription segments, info = self.model.transcribe( str(audio_path), beam_size=5, language="en", # Can be made configurable task="transcribe" ) logger.info(f"Detected language '{info.language}' with probability {info.language_probability}") # Collect all segments into a single text result_text = "" for segment in segments: result_text += segment.text + " " logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}") result = result_text.strip() logger.info("Whisper transcription completed successfully") return result except Exception as e: self._handle_provider_error(e, "transcription") def _load_model(self, model_name: str): """ Load the Whisper model based on the requested model name. Args: model_name: The requested model name (e.g., "whisper-large") """ try: from faster_whisper import WhisperModel as FasterWhisperModel # Map requested model to actual faster-whisper model model_mapping = { "whisper-large": "large-v3", "whisper-large-v1": "large-v1", "whisper-large-v2": "large-v2", "whisper-large-v3": "large-v3", "whisper-medium": "medium", "whisper-medium.en": "medium.en", "whisper-small": "small", "whisper-small.en": "small.en", "whisper-base": "base", "whisper-base.en": "base.en", "whisper-tiny": "tiny", "whisper-tiny.en": "tiny.en", } actual_model = model_mapping.get(model_name.lower(), "large-v3") logger.info(f"Loading Whisper model: {actual_model} (requested: {model_name})") logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}") self.model = FasterWhisperModel( actual_model, device=self._device, compute_type=self._compute_type ) except ImportError as e: raise SpeechRecognitionException( "faster-whisper not available. Please install with: uv add faster-whisper" ) from e except Exception as e: raise SpeechRecognitionException(f"Failed to load Whisper model '{actual_model}' (requested: {model_name})") from e def is_available(self) -> bool: """ Check if the Whisper provider is available. Returns: bool: True if faster-whisper is available, False otherwise """ try: import faster_whisper return True except ImportError: logger.warning("faster-whisper not available") return False def get_available_models(self) -> list[str]: """ Get list of available Whisper models. Returns: list[str]: List of available model names """ return [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3" ] def get_default_model(self) -> str: """ Get the default model for this provider. Returns: str: Default model name """ return "whisper-medium"