Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,7 +37,7 @@ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
|
|
| 37 |
# =========================================================
|
| 38 |
print("Loading pipeline components...")
|
| 39 |
|
| 40 |
-
#
|
| 41 |
transformer = WanTransformer3DModel.from_pretrained(
|
| 42 |
MODEL_ID,
|
| 43 |
subfolder="transformer",
|
|
@@ -61,7 +61,6 @@ pipe = WanImageToVideoPipeline.from_pretrained(
|
|
| 61 |
token=HF_TOKEN
|
| 62 |
)
|
| 63 |
|
| 64 |
-
# نقل الموديل إلى CUDA الآن
|
| 65 |
print("Moving to CUDA...")
|
| 66 |
pipe = pipe.to("cuda")
|
| 67 |
|
|
@@ -94,7 +93,6 @@ except Exception as e:
|
|
| 94 |
# QUANTIZATION & AOT OPTIMIZATION
|
| 95 |
# =========================================================
|
| 96 |
print("Applying quantization...")
|
| 97 |
-
# تنظيف الذاكرة قبل العمليات الثقيلة
|
| 98 |
torch.cuda.empty_cache()
|
| 99 |
gc.collect()
|
| 100 |
|
|
@@ -164,10 +162,9 @@ def get_num_frames(duration_seconds: float):
|
|
| 164 |
# =========================================================
|
| 165 |
# MAIN GENERATION FUNCTION
|
| 166 |
# =========================================================
|
| 167 |
-
# زيادة الوقت المسموح به إلى 180 ثانية لتجنب التايم أوت
|
| 168 |
@spaces.GPU(duration=180)
|
| 169 |
def generate_video(
|
| 170 |
-
|
| 171 |
prompt,
|
| 172 |
steps=4,
|
| 173 |
negative_prompt=default_negative_prompt,
|
|
@@ -178,14 +175,22 @@ def generate_video(
|
|
| 178 |
randomize_seed=False,
|
| 179 |
progress=gr.Progress(track_tqdm=True),
|
| 180 |
):
|
| 181 |
-
#
|
| 182 |
gc.collect()
|
| 183 |
torch.cuda.empty_cache()
|
| 184 |
|
| 185 |
try:
|
| 186 |
-
|
|
|
|
| 187 |
raise gr.Error("Please upload an input image.")
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
num_frames = get_num_frames(duration_seconds)
|
| 190 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 191 |
resized_image = resize_image(input_image)
|
|
@@ -210,23 +215,23 @@ def generate_video(
|
|
| 210 |
|
| 211 |
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
|
| 212 |
|
| 213 |
-
#
|
| 214 |
del output_frames_list
|
|
|
|
|
|
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
|
| 217 |
return video_path, current_seed
|
| 218 |
|
| 219 |
except Exception as e:
|
| 220 |
-
# طباعة الخطأ الحقيقي في الكونسول
|
| 221 |
print(f"Error during generation: {e}")
|
| 222 |
-
# إعادة رفع الخطأ ليظهر للمستخدم
|
| 223 |
raise gr.Error(f"Generation failed: {str(e)}")
|
| 224 |
|
| 225 |
# =========================================================
|
| 226 |
# GRADIO UI
|
| 227 |
# =========================================================
|
| 228 |
|
| 229 |
-
#
|
| 230 |
ga_script = """
|
| 231 |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1TD40BVM04"></script>
|
| 232 |
<script>
|
|
@@ -281,7 +286,8 @@ with gr.Blocks(theme=gr.themes.Soft(), head=ga_script) as demo:
|
|
| 281 |
|
| 282 |
with gr.Row():
|
| 283 |
with gr.Column():
|
| 284 |
-
|
|
|
|
| 285 |
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
|
| 286 |
|
| 287 |
duration_seconds_input = gr.Slider(
|
|
|
|
| 37 |
# =========================================================
|
| 38 |
print("Loading pipeline components...")
|
| 39 |
|
| 40 |
+
# Load models in bfloat16
|
| 41 |
transformer = WanTransformer3DModel.from_pretrained(
|
| 42 |
MODEL_ID,
|
| 43 |
subfolder="transformer",
|
|
|
|
| 61 |
token=HF_TOKEN
|
| 62 |
)
|
| 63 |
|
|
|
|
| 64 |
print("Moving to CUDA...")
|
| 65 |
pipe = pipe.to("cuda")
|
| 66 |
|
|
|
|
| 93 |
# QUANTIZATION & AOT OPTIMIZATION
|
| 94 |
# =========================================================
|
| 95 |
print("Applying quantization...")
|
|
|
|
| 96 |
torch.cuda.empty_cache()
|
| 97 |
gc.collect()
|
| 98 |
|
|
|
|
| 162 |
# =========================================================
|
| 163 |
# MAIN GENERATION FUNCTION
|
| 164 |
# =========================================================
|
|
|
|
| 165 |
@spaces.GPU(duration=180)
|
| 166 |
def generate_video(
|
| 167 |
+
input_image_path, # Receives file path now, not PIL object
|
| 168 |
prompt,
|
| 169 |
steps=4,
|
| 170 |
negative_prompt=default_negative_prompt,
|
|
|
|
| 175 |
randomize_seed=False,
|
| 176 |
progress=gr.Progress(track_tqdm=True),
|
| 177 |
):
|
| 178 |
+
# Cleanup memory
|
| 179 |
gc.collect()
|
| 180 |
torch.cuda.empty_cache()
|
| 181 |
|
| 182 |
try:
|
| 183 |
+
# 1. Validation checks
|
| 184 |
+
if not input_image_path:
|
| 185 |
raise gr.Error("Please upload an input image.")
|
| 186 |
+
|
| 187 |
+
if not os.path.exists(input_image_path):
|
| 188 |
+
raise gr.Error("Image file not found! Please re-upload the image.")
|
| 189 |
+
|
| 190 |
+
# 2. Manual Image Opening
|
| 191 |
+
# We open it inside the function to avoid connection timeouts
|
| 192 |
+
input_image = Image.open(input_image_path).convert("RGB")
|
| 193 |
+
|
| 194 |
num_frames = get_num_frames(duration_seconds)
|
| 195 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 196 |
resized_image = resize_image(input_image)
|
|
|
|
| 215 |
|
| 216 |
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
|
| 217 |
|
| 218 |
+
# Cleanup
|
| 219 |
del output_frames_list
|
| 220 |
+
del input_image
|
| 221 |
+
del resized_image
|
| 222 |
torch.cuda.empty_cache()
|
| 223 |
|
| 224 |
return video_path, current_seed
|
| 225 |
|
| 226 |
except Exception as e:
|
|
|
|
| 227 |
print(f"Error during generation: {e}")
|
|
|
|
| 228 |
raise gr.Error(f"Generation failed: {str(e)}")
|
| 229 |
|
| 230 |
# =========================================================
|
| 231 |
# GRADIO UI
|
| 232 |
# =========================================================
|
| 233 |
|
| 234 |
+
# Google Analytics Script
|
| 235 |
ga_script = """
|
| 236 |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1TD40BVM04"></script>
|
| 237 |
<script>
|
|
|
|
| 286 |
|
| 287 |
with gr.Row():
|
| 288 |
with gr.Column():
|
| 289 |
+
# CHANGE: type="filepath" fixes the file not found error
|
| 290 |
+
input_image_component = gr.Image(type="filepath", label="Input Image")
|
| 291 |
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
|
| 292 |
|
| 293 |
duration_seconds_input = gr.Slider(
|