ACloudCenter commited on
Commit
5dc3e05
·
1 Parent(s): 34c49bb

Add real-time progress streaming from Modal to UI

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. backend_modal/modal_runner.py +41 -17
app.py CHANGED
@@ -312,8 +312,8 @@ def create_demo_interface():
312
  speakers = speakers_and_params[:4]
313
  cfg_scale_val = speakers_and_params[4]
314
 
315
- # This is the call to the remote Modal function
316
- result, log = remote_generate_function.remote(
317
  num_speakers=int(num_speakers_val),
318
  script=script,
319
  speaker_1=speakers[0],
@@ -322,12 +322,15 @@ def create_demo_interface():
322
  speaker_4=speakers[3],
323
  cfg_scale=cfg_scale_val,
324
  model_name=model_choice
325
- )
326
- yield result, log
 
 
 
327
  except Exception as e:
328
  tb = traceback.format_exc()
329
  print(f"Error calling Modal: {e}")
330
- yield None, f"An error occurred in the Gradio wrapper: {e}\n\n{tb}"
331
 
332
  generate_btn.click(
333
  fn=generate_podcast_wrapper,
 
312
  speakers = speakers_and_params[:4]
313
  cfg_scale_val = speakers_and_params[4]
314
 
315
+ # Stream updates from the Modal function
316
+ for update in remote_generate_function.remote_gen(
317
  num_speakers=int(num_speakers_val),
318
  script=script,
319
  speaker_1=speakers[0],
 
322
  speaker_4=speakers[3],
323
  cfg_scale=cfg_scale_val,
324
  model_name=model_choice
325
+ ):
326
+ # Each update is a tuple (audio_or_none, log_message)
327
+ if update:
328
+ audio, log = update
329
+ yield audio, log
330
  except Exception as e:
331
  tb = traceback.format_exc()
332
  print(f"Error calling Modal: {e}")
333
+ yield None, f"An error occurred: {e}\n\n{tb}"
334
 
335
  generate_btn.click(
336
  fn=generate_podcast_wrapper,
backend_modal/modal_runner.py CHANGED
@@ -39,7 +39,7 @@ app = modal.App(
39
  )
40
 
41
 
42
- @app.cls(gpu="T4", scaledown_window=300)
43
  class VibeVoiceModel:
44
  def __init__(self):
45
  self.model_paths = {
@@ -53,13 +53,13 @@ class VibeVoiceModel:
53
  def load_models(self):
54
  """
55
  This method is run once when the container starts.
56
- It downloads and loads all models onto the GPU.
57
  """
58
  # Project-specific imports are moved here to run inside the container
59
  from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
60
  from processor.vibevoice_processor import VibeVoiceProcessor
61
 
62
- print("Entering container and loading models to GPU...")
63
 
64
  # Set compiler flags for better performance
65
  if torch.cuda.is_available() and hasattr(torch, '_inductor'):
@@ -71,7 +71,9 @@ class VibeVoiceModel:
71
 
72
  self.models = {}
73
  self.processors = {}
 
74
 
 
75
  for name, path in self.model_paths.items():
76
  print(f" - Loading {name} from {path}")
77
  proc = VibeVoiceProcessor.from_pretrained(path)
@@ -79,14 +81,24 @@ class VibeVoiceModel:
79
  path,
80
  torch_dtype=torch.bfloat16,
81
  attn_implementation="sdpa"
82
- ).to(self.device)
83
  mdl.eval()
84
  print(f" {name} loaded to {self.device}")
85
  self.processors[name] = proc
86
  self.models[name] = mdl
87
 
 
 
 
88
  self.setup_voice_presets()
89
  print("Model loading complete.")
 
 
 
 
 
 
 
90
 
91
  def setup_voice_presets(self):
92
  self.available_voices = {}
@@ -178,11 +190,18 @@ class VibeVoiceModel:
178
  speaker_4: str = None):
179
  """
180
  This is the main inference function that will be called from the Gradio app.
 
181
  """
182
  try:
 
 
183
  if model_name not in self.models:
184
  raise ValueError(f"Unknown model: {model_name}")
185
 
 
 
 
 
186
  model = self.models[model_name]
187
  processor = self.processors[model_name]
188
  model.set_ddpm_inference_steps(num_steps=self.inference_steps)
@@ -192,7 +211,7 @@ class VibeVoiceModel:
192
  if not script.strip():
193
  raise ValueError("Error: Please provide a script.")
194
 
195
- script = script.replace("", "'")
196
 
197
  if not 1 <= num_speakers <= 4:
198
  raise ValueError("Error: Number of speakers must be between 1 and 4.")
@@ -206,16 +225,19 @@ class VibeVoiceModel:
206
  log += f"Model: {model_name}\n"
207
  log += f"Parameters: CFG Scale={cfg_scale}\n"
208
  log += f"Speakers: {', '.join(selected_speakers)}\n"
 
 
209
 
210
  voice_samples = []
211
- for speaker_name in selected_speakers:
212
  audio_path = self.available_voices[speaker_name]
213
  audio_data = self.read_audio(audio_path)
214
  if len(audio_data) == 0:
215
  raise ValueError(f"Error: Failed to load audio for {speaker_name}")
216
  voice_samples.append(audio_data)
 
217
 
218
- log += f"Loaded {len(voice_samples)} voice samples\n"
219
 
220
  lines = script.strip().split('\n')
221
  formatted_script_lines = []
@@ -229,8 +251,8 @@ class VibeVoiceModel:
229
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
230
 
231
  formatted_script = '\n'.join(formatted_script_lines)
232
- log += f"Formatted script with {len(formatted_script_lines)} turns\n"
233
- log += "Processing with VibeVoice...\n"
234
 
235
  inputs = processor(
236
  text=[formatted_script],
@@ -240,6 +262,7 @@ class VibeVoiceModel:
240
  return_attention_mask=True,
241
  ).to(self.device)
242
 
 
243
  start_time = time.time()
244
 
245
  with torch.inference_mode():
@@ -253,6 +276,8 @@ class VibeVoiceModel:
253
  )
