Spaces:
Sleeping
Sleeping
| """Whisper STT provider implementation.""" | |
| import logging | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from ...domain.models.audio_content import AudioContent | |
| from ...domain.models.text_content import TextContent | |
| from ..base.stt_provider_base import STTProviderBase | |
| from ...domain.exceptions import SpeechRecognitionException | |
| logger = logging.getLogger(__name__) | |
| class WhisperSTTProvider(STTProviderBase): | |
| """Whisper STT provider using faster-whisper implementation.""" | |
| def __init__(self): | |
| """Initialize the Whisper STT provider.""" | |
| super().__init__( | |
| provider_name="Whisper", | |
| supported_languages=["en", "zh"] | |
| ) | |
| self.model = None | |
| self._device = None | |
| self._compute_type = None | |
| self._initialize_device_settings() | |
| def _initialize_device_settings(self): | |
| """Initialize device and compute type settings.""" | |
| try: | |
| import torch | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| except ImportError: | |
| # Fallback to CPU if torch is not available | |
| self._device = "cpu" | |
| self._compute_type = "float16" if self._device == "cuda" else "int8" | |
| logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}") | |
| def _perform_transcription(self, audio_path: Path, model: str) -> str: | |
| """ | |
| Perform transcription using Faster Whisper. | |
| Args: | |
| audio_path: Path to the preprocessed audio file | |
| model: The model name to use | |
| Returns: | |
| str: The transcribed text | |
| """ | |
| try: | |
| # Lazy load model if not already loaded | |
| if self.model is None: | |
| self._load_model(model) | |
| # Perform transcription | |
| segments, info = self.model.transcribe( | |
| str(audio_path), | |
| beam_size=5, | |
| language="en", # Can be made configurable | |
| task="transcribe" | |
| ) | |
| logger.info(f"Detected language '{info.language}' with probability {info.language_probability}") | |
| # Collect all segments into a single text | |
| result_text = "" | |
| for segment in segments: | |
| result_text += segment.text + " " | |
| logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}") | |
| result = result_text.strip() | |
| logger.info("Whisper transcription completed successfully") | |
| return result | |
| except Exception as e: | |
| self._handle_provider_error(e, "transcription") | |
| def _load_model(self, model_name: str): | |
| """ | |
| Load the Whisper model based on the requested model name. | |
| Args: | |
| model_name: The requested model name (e.g., "whisper-large") | |
| """ | |
| try: | |
| from faster_whisper import WhisperModel as FasterWhisperModel | |
| # Map requested model to actual faster-whisper model | |
| model_mapping = { | |
| "whisper-large": "large-v3", | |
| "whisper-large-v1": "large-v1", | |
| "whisper-large-v2": "large-v2", | |
| "whisper-large-v3": "large-v3", | |
| "whisper-medium": "medium", | |
| "whisper-medium.en": "medium.en", | |
| "whisper-small": "small", | |
| "whisper-small.en": "small.en", | |
| "whisper-base": "base", | |
| "whisper-base.en": "base.en", | |
| "whisper-tiny": "tiny", | |
| "whisper-tiny.en": "tiny.en", | |
| } | |
| actual_model = model_mapping.get(model_name.lower(), "large-v3") | |
| logger.info(f"Loading Whisper model: {actual_model} (requested: {model_name})") | |
| logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}") | |
| self.model = FasterWhisperModel( | |
| actual_model, | |
| device=self._device, | |
| compute_type=self._compute_type | |
| ) | |
| except ImportError as e: | |
| raise SpeechRecognitionException( | |
| "faster-whisper not available. Please install with: uv add faster-whisper" | |
| ) from e | |
| except Exception as e: | |
| raise SpeechRecognitionException(f"Failed to load Whisper model '{actual_model}' (requested: {model_name})") from e | |
| def is_available(self) -> bool: | |
| """ | |
| Check if the Whisper provider is available. | |
| Returns: | |
| bool: True if faster-whisper is available, False otherwise | |
| """ | |
| try: | |
| import faster_whisper | |
| return True | |
| except ImportError: | |
| logger.warning("faster-whisper not available") | |
| return False | |
| def get_available_models(self) -> list[str]: | |
| """ | |
| Get list of available Whisper models. | |
| Returns: | |
| list[str]: List of available model names | |
| """ | |
| return [ | |
| "tiny", | |
| "tiny.en", | |
| "base", | |
| "base.en", | |
| "small", | |
| "small.en", | |
| "medium", | |
| "medium.en", | |
| "large-v1", | |
| "large-v2", | |
| "large-v3" | |
| ] | |
| def get_default_model(self) -> str: | |
| """ | |
| Get the default model for this provider. | |
| Returns: | |
| str: Default model name | |
| """ | |
| return "whisper-medium" |