Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
c8d736e
1
Parent(s):
520c315
perf(stt): optimize device handling and use model.generate for inference
Browse filesThis change improves performance by:
- Automatically detecting and using CUDA when available
- Moving inputs to the appropriate device before inference
- Using the model's built-in generate method instead of manual inference
- Loading models with proper device mapping and dtype configuration
src/infrastructure/stt/parakeet_provider.py
CHANGED
|
@@ -28,6 +28,7 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
| 28 |
self.model = None
|
| 29 |
self.processor = None
|
| 30 |
self.current_model_name = None
|
|
|
|
| 31 |
|
| 32 |
def _perform_transcription(self, audio_path: Path, model: str) -> str:
|
| 33 |
"""
|
|
@@ -57,13 +58,11 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
| 57 |
return_tensors="pt"
|
| 58 |
)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
with torch.no_grad():
|
| 62 |
-
logits = self.model(inputs.input_features).logits
|
| 63 |
|
| 64 |
# Decode the predictions
|
| 65 |
-
|
| 66 |
-
transcription = self.processor.batch_decode(
|
| 67 |
|
| 68 |
logger.info("Parakeet transcription completed successfully")
|
| 69 |
return transcription
|
|
@@ -93,8 +92,10 @@ class ParakeetSTTProvider(STTProviderBase):
|
|
| 93 |
|
| 94 |
# Load processor and model
|
| 95 |
self.processor = AutoProcessor.from_pretrained(actual_model_name)
|
| 96 |
-
self.model = AutoModelForCTC.from_pretrained(actual_model_name)
|
| 97 |
self.current_model_name = model_name
|
|
|
|
|
|
|
| 98 |
|
| 99 |
# Set model to evaluation mode
|
| 100 |
self.model.eval()
|
|
|
|
| 28 |
self.model = None
|
| 29 |
self.processor = None
|
| 30 |
self.current_model_name = None
|
| 31 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
|
| 33 |
def _perform_transcription(self, audio_path: Path, model: str) -> str:
|
| 34 |
"""
|
|
|
|
| 58 |
return_tensors="pt"
|
| 59 |
)
|
| 60 |
|
| 61 |
+
inputs.to(self.device, dtype="auto")
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Decode the predictions
|
| 64 |
+
outputs = this.model.generate(**inputs)
|
| 65 |
+
transcription = self.processor.batch_decode(outputs)
|
| 66 |
|
| 67 |
logger.info("Parakeet transcription completed successfully")
|
| 68 |
return transcription
|
|
|
|
| 92 |
|
| 93 |
# Load processor and model
|
| 94 |
self.processor = AutoProcessor.from_pretrained(actual_model_name)
|
| 95 |
+
self.model = AutoModelForCTC.from_pretrained(actual_model_name, dtype="auto", device_map=self.device)
|
| 96 |
self.current_model_name = model_name
|
| 97 |
+
logger.info(f"Parakeet processor {processor}")
|
| 98 |
+
logger.info(f"Parakeet model {model}")
|
| 99 |
|
| 100 |
# Set model to evaluation mode
|
| 101 |
self.model.eval()
|