C4G-HKUST commited on
Commit
7c629b9
·
1 Parent(s): 831eccf

Fix CUDA device fallback in WanVAE and WanAF2V

Browse files
Files changed (2) hide show
  1. wan/audio2video_multiID.py +5 -1
  2. wan/modules/vae.py +6 -0
wan/audio2video_multiID.py CHANGED
@@ -67,7 +67,11 @@ class WanAF2V:
67
  use_half (`bool`, *optional*, defaults to False):
68
  Whether to use half precision (float16/bfloat16) for model inference. Reduces memory usage.
69
  """
70
- self.device = torch.device(f"cuda:{device_id}")
 
 
 
 
71
  self.config = config
72
  self.rank = rank
73
  self.t5_cpu = t5_cpu
 
67
  use_half (`bool`, *optional*, defaults to False):
68
  Whether to use half precision (float16/bfloat16) for model inference. Reduces memory usage.
69
  """
70
+ # 如果 CUDA 不可用,自动回退到 CPU
71
+ if torch.cuda.is_available():
72
+ self.device = torch.device(f"cuda:{device_id}")
73
+ else:
74
+ self.device = torch.device("cpu")
75
  self.config = config
76
  self.rank = rank
77
  self.t5_cpu = t5_cpu
wan/modules/vae.py CHANGED
@@ -624,6 +624,12 @@ class WanVAE:
624
  dtype=torch.float,
625
  device="cuda"):
626
  self.dtype = dtype
 
 
 
 
 
 
627
  self.device = device
628
 
629
  mean = [
 
624
  dtype=torch.float,
625
  device="cuda"):
626
  self.dtype = dtype
627
+ # 如果 device 是 "cuda" 但 CUDA 不可用,自动回退到 CPU
628
+ if device == "cuda" and not torch.cuda.is_available():
629
+ device = "cpu"
630
+ elif isinstance(device, str) and device.startswith("cuda"):
631
+ if not torch.cuda.is_available():
632
+ device = "cpu"
633
  self.device = device
634
 
635
  mean = [