dream2589632147 commited on
Commit
042a8b6
·
verified ·
1 Parent(s): dccbdc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -37,7 +37,7 @@ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
37
  # =========================================================
38
  print("Loading pipeline components...")
39
 
40
- # تحميل المكونات أولاً بدون نقلها للـ GPU لتوفير الذاكرة أثناء التحميل
41
  transformer = WanTransformer3DModel.from_pretrained(
42
  MODEL_ID,
43
  subfolder="transformer",
@@ -61,7 +61,6 @@ pipe = WanImageToVideoPipeline.from_pretrained(
61
  token=HF_TOKEN
62
  )
63
 
64
- # نقل الموديل إلى CUDA الآن
65
  print("Moving to CUDA...")
66
  pipe = pipe.to("cuda")
67
 
@@ -94,7 +93,6 @@ except Exception as e:
94
  # QUANTIZATION & AOT OPTIMIZATION
95
  # =========================================================
96
  print("Applying quantization...")
97
- # تنظيف الذاكرة قبل العمليات الثقيلة
98
  torch.cuda.empty_cache()
99
  gc.collect()
100
 
@@ -164,10 +162,9 @@ def get_num_frames(duration_seconds: float):
164
  # =========================================================
165
  # MAIN GENERATION FUNCTION
166
  # =========================================================
167
- # زيادة الوقت المسموح به إلى 180 ثانية لتجنب التايم أوت
168
  @spaces.GPU(duration=180)
169
  def generate_video(
170
- input_image,
171
  prompt,
172
  steps=4,
173
  negative_prompt=default_negative_prompt,
@@ -178,14 +175,22 @@ def generate_video(
178
  randomize_seed=False,
179
  progress=gr.Progress(track_tqdm=True),
180
  ):
181
- # تنظيف الذاكرة في بداية الدالة
182
  gc.collect()
183
  torch.cuda.empty_cache()
184
 
185
  try:
186
- if input_image is None:
 
187
  raise gr.Error("Please upload an input image.")
188
-
 
 
 
 
 
 
 
189
  num_frames = get_num_frames(duration_seconds)
190
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
191
  resized_image = resize_image(input_image)
@@ -210,23 +215,23 @@ def generate_video(
210
 
211
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
212
 
213
- # تنظيف الذاكرة بعد الانتهاء
214
  del output_frames_list
 
 
215
  torch.cuda.empty_cache()
216
 
217
  return video_path, current_seed
218
 
219
  except Exception as e:
220
- # طباعة الخطأ الحقيقي في الكونسول
221
  print(f"Error during generation: {e}")
222
- # إعادة رفع الخطأ ليظهر للمستخدم
223
  raise gr.Error(f"Generation failed: {str(e)}")
224
 
225
  # =========================================================
226
  # GRADIO UI
227
  # =========================================================
228
 
229
- # تعريف كود جوجل أناليتكس ليتم حقنه
230
  ga_script = """
231
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1TD40BVM04"></script>
232
  <script>
@@ -281,7 +286,8 @@ with gr.Blocks(theme=gr.themes.Soft(), head=ga_script) as demo:
281
 
282
  with gr.Row():
283
  with gr.Column():
284
- input_image_component = gr.Image(type="pil", label="Input Image")
 
285
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
286
 
287
  duration_seconds_input = gr.Slider(
 
37
  # =========================================================
38
  print("Loading pipeline components...")
39
 
40
+ # Load models in bfloat16
41
  transformer = WanTransformer3DModel.from_pretrained(
42
  MODEL_ID,
43
  subfolder="transformer",
 
61
  token=HF_TOKEN
62
  )
63
 
 
64
  print("Moving to CUDA...")
65
  pipe = pipe.to("cuda")
66
 
 
93
  # QUANTIZATION & AOT OPTIMIZATION
94
  # =========================================================
95
  print("Applying quantization...")
 
96
  torch.cuda.empty_cache()
97
  gc.collect()
98
 
 
162
  # =========================================================
163
  # MAIN GENERATION FUNCTION
164
  # =========================================================
 
165
  @spaces.GPU(duration=180)
166
  def generate_video(
167
+ input_image_path, # Receives file path now, not PIL object
168
  prompt,
169
  steps=4,
170
  negative_prompt=default_negative_prompt,
 
175
  randomize_seed=False,
176
  progress=gr.Progress(track_tqdm=True),
177
  ):
178
+ # Cleanup memory
179
  gc.collect()
180
  torch.cuda.empty_cache()
181
 
182
  try:
183
+ # 1. Validation checks
184
+ if not input_image_path:
185
  raise gr.Error("Please upload an input image.")
186
+
187
+ if not os.path.exists(input_image_path):
188
+ raise gr.Error("Image file not found! Please re-upload the image.")
189
+
190
+ # 2. Manual Image Opening
191
+ # We open it inside the function to avoid connection timeouts
192
+ input_image = Image.open(input_image_path).convert("RGB")
193
+
194
  num_frames = get_num_frames(duration_seconds)
195
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
196
  resized_image = resize_image(input_image)
 
215
 
216
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
217
 
218
+ # Cleanup
219
  del output_frames_list
220
+ del input_image
221
+ del resized_image
222
  torch.cuda.empty_cache()
223
 
224
  return video_path, current_seed
225
 
226
  except Exception as e:
 
227
  print(f"Error during generation: {e}")
 
228
  raise gr.Error(f"Generation failed: {str(e)}")
229
 
230
  # =========================================================
231
  # GRADIO UI
232
  # =========================================================
233
 
234
+ # Google Analytics Script
235
  ga_script = """
236
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1TD40BVM04"></script>
237
  <script>
 
286
 
287
  with gr.Row():
288
  with gr.Column():
289
+ # CHANGE: type="filepath" fixes the file not found error
290
+ input_image_component = gr.Image(type="filepath", label="Input Image")
291
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
292
 
293
  duration_seconds_input = gr.Slider(