Spaces:
Running
on
Zero
Running
on
Zero
Fix CUDA device access in T5EncoderModel and add GPU check
Browse files- app.py +9 -0
- wan/modules/t5.py +7 -1
app.py
CHANGED
|
@@ -16,6 +16,15 @@ import torch.distributed as dist
|
|
| 16 |
from PIL import Image
|
| 17 |
from huggingface_hub import snapshot_download
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# 导入 AnyTalker 相关的模块
|
| 20 |
import wan
|
| 21 |
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS, MAX_AREA_CONFIGS
|
|
|
|
| 16 |
from PIL import Image
|
| 17 |
from huggingface_hub import snapshot_download
|
| 18 |
|
| 19 |
+
# 检查 GPU 可用性(参考 Meigen-MultiTalk)
|
| 20 |
+
is_gpu_available = torch.cuda.is_available()
|
| 21 |
+
if is_gpu_available:
|
| 22 |
+
# 初始化 CUDA,确保设备可用
|
| 23 |
+
try:
|
| 24 |
+
_ = torch.cuda.current_device()
|
| 25 |
+
except RuntimeError:
|
| 26 |
+
is_gpu_available = False
|
| 27 |
+
|
| 28 |
# 导入 AnyTalker 相关的模块
|
| 29 |
import wan
|
| 30 |
from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS, MAX_AREA_CONFIGS
|
wan/modules/t5.py
CHANGED
|
@@ -475,11 +475,17 @@ class T5EncoderModel:
|
|
| 475 |
self,
|
| 476 |
text_len,
|
| 477 |
dtype=torch.bfloat16,
|
| 478 |
-
device=
|
| 479 |
checkpoint_path=None,
|
| 480 |
tokenizer_path=None,
|
| 481 |
shard_fn=None,
|
| 482 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
self.text_len = text_len
|
| 484 |
self.dtype = dtype
|
| 485 |
self.device = device
|
|
|
|
| 475 |
self,
|
| 476 |
text_len,
|
| 477 |
dtype=torch.bfloat16,
|
| 478 |
+
device=None,
|
| 479 |
checkpoint_path=None,
|
| 480 |
tokenizer_path=None,
|
| 481 |
shard_fn=None,
|
| 482 |
):
|
| 483 |
+
# 延迟获取 CUDA 设备,避免在导入时访问
|
| 484 |
+
if device is None:
|
| 485 |
+
if torch.cuda.is_available():
|
| 486 |
+
device = torch.cuda.current_device()
|
| 487 |
+
else:
|
| 488 |
+
device = torch.device('cpu')
|
| 489 |
self.text_len = text_len
|
| 490 |
self.dtype = dtype
|
| 491 |
self.device = device
|