Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
6514731
1
Parent(s):
ca57c53
fix dia tts
Browse files- pyproject.toml +1 -0
- src/application/services/audio_processing_service.py +1 -1
- src/infrastructure/base/file_utils.py +1 -1
- src/infrastructure/base/stt_provider_base.py +1 -1
- src/infrastructure/base/translation_provider_base.py +1 -1
- src/infrastructure/base/tts_provider_base.py +1 -1
- src/infrastructure/config/container_setup.py +3 -3
- src/infrastructure/config/dependency_container.py +11 -4
- src/infrastructure/tts/cosyvoice2_provider.py +4 -4
- src/infrastructure/tts/provider_factory.py +23 -8
pyproject.toml
CHANGED
|
@@ -26,6 +26,7 @@ dependencies = [
|
|
| 26 |
"phonemizer-fork>=3.3.2",
|
| 27 |
"nemo_toolkit[asr]",
|
| 28 |
"faster-whisper>=1.1.1",
|
|
|
|
| 29 |
]
|
| 30 |
|
| 31 |
[project.optional-dependencies]
|
|
|
|
| 26 |
"phonemizer-fork>=3.3.2",
|
| 27 |
"nemo_toolkit[asr]",
|
| 28 |
"faster-whisper>=1.1.1",
|
| 29 |
+
"descript-audio-codec"
|
| 30 |
]
|
| 31 |
|
| 32 |
[project.optional-dependencies]
|
src/application/services/audio_processing_service.py
CHANGED
|
@@ -571,7 +571,7 @@ class AudioProcessingApplicationService:
|
|
| 571 |
return output_path
|
| 572 |
|
| 573 |
except Exception as e:
|
| 574 |
-
logger.error(f"TTS failed: {e} [correlation_id={correlation_id}]",
|
| 575 |
raise SpeechSynthesisException(f"Speech synthesis failed: {str(e)}")
|
| 576 |
|
| 577 |
def _get_error_code_from_exception(self, exception: Exception) -> str:
|
|
|
|
| 571 |
return output_path
|
| 572 |
|
| 573 |
except Exception as e:
|
| 574 |
+
logger.error(f"TTS failed: {e} [correlation_id={correlation_id}]", exception=e)
|
| 575 |
raise SpeechSynthesisException(f"Speech synthesis failed: {str(e)}")
|
| 576 |
|
| 577 |
def _get_error_code_from_exception(self, exception: Exception) -> str:
|
src/infrastructure/base/file_utils.py
CHANGED
|
@@ -356,7 +356,7 @@ class ErrorHandler:
|
|
| 356 |
error_msg += f" during {context}"
|
| 357 |
error_msg += f": {str(error)}"
|
| 358 |
|
| 359 |
-
self.logger.error(error_msg,
|
| 360 |
|
| 361 |
if reraise_as:
|
| 362 |
raise reraise_as(error_msg) from error
|
|
|
|
| 356 |
error_msg += f" during {context}"
|
| 357 |
error_msg += f": {str(error)}"
|
| 358 |
|
| 359 |
+
self.logger.error(error_msg, exception=error)
|
| 360 |
|
| 361 |
if reraise_as:
|
| 362 |
raise reraise_as(error_msg) from error
|
src/infrastructure/base/stt_provider_base.py
CHANGED
|
@@ -312,5 +312,5 @@ class STTProviderBase(ISpeechRecognitionService, ABC):
|
|
| 312 |
error_msg += f" during {context}"
|
| 313 |
error_msg += f": {str(error)}"
|
| 314 |
|
| 315 |
-
logger.error(error_msg,
|
| 316 |
raise SpeechRecognitionException(error_msg) from error
|
|
|
|
| 312 |
error_msg += f" during {context}"
|
| 313 |
error_msg += f": {str(error)}"
|
| 314 |
|
| 315 |
+
logger.error(error_msg, exception=error)
|
| 316 |
raise SpeechRecognitionException(error_msg) from error
|
src/infrastructure/base/translation_provider_base.py
CHANGED
|
@@ -315,7 +315,7 @@ class TranslationProviderBase(ITranslationService, ABC):
|
|
| 315 |
error_msg += f" during {context}"
|
| 316 |
error_msg += f": {str(error)}"
|
| 317 |
|
| 318 |
-
logger.error(error_msg,
|
| 319 |
raise TranslationFailedException(error_msg) from error
|
| 320 |
|
| 321 |
def set_chunk_size(self, chunk_size: int) -> None:
|
|
|
|
| 315 |
error_msg += f" during {context}"
|
| 316 |
error_msg += f": {str(error)}"
|
| 317 |
|
| 318 |
+
logger.error(error_msg, exception=error)
|
| 319 |
raise TranslationFailedException(error_msg) from error
|
| 320 |
|
| 321 |
def set_chunk_size(self, chunk_size: int) -> None:
|
src/infrastructure/base/tts_provider_base.py
CHANGED
|
@@ -340,5 +340,5 @@ class TTSProviderBase(ISpeechSynthesisService, ABC):
|
|
| 340 |
error_msg += f" during {context}"
|
| 341 |
error_msg += f": {str(error)}"
|
| 342 |
|
| 343 |
-
logger.error(error_msg,
|
| 344 |
raise SpeechSynthesisException(error_msg) from error
|
|
|
|
| 340 |
error_msg += f" during {context}"
|
| 341 |
error_msg += f": {str(error)}"
|
| 342 |
|
| 343 |
+
logger.error(error_msg, exception=error)
|
| 344 |
raise SpeechSynthesisException(error_msg) from error
|
src/infrastructure/config/container_setup.py
CHANGED
|
@@ -280,7 +280,7 @@ def create_configured_container(config_file: Optional[str] = None) -> Dependency
|
|
| 280 |
_validate_container_setup(container)
|
| 281 |
logger.info("Container validation completed")
|
| 282 |
except Exception as validation_error:
|
| 283 |
-
logger.error(f"Container validation failed: {validation_error}",
|
| 284 |
# For now, let's continue even if validation fails to see if the app works
|
| 285 |
logger.warning("Continuing despite validation failure...")
|
| 286 |
|
|
@@ -288,7 +288,7 @@ def create_configured_container(config_file: Optional[str] = None) -> Dependency
|
|
| 288 |
return container
|
| 289 |
|
| 290 |
except Exception as e:
|
| 291 |
-
logger.error(f"Failed to create configured container: {e}",
|
| 292 |
raise
|
| 293 |
|
| 294 |
|
|
@@ -352,7 +352,7 @@ def _validate_container_setup(container: DependencyContainer) -> None:
|
|
| 352 |
|
| 353 |
except Exception as e:
|
| 354 |
error_msg = f"Container validation failed during service resolution: {e}"
|
| 355 |
-
logger.error(error_msg,
|
| 356 |
raise RuntimeError(error_msg)
|
| 357 |
|
| 358 |
|
|
|
|
| 280 |
_validate_container_setup(container)
|
| 281 |
logger.info("Container validation completed")
|
| 282 |
except Exception as validation_error:
|
| 283 |
+
logger.error(f"Container validation failed: {validation_error}", exception=validation_error)
|
| 284 |
# For now, let's continue even if validation fails to see if the app works
|
| 285 |
logger.warning("Continuing despite validation failure...")
|
| 286 |
|
|
|
|
| 288 |
return container
|
| 289 |
|
| 290 |
except Exception as e:
|
| 291 |
+
logger.error(f"Failed to create configured container: {e}", exception=e)
|
| 292 |
raise
|
| 293 |
|
| 294 |
|
|
|
|
| 352 |
|
| 353 |
except Exception as e:
|
| 354 |
error_msg = f"Container validation failed during service resolution: {e}"
|
| 355 |
+
logger.error(error_msg, exception=e)
|
| 356 |
raise RuntimeError(error_msg)
|
| 357 |
|
| 358 |
|
src/infrastructure/config/dependency_container.py
CHANGED
|
@@ -214,7 +214,7 @@ class DependencyContainer:
|
|
| 214 |
return result
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
-
logger.error(f"Failed to resolve service {service_type.__name__}: {e}",
|
| 218 |
raise
|
| 219 |
|
| 220 |
def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
|
|
@@ -260,7 +260,7 @@ class DependencyContainer:
|
|
| 260 |
logger.info(f"Factory function completed for {descriptor.service_type.__name__}")
|
| 261 |
return result
|
| 262 |
except Exception as e:
|
| 263 |
-
logger.error(f"Factory function failed for {descriptor.service_type.__name__}: {e}",
|
| 264 |
raise
|
| 265 |
|
| 266 |
# If implementation is a class
|
|
@@ -271,7 +271,7 @@ class DependencyContainer:
|
|
| 271 |
logger.info(f"Class instantiation completed for {descriptor.service_type.__name__}")
|
| 272 |
return result
|
| 273 |
except Exception as e:
|
| 274 |
-
logger.error(f"Class instantiation failed for {descriptor.service_type.__name__}: {e}",
|
| 275 |
raise
|
| 276 |
|
| 277 |
logger.error(f"Invalid implementation type for {descriptor.service_type.__name__}: {type(implementation)}")
|
|
@@ -312,7 +312,14 @@ class DependencyContainer:
|
|
| 312 |
factory = self.resolve(TTSProviderFactory)
|
| 313 |
|
| 314 |
if provider_name:
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
else:
|
| 317 |
preferred_providers = self._config.tts.preferred_providers
|
| 318 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
|
|
|
| 214 |
return result
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
+
logger.error(f"Failed to resolve service {service_type.__name__}: {e}", exception=e)
|
| 218 |
raise
|
| 219 |
|
| 220 |
def _create_singleton(self, service_type: Type[T], descriptor: ServiceDescriptor) -> T:
|
|
|
|
| 260 |
logger.info(f"Factory function completed for {descriptor.service_type.__name__}")
|
| 261 |
return result
|
| 262 |
except Exception as e:
|
| 263 |
+
logger.error(f"Factory function failed for {descriptor.service_type.__name__}: {e}", exception=e)
|
| 264 |
raise
|
| 265 |
|
| 266 |
# If implementation is a class
|
|
|
|
| 271 |
logger.info(f"Class instantiation completed for {descriptor.service_type.__name__}")
|
| 272 |
return result
|
| 273 |
except Exception as e:
|
| 274 |
+
logger.error(f"Class instantiation failed for {descriptor.service_type.__name__}: {e}", exception=e)
|
| 275 |
raise
|
| 276 |
|
| 277 |
logger.error(f"Invalid implementation type for {descriptor.service_type.__name__}: {type(implementation)}")
|
|
|
|
| 312 |
factory = self.resolve(TTSProviderFactory)
|
| 313 |
|
| 314 |
if provider_name:
|
| 315 |
+
try:
|
| 316 |
+
return factory.create_provider(provider_name, **kwargs)
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.warning(f"Failed to create specific TTS provider {provider_name}: {e}")
|
| 319 |
+
logger.info("Falling back to default provider selection")
|
| 320 |
+
# Fall back to default provider selection
|
| 321 |
+
preferred_providers = self._config.tts.preferred_providers
|
| 322 |
+
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
| 323 |
else:
|
| 324 |
preferred_providers = self._config.tts.preferred_providers
|
| 325 |
return factory.get_provider_with_fallback(preferred_providers, **kwargs)
|
src/infrastructure/tts/cosyvoice2_provider.py
CHANGED
|
@@ -61,13 +61,13 @@ class CosyVoice2TTSProvider(TTSProviderBase):
|
|
| 61 |
self.model = CosyVoice('pretrained_models/CosyVoice-300M')
|
| 62 |
logger.info("CosyVoice2 model successfully loaded")
|
| 63 |
except ImportError as e:
|
| 64 |
-
logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}",
|
| 65 |
self.model = None
|
| 66 |
except FileNotFoundError as e:
|
| 67 |
-
logger.error(f"Failed to load CosyVoice2 model files: {str(e)}",
|
| 68 |
self.model = None
|
| 69 |
except Exception as e:
|
| 70 |
-
logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}",
|
| 71 |
self.model = None
|
| 72 |
|
| 73 |
model_available = self.model is not None
|
|
@@ -144,7 +144,7 @@ class CosyVoice2TTSProvider(TTSProviderBase):
|
|
| 144 |
return audio_bytes, DEFAULT_SAMPLE_RATE
|
| 145 |
|
| 146 |
except Exception as e:
|
| 147 |
-
logger.error(f"CosyVoice2 audio generation failed: {str(e)}",
|
| 148 |
self._handle_provider_error(e, "audio generation")
|
| 149 |
|
| 150 |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
|
|
|
|
| 61 |
self.model = CosyVoice('pretrained_models/CosyVoice-300M')
|
| 62 |
logger.info("CosyVoice2 model successfully loaded")
|
| 63 |
except ImportError as e:
|
| 64 |
+
logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}", exception=e)
|
| 65 |
self.model = None
|
| 66 |
except FileNotFoundError as e:
|
| 67 |
+
logger.error(f"Failed to load CosyVoice2 model files: {str(e)}", exception=e)
|
| 68 |
self.model = None
|
| 69 |
except Exception as e:
|
| 70 |
+
logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}", exception=e)
|
| 71 |
self.model = None
|
| 72 |
|
| 73 |
model_available = self.model is not None
|
|
|
|
| 144 |
return audio_bytes, DEFAULT_SAMPLE_RATE
|
| 145 |
|
| 146 |
except Exception as e:
|
| 147 |
+
logger.error(f"CosyVoice2 audio generation failed: {str(e)}", exception=e)
|
| 148 |
self._handle_provider_error(e, "audio generation")
|
| 149 |
|
| 150 |
def _generate_audio_stream(self, request: 'SpeechSynthesisRequest') -> Iterator[tuple[bytes, int, bool]]:
|
src/infrastructure/tts/provider_factory.py
CHANGED
|
@@ -20,7 +20,7 @@ class TTSProviderFactory:
|
|
| 20 |
def _register_default_providers(self):
|
| 21 |
"""Register all available TTS providers."""
|
| 22 |
# Import providers dynamically to avoid import errors if dependencies are missing
|
| 23 |
-
|
| 24 |
# Always register dummy provider as fallback
|
| 25 |
from .dummy_provider import DummyTTSProvider
|
| 26 |
self._providers['dummy'] = DummyTTSProvider
|
|
@@ -39,7 +39,16 @@ class TTSProviderFactory:
|
|
| 39 |
self._providers['dia'] = DiaTTSProvider
|
| 40 |
logger.info("Registered Dia TTS provider")
|
| 41 |
except ImportError as e:
|
| 42 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Try to register CosyVoice2 provider
|
| 45 |
try:
|
|
@@ -68,10 +77,10 @@ class TTSProviderFactory:
|
|
| 68 |
# Check if provider is available
|
| 69 |
if self._provider_instances[name].is_available():
|
| 70 |
available.append(name)
|
| 71 |
-
|
| 72 |
except Exception as e:
|
| 73 |
logger.warning(f"Failed to check availability of {name} provider: {e}")
|
| 74 |
-
|
| 75 |
return available
|
| 76 |
|
| 77 |
def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
|
|
@@ -94,9 +103,15 @@ class TTSProviderFactory:
|
|
| 94 |
f"Unknown TTS provider: {provider_name}. Available providers: {available}"
|
| 95 |
)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
try:
|
| 98 |
provider_class = self._providers[provider_name]
|
| 99 |
-
|
| 100 |
# Create instance with appropriate parameters
|
| 101 |
if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
|
| 102 |
lang_code = kwargs.get('lang_code', 'z')
|
|
@@ -133,7 +148,7 @@ class TTSProviderFactory:
|
|
| 133 |
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
|
| 134 |
|
| 135 |
available_providers = self.get_available_providers()
|
| 136 |
-
|
| 137 |
# Try preferred providers in order
|
| 138 |
for provider_name in preferred_providers:
|
| 139 |
if provider_name in available_providers:
|
|
@@ -177,7 +192,7 @@ class TTSProviderFactory:
|
|
| 177 |
self._provider_instances[provider_name] = provider_class()
|
| 178 |
|
| 179 |
provider = self._provider_instances[provider_name]
|
| 180 |
-
|
| 181 |
return {
|
| 182 |
"available": provider.is_available(),
|
| 183 |
"name": provider.provider_name,
|
|
@@ -199,6 +214,6 @@ class TTSProviderFactory:
|
|
| 199 |
provider._cleanup_temp_files()
|
| 200 |
except Exception as e:
|
| 201 |
logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
|
| 202 |
-
|
| 203 |
self._provider_instances.clear()
|
| 204 |
logger.info("Cleaned up TTS provider instances")
|
|
|
|
| 20 |
def _register_default_providers(self):
|
| 21 |
"""Register all available TTS providers."""
|
| 22 |
# Import providers dynamically to avoid import errors if dependencies are missing
|
| 23 |
+
|
| 24 |
# Always register dummy provider as fallback
|
| 25 |
from .dummy_provider import DummyTTSProvider
|
| 26 |
self._providers['dummy'] = DummyTTSProvider
|
|
|
|
| 39 |
self._providers['dia'] = DiaTTSProvider
|
| 40 |
logger.info("Registered Dia TTS provider")
|
| 41 |
except ImportError as e:
|
| 42 |
+
logger.warning(f"Dia TTS provider not available: {e}")
|
| 43 |
+
# Still register it so it can attempt installation later
|
| 44 |
+
try:
|
| 45 |
+
from .dia_provider import DiaTTSProvider
|
| 46 |
+
self._providers['dia'] = DiaTTSProvider
|
| 47 |
+
logger.info("Registered Dia TTS provider (dependencies may be installed on demand)")
|
| 48 |
+
except Exception:
|
| 49 |
+
logger.warning("Failed to register Dia TTS provider")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.warning(f"Failed to register Dia TTS provider: {e}")
|
| 52 |
|
| 53 |
# Try to register CosyVoice2 provider
|
| 54 |
try:
|
|
|
|
| 77 |
# Check if provider is available
|
| 78 |
if self._provider_instances[name].is_available():
|
| 79 |
available.append(name)
|
| 80 |
+
|
| 81 |
except Exception as e:
|
| 82 |
logger.warning(f"Failed to check availability of {name} provider: {e}")
|
| 83 |
+
|
| 84 |
return available
|
| 85 |
|
| 86 |
def create_provider(self, provider_name: str, **kwargs) -> TTSProviderBase:
|
|
|
|
| 103 |
f"Unknown TTS provider: {provider_name}. Available providers: {available}"
|
| 104 |
)
|
| 105 |
|
| 106 |
+
# Check if provider is actually available before creating
|
| 107 |
+
available_providers = self.get_available_providers()
|
| 108 |
+
if provider_name not in available_providers:
|
| 109 |
+
logger.warning(f"TTS provider {provider_name} is registered but not available")
|
| 110 |
+
raise SpeechSynthesisException(f"TTS provider {provider_name} is not available")
|
| 111 |
+
|
| 112 |
try:
|
| 113 |
provider_class = self._providers[provider_name]
|
| 114 |
+
|
| 115 |
# Create instance with appropriate parameters
|
| 116 |
if provider_name in ['kokoro', 'dia', 'cosyvoice2']:
|
| 117 |
lang_code = kwargs.get('lang_code', 'z')
|
|
|
|
| 148 |
preferred_providers = ['kokoro', 'dia', 'cosyvoice2', 'dummy']
|
| 149 |
|
| 150 |
available_providers = self.get_available_providers()
|
| 151 |
+
|
| 152 |
# Try preferred providers in order
|
| 153 |
for provider_name in preferred_providers:
|
| 154 |
if provider_name in available_providers:
|
|
|
|
| 192 |
self._provider_instances[provider_name] = provider_class()
|
| 193 |
|
| 194 |
provider = self._provider_instances[provider_name]
|
| 195 |
+
|
| 196 |
return {
|
| 197 |
"available": provider.is_available(),
|
| 198 |
"name": provider.provider_name,
|
|
|
|
| 214 |
provider._cleanup_temp_files()
|
| 215 |
except Exception as e:
|
| 216 |
logger.warning(f"Failed to cleanup provider {provider.provider_name}: {e}")
|
| 217 |
+
|
| 218 |
self._provider_instances.clear()
|
| 219 |
logger.info("Cleaned up TTS provider instances")
|