254
  generation_time = time.time() - start_time
255
 
 
 
256
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
257
  audio_tensor = outputs.speech_outputs[0]
258
  audio = audio_tensor.cpu().float().numpy()
@@ -264,16 +289,15 @@ class VibeVoiceModel:
264
 
265
  sample_rate = 24000
266
  total_duration = len(audio) / sample_rate
267
- log += f"Generation completed in {generation_time:.2f} seconds\n"
268
- log += f"Final audio duration: {total_duration:.2f} seconds\n"
269
 
270
- # Return the raw audio data and sample rate, Gradio will handle the rest
271
- return (sample_rate, audio), log
272
 
273
  except Exception as e:
274
  import traceback
275
- error_msg = f"An unexpected error occurred on Modal: {str(e)}\n{traceback.format_exc()}"
276
  print(error_msg)
277
- # Return a special value or raise an exception that the client can handle
278
- # For Gradio, returning a log message is often best.
279
- return None, error_msg
 
39
  )
40
 
41
 
42
+ @app.cls(gpu="A100-40GB", scaledown_window=300)
43
  class VibeVoiceModel:
44
  def __init__(self):
45
  self.model_paths = {
 
53
  def load_models(self):
54
  """
55
  This method is run once when the container starts.
56
+ With A10G (24GB), we can load both models to GPU.
57
  """
58
  # Project-specific imports are moved here to run inside the container
59
  from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
60
  from processor.vibevoice_processor import VibeVoiceProcessor
61
 
62
+ print("Entering container and loading models to GPU (A10G with 24GB)...")
63
 
64
  # Set compiler flags for better performance
65
  if torch.cuda.is_available() and hasattr(torch, '_inductor'):
 
71
 
72
  self.models = {}
73
  self.processors = {}
74
+ self.current_model_name = None
75
 
76
+ # Load all models directly to GPU (A10G has enough memory)
77
  for name, path in self.model_paths.items():
78
  print(f" - Loading {name} from {path}")
79
  proc = VibeVoiceProcessor.from_pretrained(path)
 
81
  path,
82
  torch_dtype=torch.bfloat16,
83
  attn_implementation="sdpa"
84
+ ).to(self.device) # Load directly to GPU
85
  mdl.eval()
