Michael Hu commited on
Commit
c8d736e
·
1 Parent(s): 520c315

perf(stt): optimize device handling and use model.generate for inference

Browse files

This change improves performance by:
- Automatically detecting and using CUDA when available
- Moving inputs to the appropriate device before inference
- Using the model's built-in generate method instead of manual inference
- Loading models with proper device mapping and dtype configuration

src/infrastructure/stt/parakeet_provider.py CHANGED
@@ -28,6 +28,7 @@ class ParakeetSTTProvider(STTProviderBase):
28
  self.model = None
29
  self.processor = None
30
  self.current_model_name = None
 
31
 
32
  def _perform_transcription(self, audio_path: Path, model: str) -> str:
33
  """
@@ -57,13 +58,11 @@ class ParakeetSTTProvider(STTProviderBase):
57
  return_tensors="pt"
58
  )
59
 
60
- # Perform inference
61
- with torch.no_grad():
62
- logits = self.model(inputs.input_features).logits
63
 
64
  # Decode the predictions
65
- predicted_ids = torch.argmax(logits, dim=-1)
66
- transcription = self.processor.batch_decode(predicted_ids)[0]
67
 
68
  logger.info("Parakeet transcription completed successfully")
69
  return transcription
@@ -93,8 +92,10 @@ class ParakeetSTTProvider(STTProviderBase):
93
 
94
  # Load processor and model
95
  self.processor = AutoProcessor.from_pretrained(actual_model_name)
96
- self.model = AutoModelForCTC.from_pretrained(actual_model_name)
97
  self.current_model_name = model_name
 
 
98
 
99
  # Set model to evaluation mode
100
  self.model.eval()
 
28
  self.model = None
29
  self.processor = None
30
  self.current_model_name = None
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
  def _perform_transcription(self, audio_path: Path, model: str) -> str:
34
  """
 
58
  return_tensors="pt"
59
  )
60
 
61
+ inputs.to(self.device, dtype="auto")
 
 
62
 
63
  # Decode the predictions
64
+ outputs = this.model.generate(**inputs)
65
+ transcription = self.processor.batch_decode(outputs)
66
 
67
  logger.info("Parakeet transcription completed successfully")
68
  return transcription
 
92
 
93
  # Load processor and model
94
  self.processor = AutoProcessor.from_pretrained(actual_model_name)
95
+ self.model = AutoModelForCTC.from_pretrained(actual_model_name, dtype="auto", device_map=self.device)
96
  self.current_model_name = model_name
97
+ logger.info(f"Parakeet processor {processor}")
98
+ logger.info(f"Parakeet model {model}")
99
 
100
  # Set model to evaluation mode
101
  self.model.eval()