Spaces:
Running
on
Zero
Running
on
Zero
Fix CUDA device fallback in WanVAE and WanAF2V
Browse files- wan/audio2video_multiID.py +5 -1
- wan/modules/vae.py +6 -0
wan/audio2video_multiID.py
CHANGED
|
@@ -67,7 +67,11 @@ class WanAF2V:
|
|
| 67 |
use_half (`bool`, *optional*, defaults to False):
|
| 68 |
Whether to use half precision (float16/bfloat16) for model inference. Reduces memory usage.
|
| 69 |
"""
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.config = config
|
| 72 |
self.rank = rank
|
| 73 |
self.t5_cpu = t5_cpu
|
|
|
|
| 67 |
use_half (`bool`, *optional*, defaults to False):
|
| 68 |
Whether to use half precision (float16/bfloat16) for model inference. Reduces memory usage.
|
| 69 |
"""
|
| 70 |
+
# 如果 CUDA 不可用,自动回退到 CPU
|
| 71 |
+
if torch.cuda.is_available():
|
| 72 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 73 |
+
else:
|
| 74 |
+
self.device = torch.device("cpu")
|
| 75 |
self.config = config
|
| 76 |
self.rank = rank
|
| 77 |
self.t5_cpu = t5_cpu
|
wan/modules/vae.py
CHANGED
|
@@ -624,6 +624,12 @@ class WanVAE:
|
|
| 624 |
dtype=torch.float,
|
| 625 |
device="cuda"):
|
| 626 |
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
self.device = device
|
| 628 |
|
| 629 |
mean = [
|
|
|
|
| 624 |
dtype=torch.float,
|
| 625 |
device="cuda"):
|
| 626 |
self.dtype = dtype
|
| 627 |
+
# 如果 device 是 "cuda" 但 CUDA 不可用,自动回退到 CPU
|
| 628 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 629 |
+
device = "cpu"
|
| 630 |
+
elif isinstance(device, str) and device.startswith("cuda"):
|
| 631 |
+
if not torch.cuda.is_available():
|
| 632 |
+
device = "cpu"
|
| 633 |
self.device = device
|
| 634 |
|
| 635 |
mean = [
|