86
  print(f" {name} loaded to {self.device}")
87
  self.processors[name] = proc
88
  self.models[name] = mdl
89
 
90
+ # Set default model
91
+ self.current_model_name = "VibeVoice-1.5B"
92
+
93
  self.setup_voice_presets()
94
  print("Model loading complete.")
95
+
96
+ def _place_model(self, target_name: str):
97
+ """
98
+ With A10G, both models stay on GPU. Just update the current model.
99
+ """
100
+ self.current_model_name = target_name
101
+ print(f"Switched to model {target_name}")
102
 
103
  def setup_voice_presets(self):
104
  self.available_voices = {}
 
190
  speaker_4: str = None):
191
  """
192
  This is the main inference function that will be called from the Gradio app.
193
+ Yields progress updates during generation.
194
  """
195
  try:
196
+ # Yield initial status
197
+ yield None, "🔄 Initializing generation..."
198
  if model_name not in self.models:
199
  raise ValueError(f"Unknown model: {model_name}")
200
 
201
+ # Move the selected model to GPU, others to CPU
202
+ yield None, "🔄 Loading model to GPU..."
203
+ self._place_model(model_name)
204
+
205
  model = self.models[model_name]
206
  processor = self.processors[model_name]
207
  model.set_ddpm_inference_steps(num_steps=self.inference_steps)
 
211
  if not script.strip():
212
  raise ValueError("Error: Please provide a script.")
213
 
214
+ script = script.replace("'", "'")
215
 
216
  if not 1 <= num_speakers <= 4:
217
  raise ValueError("Error: Number of speakers must be between 1 and 4.")
 
225
  log += f"Model: {model_name}\n"
226
  log += f"Parameters: CFG Scale={cfg_scale}\n"
227
  log += f"Speakers: {', '.join(selected_speakers)}\n"
228
+
229
+ yield None, log + "\n🔄 Loading voice samples..."
230
 
231
  voice_samples = []
232
+ for i, speaker_name in enumerate(selected_speakers):
233
  audio_path = self.available_voices[speaker_name]
234
  audio_data = self.read_audio(audio_path)
235
  if len(audio_data) == 0:
236
  raise ValueError(f"Error: Failed to load audio for {speaker_name}")
237
  voice_samples.append(audio_data)
238
+ yield None, log + f"\n✓ Loaded voice {i+1}/{len(selected_speakers)}: {speaker_name}"
239
 
240
+ log += f"\nLoaded {len(voice_samples)} voice samples"
241
 
242
  lines = script.strip().split('\n')
243
  formatted_script_lines = []
 
251
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
252
 
253
  formatted_script = '\n'.join(formatted_script_lines)
254
+ log += f"\nFormatted script with {len(formatted_script_lines)} turns"
255
+ yield None, log + "\n🔄 Processing script with VibeVoice..."
256
 
257
  inputs = processor(
258
  text=[formatted_script],
 
262
  return_attention_mask=True,
263
  ).to(self.device)
264
 
265
+ yield None, log + "\n🎯 Starting audio generation (this may take 1-2 minutes)..."
266
  start_time = time.time()
267
 
268
  with torch.inference_mode():
 
276
  )
277
  generation_time = time.time() - start_time
278
 
279
+ yield None, log + f"\n✓ Generation completed in {generation_time:.2f} seconds\n🔄 Processing audio..."
280
+
281
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
282
  audio_tensor = outputs.speech_outputs[0]
283
  audio = audio_tensor.cpu().float().numpy()
 
289
 
290
  sample_rate = 24000
291
  total_duration = len(audio) / sample_rate
292
+ log += f"\n✓ Generation completed in {generation_time:.2f} seconds"
293
+ log += f"\n✓ Audio duration: {total_duration:.2f} seconds"
294
 
295
+ # Final yield with both audio and complete log
296
+ yield (sample_rate, audio), log + "\n✅ Complete!"
297
 
298
  except Exception as e:
299
  import traceback
300
+ error_msg = f"An unexpected error occurred on Modal: {str(e)}\n{traceback.format_exc()}"
301
  print(error_msg)
302
+ # Yield error state
303
+ yield None, error_msg