C4G-HKUST commited on
Commit
831eccf
·
1 Parent(s): ef30698

Fix CUDA device access in T5EncoderModel and add GPU check

Browse files
Files changed (2) hide show
  1. app.py +9 -0
  2. wan/modules/t5.py +7 -1
app.py CHANGED
@@ -16,6 +16,15 @@ import torch.distributed as dist
16
  from PIL import Image
17
  from huggingface_hub import snapshot_download
18
 
 
 
 
 
 
 
 
 
 
19
  # 导入 AnyTalker 相关的模块
20
  import wan
21
  from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS, MAX_AREA_CONFIGS
 
16
  from PIL import Image
17
  from huggingface_hub import snapshot_download
18
 
19
+ # 检查 GPU 可用性(参考 Meigen-MultiTalk)
20
+ is_gpu_available = torch.cuda.is_available()
21
+ if is_gpu_available:
22
+ # 初始化 CUDA,确保设备可用
23
+ try:
24
+ _ = torch.cuda.current_device()
25
+ except RuntimeError:
26
+ is_gpu_available = False
27
+
28
  # 导入 AnyTalker 相关的模块
29
  import wan
30
  from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS, MAX_AREA_CONFIGS
wan/modules/t5.py CHANGED
@@ -475,11 +475,17 @@ class T5EncoderModel:
475
  self,
476
  text_len,
477
  dtype=torch.bfloat16,
478
- device=torch.cuda.current_device(),
479
  checkpoint_path=None,
480
  tokenizer_path=None,
481
  shard_fn=None,
482
  ):
 
 
 
 
 
 
483
  self.text_len = text_len
484
  self.dtype = dtype
485
  self.device = device
 
475
  self,
476
  text_len,
477
  dtype=torch.bfloat16,
478
+ device=None,
479
  checkpoint_path=None,
480
  tokenizer_path=None,
481
  shard_fn=None,
482
  ):
483
+ # 延迟获取 CUDA 设备,避免在导入时访问
484
+ if device is None:
485
+ if torch.cuda.is_available():
486
+ device = torch.cuda.current_device()
487
+ else:
488
+ device = torch.device('cpu')
489
  self.text_len = text_len
490
  self.dtype = dtype
491
  self.device = device