import base64 import ctypes import gc import inspect import json import mmap import os import shutil import signal import sys import time import warnings from collections import defaultdict from concurrent.futures import as_completed, ThreadPoolExecutor from contextlib import contextmanager, nullcontext from contextvars import copy_context from dataclasses import dataclass from datetime import timedelta from functools import lru_cache as cache, partial, wraps from importlib import metadata import importlib from queue import Empty, Queue as ThreadQueue from threading import Thread from types import ModuleType, SimpleNamespace from typing import ( Any, Callable, Dict, Generator, Generic, List, Literal, NamedTuple, Optional, Set, Tuple, Type, TypedDict, TypeVar, Union, overload ) from typing_extensions import ( assert_never, ParamSpec, TypeAlias, Unpack, get_args ) from pathlib import Path from packaging import version import gradio as gr import httpx from gradio.context import Context, LocalContext from gradio.helpers import Progress, TrackedIterable from gradio.queueing import Queue from pydantic import BaseModel warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") try: import torch from torch.utils.weak import WeakTensorKeyDictionary except ImportError: torch = None WeakTensorKeyDictionary = dict if torch and "weights_only" in inspect.signature(torch.load).parameters: _original_torch_load = torch.load @wraps(_original_torch_load) def patched_torch_load(*args, **kwargs): kwargs.setdefault("weights_only", False) return _original_torch_load(*args, **kwargs) torch.load = patched_torch_load try: from tqdm import tqdm as _tqdm except ImportError: _tqdm = None def boolean(value: str | None) -> bool: return value is not None and value.lower() in ("1", "t", "true") class Settings: def __init__(self): self.zero_gpu = boolean(os.getenv('SPACES_ZERO_GPU')) self.zero_device_api_url = os.getenv('SPACES_ZERO_DEVICE_API_URL') self.gradio_auto_wrap = boolean(os.getenv('SPACES_GRADIO_AUTO_WRAP')) self.zero_patch_torch_device = boolean(os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE')) self.zero_gpu_v2 = boolean(os.getenv('ZEROGPU_V2')) GPUSizeConfig = Literal['auto', 'medium', 'large'] self.zerogpu_size: Union[Literal['medium', 'large'], Literal['auto']] = os.getenv('ZEROGPU_SIZE', 'large') self.zerogpu_medium_size_threshold = int(os.getenv('ZEROGPU_MEDIUM_SIZE_THRESHOLD', 30 * 2**30)) ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors') self.zerogpu_offload_dir = os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT) self.zerogpu_proc_self_cgroup_path = os.getenv('ZEROGPU_PROC_SELF_CGROUP_PATH', '/proc/self/cgroup') self.zerogpu_cuda_device_name = os.getenv('ZEROGPU_CUDA_DEVICE_NAME', "NVIDIA H200 MIG 3g.71gb") self.zerogpu_cuda_total_memory = int(os.getenv('ZEROGPU_CUDA_TOTAL_MEMORY', 74625056768)) self.zerogpu_cuda_reserved_memory = int(os.getenv('ZEROGPU_CUDA_RESERVED_MEMORY', 0)) self.zerogpu_cuda_capability_major = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MAJOR', 9)) self.zerogpu_cuda_capability_minor = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MINOR', 0)) self.zerogpu_cuda_multi_processor_count = int(os.getenv('ZEROGPU_CUDA_MULTI_PROCESSOR_COUNT', 60)) Config = Settings() if Config.zero_gpu: if Config.zero_device_api_url is None: print("Error: SPACES_ZERO_DEVICE_API_URL environment variable must be set on ZeroGPU Spaces.", file=sys.stderr) GPUSizeConfig = Literal['auto', 'medium', 'large'] if Config.zerogpu_size not in get_args(GPUSizeConfig): print(f"Error: ZEROGPU_SIZE should be one of {', '.join(get_args(GPUSizeConfig))}", file=sys.stderr) T = TypeVar('T') @cache def self_cgroup_device_path() -> str: try: cgroup_content = Path(Config.zerogpu_proc_self_cgroup_path).read_text() for line in cgroup_content.strip().split('\n'): contents = line.split(':devices:') if len(contents) == 2: return contents[1] except Exception as e: print(f"Could not determine cgroup device path: {e}", file=sys.stderr) return "" class SimpleQueue(ThreadQueue[T]): def put(self, obj: T): try: super().put(obj) except Exception as e: print(f"Error in SimpleQueue.put: {e}", file=sys.stderr) def close(self): try: pass except Exception as e: print(f"Error closing SimpleQueue: {e}", file=sys.stderr) def wlock_release(self): try: pass except (ValueError, Exception): pass def drop_params(fn: Callable[[], T]) -> Callable[..., T]: def drop(*args, **kwargs): return fn() return drop def gradio_request_var(): try: from gradio.context import LocalContext return LocalContext.request except ImportError: print("Could not import Gradio LocalContext. Ensure Gradio version is at least 3.46.", file=sys.stderr) return None def malloc_trim(): try: ctypes.CDLL("libc.so.6").malloc_trim(0) except (OSError, AttributeError) as e: print(f"malloc_trim not available on this system: {e}", file=sys.stderr) debug = partial(print, 'SPACES_ZERO_GPU_DEBUG') def jwt_payload(token: str) -> dict[str, Any]: try: _, payload, _ = token.split('.') return json.loads(base64.urlsafe_b64decode(f'{payload}==')) except Exception as e: print(f"Error decoding JWT payload: {e}", file=sys.stderr) return {} if torch: @wraps(torch.empty_like) def empty_like_raw_alloc(tensor: torch.Tensor, **kwargs) -> torch.Tensor: empty = torch.empty_like(tensor, **{**kwargs, 'requires_grad': False}) if (nbytes := empty.untyped_storage().nbytes()) > 0: try: buffer = mmap.mmap(-1, nbytes, prot=mmap.PROT_READ | mmap.PROT_WRITE) buffer_torch = torch.frombuffer(buffer, dtype=torch.uint8) empty.set_(buffer_torch.untyped_storage(), 0, empty.shape, empty.stride()) except Exception as e: print(f"Failed to create mmap buffer for tensor: {e}", file=sys.stderr) empty.requires_grad_(kwargs.get('requires_grad', False)) return empty Params = Tuple[Tuple[object, ...], Dict[str, Any]] Res = TypeVar('Res') Param = ParamSpec('Param') class EmptyKwargs(TypedDict): pass @dataclass class OkResult(Generic[Res]): value: Res @dataclass class ExceptionResult: traceback: str error_cls: str @dataclass class AbortedResult: pass @dataclass class EndResult: pass @dataclass class GradioQueueEvent: method_name: str args: tuple[Any, ...] kwargs: dict[str, Any] RegularResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "GradioQueueEvent"] GeneratorResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "GradioQueueEvent"] YieldQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "AbortedResult"] Duration: TypeAlias = Union[int, timedelta] DynamicDuration: TypeAlias = Union[Duration, Callable[Param, Duration], None] if torch: class AliasId(NamedTuple): data_ptr: int dtype: torch.dtype shape: tuple[int, ...] stride: tuple[int, ...] @classmethod def from_tensor(cls, tensor: torch.Tensor): return cls( tensor.data_ptr(), tensor.dtype, tensor.shape, tensor.stride(), ) AllowToken = str NvidiaIndex = int NvidiaUUID = str CGroupPath = str TaskId = int GPUSize = Literal['medium', 'large'] AuthLevel = Literal['regular', 'pro'] QueuingReason = Literal['node', 'concurrency'] AUTHENTICATED_HEADER = 'X-Authenticated' QUEUING_REASON_HEADER = 'X-Queuing-Reason' class ScheduleResponse(BaseModel): idle: bool nvidiaIndex: int nvidiaUUID: str allowToken: str class ScheduleMetadata(BaseModel): auth: Optional[AuthLevel] = None queuing_reason: Optional[QueuingReason] = None class QuotaInfos(BaseModel): left: int wait: timedelta class QueueEvent(BaseModel): event: Literal['ping', 'failed', 'succeeded'] data: Optional[ScheduleResponse] = None def sse_parse(text: str): event, *data = text.strip().splitlines() assert event.startswith('event:') event = event[6:].strip() if event in ('ping', 'failed'): return QueueEvent(event=event) assert event == 'succeeded' (data,) = data assert data.startswith('data:') data = data[5:].strip() return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data)) def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]: for text in res.iter_text(): if len(text) == 0: break try: yield sse_parse(text) except GeneratorExit: res.close() break except Exception as e: print(f"Error parsing SSE event: {e}", file=sys.stderr) continue class APIClient: def __init__(self, client: httpx.Client): self.client = client def startup_report(self, cgroup_path: str, gpu_size: GPUSize) -> httpx.codes: try: res = self.client.post('/startup-report', params={'cgroupPath': cgroup_path, 'gpuSize': gpu_size}) return httpx.codes(res.status_code) except Exception as e: print(f"Failed to send startup report: {e}", file=sys.stderr) return httpx.codes.INTERNAL_SERVER_ERROR def schedule(self, cgroup_path: str, task_id: int = 0, token: str | None = None, token_version: int = 1, duration_seconds: int = 0, enable_queue: bool = True): try: params: dict[str, str | int | bool] = {'cgroupPath': cgroup_path, 'taskId': task_id, 'enableQueue': enable_queue, 'tokenVersion': token_version, 'durationSeconds': duration_seconds} if token is not None: params['token'] = token req = self.client.build_request(method='POST', url='/schedule', params=params) res = self.client.send(req, stream=True) status = httpx.codes(res.status_code) auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER) queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER) metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason) if status is not httpx.codes.OK and status is not httpx.codes.TOO_MANY_REQUESTS: res.close() return status, metadata if "text/event-stream" in res.headers.get('content-type', ''): return sse_stream(res), metadata res.read() if status is httpx.codes.TOO_MANY_REQUESTS: return QuotaInfos(**res.json()), metadata if status is httpx.codes.OK: return ScheduleResponse(**res.json()), metadata assert_never(status) except Exception as e: print(f"Error in APIClient.schedule: {e}", file=sys.stderr) return httpx.codes.INTERNAL_SERVER_ERROR, ScheduleMetadata() def allow(self, allow_token: str, pid: int): try: res = self.client.post('/allow', params={'allowToken': allow_token, 'pid': pid}) return httpx.codes(res.status_code) except Exception as e: print(f"Error in APIClient.allow: {e}", file=sys.stderr) return httpx.codes.INTERNAL_SERVER_ERROR def release(self, allow_token: str, fail: bool = False) -> httpx.codes: try: res = self.client.post('/release', params={'allowToken': allow_token, 'fail': fail}) return httpx.codes(res.status_code) except Exception as e: print(f"Error in APIClient.release: {e}", file=sys.stderr) return httpx.codes.INTERNAL_SERVER_ERROR def get_queue_size(self) -> float: try: res = self.client.get('/queue-size') assert res.status_code == 200, res.status_code return res.json() except Exception as e: print(f"Error in APIClient.get_queue_size: {e}", file=sys.stderr) return 0.0 def remove_tqdm_multiprocessing_lock(): if _tqdm is None: return try: tqdm_lock = _tqdm.get_lock() if hasattr(tqdm_lock, 'locks'): pass except Exception as e: print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr) tqdm = _tqdm try: Success = gr.Success except AttributeError: Success = gr.Info Level: TypeAlias = "Literal['success', 'info', 'warning']" def modal(level: Level): if level == 'info': return gr.Info if level == 'success': return Success if level == 'warning': return gr.Warning return gr.Info class GradioPartialContext(NamedTuple): event_id: str | None in_event_listener: bool progress: Progress | None @staticmethod def get(): TrackedIterable.__reduce__ = tracked_iterable__reduce__ return GradioPartialContext( event_id=LocalContext.event_id.get(None), in_event_listener=LocalContext.in_event_listener.get(False), progress=LocalContext.progress.get(None), ) @staticmethod def apply(context: 'GradioPartialContext'): LocalContext.event_id.set(context.event_id) LocalContext.in_event_listener.set(context.in_event_listener) LocalContext.progress.set(context.progress) def get_queue_instance(): blocks = LocalContext.blocks.get(None) if blocks is None: return None return getattr(blocks, '_queue', None) def get_event(): queue = get_queue_instance() event_id = LocalContext.event_id.get(None) if queue is None or event_id is None: return None for job in getattr(queue, 'active_jobs', []): if job is None: continue for event in job: if getattr(event, '_id', None) == event_id: return event return None def get_server_port() -> int | None: from_request_context = True if (blocks := LocalContext.blocks.get(None)) is None: from_request_context = False if (blocks := Context.root_block) is None: return None if (server := getattr(blocks, "server", None)) is None: if from_request_context: warnings.warn("Gradio: No blocks.server inside a request") return -1 server_config = getattr(server, 'config', None) if isinstance(server_config, dict): return server_config.get('port') elif isinstance(server_config, Settings): warnings.warn("ZeroGPU: Gradio server.config appears to be the global ZeroGPU Config object. Cannot determine Gradio port from this object.") return None elif hasattr(server_config, 'port'): return server_config.port warnings.warn(f"ZeroGPU: Unexpected type for server.config ({type(server_config)}). Cannot determine Gradio port.") return None def try_process_queue_event(method_name: str, *args, **kwargs): queue = get_queue_instance() if queue is None: warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance") return method = getattr(queue, method_name, None) if callable(method): try: method(*args, **kwargs) except Exception as e: print(f"Error processing Gradio queue event '{method_name}': {e}", file=sys.stderr) QUEUE_RPC_METHODS = ["set_progress", "log_message"] def patch_gradio_queue(res_queue: Union[SimpleQueue[RegularResQueueResult | None], SimpleQueue[GeneratorResQueueResult | None]]): def rpc_method(method_name: str): def method(*args, **kwargs): if args and isinstance(args[0], Queue): args = args[1:] res_queue.put(GradioQueueEvent(method_name, args, kwargs)) return method for method_name in QUEUE_RPC_METHODS: if (method := getattr(Queue, method_name, None)) is None: warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute") continue if not callable(method): warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable") continue setattr(Queue, method_name, rpc_method(method_name)) TrackedIterable.__reduce__ = tracked_iterable__reduce__ def tracked_iterable__reduce__(self): try: res: tuple = super(TrackedIterable, self).__reduce__() cls, base, state, *_ = res return cls, base, {**state, **{'iterable': None, '_tqdm': None}} except Exception: return object, (), {} def supports_auth(): try: return version.parse(gr.__version__) >= version.Version('4.27.0') except Exception: return False Param_one_launch = ParamSpec('Param_one_launch') def one_launch(task: Callable[Param_one_launch, None], *task_args: Param_one_launch.args, **task_kwargs: Param_one_launch.kwargs): _launch = gr.Blocks.launch @wraps(gr.Blocks.launch) def launch(*args, **kwargs): task(*task_args, **task_kwargs) gr.Blocks.launch = _launch return gr.Blocks.launch(*args, **kwargs) gr.Blocks.launch = launch class HTMLError(gr.Error): def __str__(self): return str(self.message) def error(title: str, message: str, html: bool = False): print(f"ERROR: {title} - {message}", file=sys.stderr) error_cls = HTMLError if html else gr.Error params = inspect.signature(gr.Error).parameters kwargs: dict[str, Any] = {} if 'title' in params: kwargs['title'] = title if 'print_exception' in params: kwargs['print_exception'] = False try: pass except Exception: pass def info(title: str, message: str, level: Level = 'info'): print(f"INFO: {title} - {message}") info_cls = modal(level) params = inspect.signature(gr.Info).parameters kwargs: dict[str, Any] = {} if 'title' in params: kwargs['title'] = title try: info_cls(message, **kwargs) except Exception: pass TOKEN_HEADER = 'X-IP-Token' UNUSED_MESSAGE = "GPU device not used" NO_GPU_MESSAGE_REGULAR = "No GPU was available" NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60 seconds" EXAMPLES_RETRY_MESSAGE = "Try re-running outside of examples if it happened after clicking one" SIGNUP_ON_HF_TXT = "Create a free account" SIGNUP_ON_HF_URL = "https://huggingface.co/join" SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro" SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription" def api_client(): assert Config.zero_device_api_url is not None httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False) return APIClient(httpx_client) def startup_report_client(cgroup_path: str, gpu_size: GPUSize): retries, max_retries = 0, 2 client = api_client() status = None while retries <= max_retries: status = client.startup_report(cgroup_path, gpu_size) if status is not httpx.codes.NOT_FOUND: break time.sleep(1) retries += 1 if status is not httpx.codes.OK: print(f"Error while initializing ZeroGPU: status {status}", file=sys.stderr) def html_string(html_contents: str, text_contents: str): class HTMLString(str): def __str__(self): return text_contents return HTMLString(html_contents) def _toast_action(auth: AuthLevel | None, supports_html: bool, pro_message: str, unlogged_desc: str, logged_desc: str, ending: str) -> tuple[str, str]: if not supports_auth() or auth == 'pro': return pro_message, pro_message link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT desc = unlogged_desc if auth is None else logged_desc desc += f" {ending}." style = ";".join(["white-space: nowrap", "text-underline-offset: 2px", "color: var(--body-text-color)"]) html = f'{text} {desc}' markdown = f'[{text}]({link}) {desc}' return html, markdown def schedule(task_id: int, request: gr.Request | None = None, duration: timedelta = timedelta(0), _first_attempt: bool = True) -> Optional[ScheduleResponse]: try: gradio_version = version.parse(gr.__version__) if gradio_version.major < 4: print("ZeroGPU is only compatible with Gradio 4+", file=sys.stderr) return None except Exception: print("Could not parse Gradio version.", file=sys.stderr) return None GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39') GRADIO_HANDSHAKE = gradio_version >= version.Version('5.16.1') token, payload = _get_token_and_payload(request) if token is not None and (token_error := payload.get('error')): info("ZeroGPU client warning", f"Falling back to IP-based quotas ({token_error})", level='warning') duration_seconds = duration.seconds res, meta = api_client().schedule(cgroup_path=self_cgroup_device_path(), task_id=task_id, token=token, token_version=2 if GRADIO_HANDSHAKE else 1, duration_seconds=duration_seconds) if isinstance(res, ScheduleResponse): print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.") return res if isinstance(res, QuotaInfos): requested = duration.seconds message = "" if res.wait < timedelta(0): message = f"The requested GPU duration ({requested}s) is larger than the maximum allowed" elif token is None: message = f"Space app has reached its GPU limit. {EXAMPLES_RETRY_MESSAGE}" else: if payload.get('user') is None and res.wait == timedelta(0): message = "You have exceeded your runs limit." else: gpu = "Pro GPU" if meta.auth == 'pro' else ("free GPU" if meta.auth == 'regular' else "GPU") message = f"You have exceeded your {gpu} quota ({requested}s requested vs. {res.left}s left). Try again in {res.wait}" print(f"ZeroGPU quota exceeded: {message}", file=sys.stderr) return None if not isinstance(res, httpx.codes): if meta.queuing_reason in ('node', None): info("ZeroGPU queue", "Waiting for a GPU to become available") elif meta.queuing_reason == 'concurrency': info("ZeroGPU queue", "Waiting for a GPU slot on this Space") else: assert_never(meta.queuing_reason) connection_event = get_event() if connection_event is None and request is not None: warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance") while True: try: event = next(res) except StopIteration: print("Unexpected end of stream in schedule", file=sys.stderr) return None except httpx.RemoteProtocolError: if not _first_attempt: print("Error while re-trying after queue disconnect", file=sys.stderr) return None return schedule(task_id, request, duration, _first_attempt=False) except Exception as e: print(f"Error processing schedule event stream: {e}", file=sys.stderr) return None if event.event == 'ping': if connection_event is not None and not connection_event.alive: res.close() print("Connection closed by visitor while queueing", file=sys.stderr) return None continue if event.event == 'failed': if token is None: message = f"{NO_GPU_MESSAGE_INQUEUE}. {EXAMPLES_RETRY_MESSAGE}" else: _, details_markdown = _toast_action(auth=meta.auth, supports_html=GRADIO_HTML_TOASTS, pro_message="Retry later", unlogged_desc="to get a higher", logged_desc="to get the highest", ending="priority in ZeroGPU queues") message = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}" print(f"ZeroGPU queue timeout: {message}", file=sys.stderr) return None if event.event == 'succeeded': assert event.data is not None if connection_event is not None and not connection_event.alive: release(event.data.allowToken) print("Connection closed by visitor on queue success", file=sys.stderr) return None info("ZeroGPU queue", "Successfully acquired a GPU", level='success') print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.") return event.data if res is httpx.codes.SERVICE_UNAVAILABLE: print(f"ZeroGPU client error: {NO_GPU_MESSAGE_REGULAR}", file=sys.stderr) return None if res is httpx.codes.UNAUTHORIZED: print("ZeroGPU client error: Expired ZeroGPU proxy token", file=sys.stderr) return None reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown" print(f"ZeroGPU API /schedule error: {res} ({reason})", file=sys.stderr) return None def allow(allow_token: str) -> None: process_id = os.getpid() if process_id == 1: print("CRITICAL: Allowing PID 1 on ZeroGPU will end up killing your Space. Aborting.", file=sys.stderr) return if api_client().allow(allow_token=allow_token, pid=process_id) is not httpx.codes.OK: print(f"API call to /allow failed for token {allow_token}", file=sys.stderr) def release(allow_token: str, *, fail: bool = False, allow_404: bool = True) -> None: res = api_client().release(allow_token=allow_token, fail=fail) if res is httpx.codes.NO_CONTENT: try: info("ZeroGPU client warning", UNUSED_MESSAGE, level='warning') except AttributeError: pass warnings.warn(UNUSED_MESSAGE, RuntimeWarning) return if res is httpx.codes.NOT_FOUND: if not allow_404: warnings.warn("ZeroGPU API /release warning: 404 Not Found") return if httpx.codes.is_success(res): return reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown" print(f"ZeroGPU API /release error: {res} ({reason})", file=sys.stderr) def _get_token(request: gr.Request | None) -> str | None: if request is None: return None headers = getattr(request, 'headers', None) if headers is None or not hasattr(headers, '__dict__'): print("ZeroGPU client error: Internal Gradio error (headers not found)", file=sys.stderr) return None if not hasattr(headers, 'get'): headers = headers.__dict__ return headers.get(TOKEN_HEADER.lower()) def _get_token_and_payload(request: gr.Request | None) -> tuple[str | None, dict[str, Any]]: token = _get_token(request) if token is None: return None, {} payload = jwt_payload(token) return token, payload def compute_base_free_memory(total_memory: int) -> int: pytorch_base_memory = 309002240 return total_memory - pytorch_base_memory - Config.zerogpu_cuda_reserved_memory CUDA_DEVICE_NAME_STATIC = Config.zerogpu_cuda_device_name CUDA_TOTAL_MEMORY_STATIC = Config.zerogpu_cuda_total_memory CUDA_MEM_GET_INFO_STATIC = (compute_base_free_memory(CUDA_TOTAL_MEMORY_STATIC), CUDA_TOTAL_MEMORY_STATIC) CUDA_DEVICE_CAPABILITY_STATIC = (Config.zerogpu_cuda_capability_major, Config.zerogpu_cuda_capability_minor) CUDA_DEVICE_PROPERTIES_STATIC = SimpleNamespace(name=CUDA_DEVICE_NAME_STATIC, major=CUDA_DEVICE_CAPABILITY_STATIC[0], minor=CUDA_DEVICE_CAPABILITY_STATIC[1], total_memory=CUDA_TOTAL_MEMORY_STATIC, multi_processor_count=Config.zerogpu_cuda_multi_processor_count) if torch: class MockCudaRuntime: def setDevice(self, device): pass def getDevice(self): return 0 def deviceSynchronize(self): pass def deviceGetStreamPriorityRange(self): return 0, 0 cudart = MockCudaRuntime() if torch and torch.version.cuda.startswith("12."): CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "num_sync_all_streams": 0, "num_device_alloc": 0, "num_device_free": 0, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}} else: CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}} def cudaMemGetInfo(device: int, /): return CUDA_MEM_GET_INFO_STATIC PAGE_SIZE = 4096 try: TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') except (ValueError, AttributeError): TOTAL_MEMORY = 8 * (1024**3) VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2) BUFFER_SIZE = 128 * 2**20 BUFFER_COUNT = 2 if torch: TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]' if torch: @dataclass class ZeroGPUTensorPack: base_dir: str batches: list[list[TensorWithSizes]] big_tensors: list[list[TensorWithSizes]] fakes: dict[torch.Tensor, list[torch.Tensor]] total_size: int def path(self): return f'{self.base_dir}/{id(self)}' def __del__(self): try: os.remove(self.path()) except (FileNotFoundError, TypeError, AttributeError): pass def write_packing(fd: int, tensor: torch.Tensor): try: clone = torch.empty_like(tensor) size = clone.untyped_storage().size() buffer = torch.UntypedStorage(VM_MAX_SIZE) buffer_ptr = buffer.data_ptr() offset = -buffer_ptr % PAGE_SIZE padding = -size % PAGE_SIZE clone.set_(buffer[offset:offset + size], 0, clone.shape, clone.stride()) clone.copy_(tensor) mv = memoryview((ctypes.c_char * (size + padding)).from_address(buffer_ptr + offset)) written_bytes = 0 while written_bytes < size: written_bytes += os.write(fd, mv[written_bytes:]) except Exception as e: print(f"Error during tensor write packing: {e}", file=sys.stderr) def pack_tensors(tensors: set[torch.Tensor], fakes: dict[torch.Tensor, list[torch.Tensor]], offload_dir: str, callback: Callable[[int], None] | None = None): callback = (lambda b: None) if callback is None else callback batches: list[list[TensorWithSizes]] = [] big_tensors: list[list[TensorWithSizes]] = [] tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = [] for tensor in tensors: size = tensor.numel() * tensor.element_size() aligned_size = size + (-size % PAGE_SIZE) tensors_with_sizes.append((tensor, size, aligned_size)) current_batch, current_size = [], 0 for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]): if aligned_size > BUFFER_SIZE: big_tensors.append((tensor, size, aligned_size)) continue current_size += aligned_size if current_size > BUFFER_SIZE: batches.append(current_batch) current_batch, current_size = [(tensor, size, aligned_size)], aligned_size else: current_batch.append((tensor, size, aligned_size)) if current_batch: batches.append(current_batch) get_meta = {tensor: empty_like_raw_alloc(tensor) for tensor in tensors} batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches] big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors] fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()} pack = ZeroGPUTensorPack(base_dir=offload_dir, batches=batches_meta, big_tensors=big_tensors_meta, fakes=fakes_meta, total_size=sum([size for _, size, _ in tensors_with_sizes])) fd = -1 try: fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT) total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch]) total_asize += sum([aligned_size for *_, aligned_size in big_tensors]) if total_asize > 0: os.posix_fallocate(fd, 0, total_asize) for batch in batches: for tensor, size, _ in batch: write_packing(fd, tensor) callback(size) for tensor, size, _ in big_tensors: write_packing(fd, tensor) callback(size) return pack except Exception as e: print(f"Failed to pack tensors to disk: {e}", file=sys.stderr) return pack finally: if fd != -1: os.close(fd) def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int], None] | None = None): callback = (lambda b: None) if callback is None else callback free_buffers: ThreadQueue[torch.Tensor] = ThreadQueue() read_buffers: ThreadQueue[torch.Tensor] = ThreadQueue() for _ in range(BUFFER_COUNT): free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory()) def read(fd: int, buffer: torch.Tensor, size: int): mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr())) read_bytes = 0 while read_bytes < size: read_bytes += os.readv(fd, [mv[read_bytes:]]) def disk_to_pin(fd: int): for batch in pack.batches: buffer = free_buffers.get() batch_size = sum([aligned_size for *_, aligned_size in batch]) read(fd, buffer, batch_size) read_buffers.put(buffer) for *_, aligned_size in pack.big_tensors: read_bytes = 0 while read_bytes < aligned_size: buffer = free_buffers.get() read_size = min(BUFFER_SIZE, aligned_size - read_bytes) read(fd, buffer, read_size) read_buffers.put(buffer) read_bytes += read_size def pin_to_cuda(): total_duration_in_callback = 0 for batch in pack.batches: buffer = read_buffers.get() offset = 0 cuda_storages = [] for tensor, size, aligned_size in batch: cuda_storages.append(buffer[offset:offset + size].cuda(non_blocking=True)) offset += aligned_size torch.cuda.synchronize() free_buffers.put(buffer) batch_total_size = 0 for (tensor, size, _), cuda_storage in zip(batch, cuda_storages): cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda') cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride()) for fake in pack.fakes[tensor]: fake.data = cuda_tensor batch_total_size += size t0 = time.perf_counter() callback(batch_total_size) total_duration_in_callback += time.perf_counter() - t0 for tensor, size, _ in pack.big_tensors: cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda') offset = 0 while offset < size: buffer = read_buffers.get() read_size = min(BUFFER_SIZE, size - offset) cuda_storage[offset:offset + read_size] = buffer[:read_size] offset += read_size torch.cuda.synchronize() free_buffers.put(buffer) t0 = time.perf_counter() callback(read_size) total_duration_in_callback += time.perf_counter() - t0 cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda') cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride()) for fake in pack.fakes[tensor]: fake.data = cuda_tensor debug(f"{total_duration_in_callback=}") fd = -1 try: with ThreadPoolExecutor(2) as e: fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT) futures = [e.submit(copy_context().run, disk_to_pin, fd), e.submit(copy_context().run, pin_to_cuda)] for future in as_completed(futures): future.result() except Exception as e: print(f"Error during pack_to_cuda: {e}", file=sys.stderr) finally: if fd != -1: os.close(fd) @contextmanager def cuda_unavailable(torch_module: ModuleType): _is_available = torch_module.cuda.is_available torch_module.cuda.is_available = lambda: False yield torch_module.cuda.is_available = _is_available def maybe_import_bitsandbytes(): try: if torch is None: return None bnb_version = version.parse(metadata.version('bitsandbytes')) if bnb_version < version.parse('0.40.0'): print(f"Warning: ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})", file=sys.stderr) return None ctx_factory = (lambda: cuda_unavailable(torch)) if bnb_version < version.parse('0.43.1') else nullcontext with (ctx := ctx_factory()): importlib.import_module('bitsandbytes') if not isinstance(ctx, nullcontext): print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑", file=sys.stderr) return ctx_factory except (ImportError, metadata.PackageNotFoundError): return None except Exception as e: print(f"Unexpected error during bitsandbytes check: {e}", file=sys.stderr) return None bnb_import_context = maybe_import_bitsandbytes() if bnb_import_context and torch: from torch.utils.weak import WeakTensorKeyDictionary with (import_ctx := bnb_import_context()): CUDASetup = None if not isinstance(import_ctx, nullcontext): from bitsandbytes.cuda_setup.main import CUDASetup from bitsandbytes import cextension, functional from bitsandbytes.nn import Int8Params, Params4bit _param_to_8bit = Int8Params.to _param_cuda_8bit = Int8Params.cuda _param_to_4bit = Params4bit.to _param_cuda_4bit = Params4bit.cuda TensorToArgs_bnb = Tuple[torch.device, torch.dtype, bool, torch.memory_format] to_ops_8bit: dict[Int8Params, TensorToArgs_bnb | None] = WeakTensorKeyDictionary() to_ops_4bit: dict[Params4bit, TensorToArgs_bnb | None] = WeakTensorKeyDictionary() def _to_op_register_8bit(self: Int8Params, *args, **kwargs): parsed = torch._C._nn._parse_to(*args, **kwargs) device, *_ = parsed if not isinstance(device, torch.device) or device.type != 'cuda': return _param_to_8bit(self, *args, **kwargs) to_ops_8bit[self] = parsed return self def _to_op_register_4bit(self: Params4bit, *args, **kwargs): parsed = torch._C._nn._parse_to(*args, **kwargs) device, *_ = parsed if not isinstance(device, torch.device) or device.type != 'cuda': return _param_to_4bit(self, *args, **kwargs) to_ops_4bit[self] = parsed return self def _cuda_op_arg_check_bnb(device: Union[torch.device, int, str, None]) -> bool: if device is None or isinstance(device, int): return True if isinstance(device, str): device = torch.device(device) return device.type == 'cuda' def _cuda_op_register_8bit(self: Int8Params, device: Union[torch.device, int, str, None] = None, **kwargs): if not _cuda_op_arg_check_bnb(device): return _param_cuda_8bit(self, device, **kwargs) to_ops_8bit[self] = None return self def _cuda_op_register_4bit(self: Params4bit, device: Union[torch.device, int, str, None] = None, **kwargs): if not _cuda_op_arg_check_bnb(device): return _param_cuda_4bit(self, device, **kwargs) to_ops_4bit[self] = None return self def _patch_bnb(): Int8Params.to = _to_op_register_8bit Int8Params.cuda = _cuda_op_register_8bit Params4bit.to = _to_op_register_4bit Params4bit.cuda = _cuda_op_register_4bit def _unpatch_bnb(): Int8Params.to = _param_to_8bit Int8Params.cuda = _param_cuda_8bit Params4bit.to = _param_to_4bit Params4bit.cuda = _param_cuda_4bit def _move_bnb(): if CUDASetup is not None: CUDASetup._instance = None importlib.reload(cextension) functional.lib = cextension.lib for tensor, parsed_args in to_ops_8bit.items(): dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None) tensor.data = _param_to_8bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format) for tensor, parsed_args in to_ops_4bit.items(): dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None) tensor.data = _param_to_4bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format) else: def _patch_bnb(): pass def _unpatch_bnb(): pass def _move_bnb(): pass patch_bnb = _patch_bnb unpatch_bnb = _unpatch_bnb move_bnb = _move_bnb class _BitsAndBytesManager: def patch(self): return patch_bnb() def unpatch(self): return unpatch_bnb() def move(self): return move_bnb() if torch: PINNED_MEMORY_RATIO_LIMIT = 0.1 OPS_INPUTS_CHECK_NO_RETURN = (torch.Tensor.equal,) OPS_INPUT_CHECK_SELF_RETURN = (torch.Tensor.set_, torch.ops.aten.set_.source_Tensor) OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}" _tensor_make_subclass = torch.Tensor._make_subclass _asarray = torch.asarray _device = torch.device _cuda_init_v2 = torch._C._cuda_init _cuda_exchange_device = torch.cuda._exchange_device _cuda_available_v2 = torch.cuda.is_available _cuda_device_count_v2 = torch.cuda.device_count _cuda_current_device_v2 = torch.cuda.current_device _cuda_synchronize = torch.cuda.synchronize _cuda_get_device_capability_v2 = torch.cuda.get_device_capability _cuda_get_device_properties_v2 = torch.cuda.get_device_properties _cuda_get_device_name_v2 = torch.cuda.get_device_name _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict _cuda_cudart = torch.cuda.cudart _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None) cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() tensor_packs: list[ZeroGPUTensorPack] = [] class ZeroGPUTensor(torch.Tensor): pass def empty_fake(tensor: torch.Tensor): fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad) if fake.__class__ != tensor.__class__: fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) return fake def no_int_device(*args, **kwargs): if len(args) and isinstance(index := args[0], int): args = (f'cuda:{index}', *args[1:]) if isinstance(index := kwargs.get('device'), int): kwargs['device'] = f'cuda:{index}' return args, kwargs class ZeroGPUFunctionMode(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): kwargs = {} if kwargs is None else kwargs try: if func == torch._C._nn._parse_to: args, kwargs = no_int_device(*args, **kwargs) return func(*args, **kwargs) if func == torch.Tensor.cuda or func == torch.Tensor.cpu: memory_format = kwargs.get("memory_format") device_str = "cuda" if func == torch.Tensor.cuda else "cpu" to_kwargs = {"device": device_str} if memory_format is not None: to_kwargs["memory_format"] = memory_format return self.__torch_function__(torch.Tensor.to, types, (args[0],), to_kwargs) if func == torch.Tensor.to and len(args) > 1: parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs) device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) return self.__torch_function__(torch.Tensor.to, types, (args[0],), {'device': device, 'dtype': dtype, 'memory_format': memory_format}) if func == torch.Tensor.data.__set__: self_tensor, target = args if target in cuda_aliases: if (target_original := cuda_aliases[target]) is None: print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), target), file=sys.stderr) return original = empty_fake(self_tensor) original.data = target_original cuda_aliases[self_tensor] = original elif self_tensor in cuda_aliases: del cuda_aliases[self_tensor] self_tensor.data = target return if func == torch.Tensor.device.__get__: tensor, = args if tensor in cuda_aliases: return torch.device('cuda', index=0) elif func == torch.Tensor.__repr__: tensor, = args if tensor in cuda_aliases: original = cuda_aliases[tensor] or tensor.to('meta') original_class = original.__class__ original.__class__ = ZeroGPUTensor try: return func(original, **kwargs) finally: original.__class__ = original_class elif func == torch.Tensor.untyped_storage: tensor, = args if tensor in cuda_aliases: if (original := cuda_aliases[tensor]) is None: print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), tensor), file=sys.stderr) return None res = func(original, **kwargs) res._zerogpu = True return res cuda: bool | None = None if (device := kwargs.get('device')) is not None: device = torch.device(device) cuda = device.type == 'cuda' if cuda: kwargs['device'] = torch.device('cpu') swapped, inputs_are_cuda = {}, set() def swap(t: torch.Tensor): nonlocal inputs_are_cuda if t not in cuda_aliases: inputs_are_cuda.add(False) return t original = cuda_aliases[t] if original is None: print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), t), file=sys.stderr) return t swapped[original] = t inputs_are_cuda.add(True) return original args_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, args) kwargs_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, kwargs) if inputs_are_cuda == {True} and cuda is not False: cuda = True if len(args) == 1 and torch.utils._python_dispatch.is_traceable_wrapper_subclass(wt := args[0]): if func in {torch.Tensor.detach, torch.ops.aten.alias.default, torch.ops.aten.clone.default}: with self: return torch.utils._python_dispatch.transform_subclass(wt, lambda _, t: func(t)) res = func(*args_, **kwargs_) for original, fake in swapped.items(): fake.data = empty_fake(original) if func in {torch.ops.aten.index.Tensor, torch.Tensor.__getitem__}: cuda = args[0] in cuda_aliases inputs_are_cuda = {cuda} if (isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN) and not (func == torch.ops.aten.set_.source_Tensor and len(args_) == 3): st = args_[0] if len(args_) >= 1 and isinstance(args_[0], torch.Tensor) else None if (res is not st or func in OPS_INPUT_CHECK_SELF_RETURN) and inputs_are_cuda == {True, False}: print("RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 (ZeroGPU) and cpu!", file=sys.stderr) def register(t: torch.Tensor): if t in swapped and cuda is not False: return swapped[t] if cuda is not True: return t fake = empty_fake(t) cuda_aliases[fake] = t return fake return torch.utils._pytree.tree_map_only(torch.Tensor, register, res) except Exception as e: print(f"Error in ZeroGPUFunctionMode: {e}", file=sys.stderr) return func(*args, **kwargs) class DefaultDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): return func(*args, **(kwargs or {})) function_mode = ZeroGPUFunctionMode() dispatch_mode = DefaultDispatchMode() def _untyped_storage_new_register(*args, **kwargs): cuda = False if (device := kwargs.get('device')) is not None and device.type == 'cuda': cuda = True del kwargs['device'] storage = torch._C.StorageBase.__new__(*args, **kwargs) if cuda: storage._zerogpu = True return storage @property def _untyped_storage_device(self): if hasattr(self, '_zerogpu'): return torch.device('cuda', index=0) return torch._C.StorageBase.device.__get__(self) def _tensor_make_subclass_function_mode(*args, **kwargs): with torch._C.DisableTorchFunction(): return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs) def _asarray_function_mode(*args, **kwargs): with torch._C.DisableTorchFunction(): return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs) class _DeviceStringOnlyMeta(type): def __instancecheck__(cls, instance): return isinstance(instance, _device) class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta): def __new__(cls, *args, **kwargs): args, kwargs = no_int_device(*args, **kwargs) return _device(*args, **kwargs) def _cuda_init_raise_v2(): pass def _cuda_dummy_exchange_device(device): assert device in {-1, 0} return device def patch_v2(): function_mode.__enter__() dispatch_mode.__enter__() torch.Tensor._make_subclass = _tensor_make_subclass_function_mode torch.UntypedStorage.__new__ = _untyped_storage_new_register torch.UntypedStorage.device = _untyped_storage_device torch.asarray = _asarray_function_mode torch.device = _DeviceStringOnly torch._C._cuda_init = _cuda_init_raise_v2 torch.cuda._exchange_device = _cuda_dummy_exchange_device torch.cuda.is_available = lambda: True torch.cuda.device_count = lambda: 1 torch.cuda.current_device = lambda: 0 torch.cuda.synchronize = lambda *args: None torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_STATIC torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_STATIC torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_STATIC torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC torch.cuda.cudart = lambda: cudart if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) _BitsAndBytesManager().patch() def unpatch_v2(): from contextlib import suppress try: dispatch_mode.__exit__(None, None, None) function_mode.__exit__(None, None, None) except RuntimeError: pass torch.Tensor._make_subclass = _tensor_make_subclass torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__ torch.UntypedStorage.device = torch._C.StorageBase.device torch.asarray = _asarray torch.device = _device torch._C._cuda_init = _cuda_init_v2 torch.cuda._exchange_device = _cuda_exchange_device torch.cuda.is_available = _cuda_available_v2 torch.cuda.device_count = _cuda_device_count_v2 torch.cuda.current_device = _cuda_current_device_v2 torch.cuda.synchronize = _cuda_synchronize torch.cuda.get_device_capability = _cuda_get_device_capability_v2 torch.cuda.get_device_properties = _cuda_get_device_properties_v2 torch.cuda.get_device_name = _cuda_get_device_name_v2 torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict torch.cuda.cudart = _cuda_cudart if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) _BitsAndBytesManager().unpatch() def _total_unpacked_size(): tensors = [t for t in cuda_aliases.values() if t is not None] deduped = {AliasId.from_tensor(t): t for t in tensors} return sum([t.numel() * t.element_size() for t in deduped.values()]) def _pack_v2_internal(offload_dir: str): originals, originals_dedup, fakes = set(), {}, defaultdict(list) for fake, original in cuda_aliases.items(): if original is not None: original_id = AliasId.from_tensor(original) if original_id not in originals_dedup: originals_dedup[original_id] = original originals.add(original) fakes[originals_dedup[original_id]].append(fake) total_size = _total_unpacked_size() progress_context = tqdm(total=total_size, unit='B', unit_scale=True, desc="ZeroGPU tensors packing") if tqdm is not None and total_size > 0 else nullcontext() with progress_context as progress: update = progress.update if progress is not None else lambda _: None pack = pack_tensors(originals, fakes, offload_dir, callback=update) tensor_packs.append(pack) for fake_list in fakes.values(): for fake in fake_list: cuda_aliases[fake] = None return total_size def pack_v2(): total_size = _pack_v2_internal(Config.zerogpu_offload_dir) gc.collect() malloc_trim() return total_size def init_v2(nvidia_uuid: str): os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid torch.Tensor([0]).cuda() def size_v2(): return _total_unpacked_size() + sum([p.total_size for p in tensor_packs]) def _move_v2_internal(callback: Callable[[int], None] | None = None): cb = callback or (lambda _: None) pinned_limit, moved = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT, {} for fake, original in cuda_aliases.items(): if original is not None: original_id = AliasId.from_tensor(original) if original_id not in moved: use_pinned = original.numel() * original.element_size() < pinned_limit original_cuda = original.pin_memory().cuda(non_blocking=True) if use_pinned else original.cuda() moved[original_id] = original_cuda cb(fake.numel() * fake.element_size()) torch.cuda.synchronize() for fake, original in cuda_aliases.items(): if original is not None: fake.data = moved[AliasId.from_tensor(original)] for tensor_pack in tensor_packs: pack_to_cuda(tensor_pack, callback=cb) _BitsAndBytesManager().move() def move_v2(callback: Callable[[int], None] | None = None): cb = callback or (lambda _: None) with ThreadPoolExecutor(1) as e: e.submit(copy_context().run, _move_v2_internal, callback=cb).result() torch.cuda.synchronize() def is_in_bad_fork_v2(): return False CUDA_DEVICE_NAME_LEGACY, CUDA_TOTAL_MEMORY_LEGACY = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb', 42144366592 CUDA_MEM_GET_INFO_LEGACY = (41911451648, CUDA_TOTAL_MEMORY_LEGACY) CUDA_DEVICE_CAPABILITY_LEGACY = (8, 0) CUDA_DEVICE_PROPERTIES_LEGACY = SimpleNamespace(name=CUDA_DEVICE_NAME_LEGACY, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY_LEGACY, multi_processor_count=42) GENERIC_METHOD_NAMES = ['arange', 'as_tensor', 'asarray', 'bartlett_window', 'blackman_window', 'empty', 'empty_like', 'empty_strided', 'eye', 'full', 'full_like', 'hamming_window', 'hann_window', 'kaiser_window', 'linspace', 'logspace', 'ones', 'ones_like', 'rand', 'rand_like', 'randint', 'randint_like', 'randn', 'randn_like', 'randperm', 'range', 'sparse_bsc_tensor', 'sparse_bsr_tensor', 'sparse_compressed_tensor', 'sparse_coo_tensor', 'sparse_csc_tensor', 'sparse_csr_tensor', 'tensor', 'tril_indices', 'triu_indices', 'zeros', 'zeros_like'] TO_CUDA = (torch.device('cuda'), None, False, None) _tensor__deepcopy__, _tensor_to, _tensor_cuda, _tensor_cpu = torch.Tensor.__deepcopy__, torch.Tensor.to, torch.Tensor.cuda, torch.Tensor.cpu _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES} _cuda_init_legacy, _cuda_available_legacy, _cuda_device_count_legacy, _cuda_current_device_legacy = torch._C._cuda_init, torch.cuda.is_available, torch.cuda.device_count, torch.cuda.current_device _cuda_mem_get_info, _cuda_get_device_capability_legacy, _cuda_get_device_properties_legacy, _cuda_get_device_name_legacy = torch.cuda.mem_get_info, torch.cuda.get_device_capability, torch.cuda.get_device_properties, torch.cuda.get_device_name TensorToArgs_legacy = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]] to_ops: dict[torch.Tensor, TensorToArgs_legacy] = WeakTensorKeyDictionary() def _tensor_new_register(*args, **kwargs): new_tensor = torch._C._TensorBase.__new__(*args, **kwargs) if (base := getattr(new_tensor, '_base', None)) is not None and base in to_ops: to_ops[new_tensor] = to_ops[base] return new_tensor def _tensor_deepcopy_register(self: torch.Tensor, memo): new_tensor = _tensor__deepcopy__(self, memo) if isinstance(new_tensor, torch.Tensor) and self in to_ops: to_ops[new_tensor] = to_ops[self] return new_tensor @property def _tensor_device_property(self: torch.Tensor): if self in to_ops: return torch.device(type='cuda', index=0) del torch.Tensor.device try: return self.device finally: torch.Tensor.device = _tensor_device_property @property def _tensor_dtype_property(self: torch.Tensor): if self in to_ops and (to_dtype := to_ops[self][1]) is not None: return to_dtype del torch.Tensor.dtype try: return self.dtype finally: torch.Tensor.dtype = _tensor_dtype_property def _to_op_register(self: torch.Tensor, *args, **kwargs): parsed = torch._C._nn._parse_to(*args, **kwargs) device, dtype, *_ = parsed to_args = to_ops.pop(self, None) if device is None: if to_args is not None: to_ops[self] = (to_args[0], dtype, *to_args[2:]) return self return _tensor_to(self, *args, **kwargs) if device.type != 'cuda': if to_args is not None and (to_dtype := to_args[1]) is not None: kwargs = {'dtype': to_dtype, **kwargs} return _tensor_to(self, *args, **kwargs) to_ops[self] = parsed return self def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool: if device is None or isinstance(device, int): return True if isinstance(device, str): device = torch.device(device) return device.type == 'cuda' def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs): if not _cuda_op_arg_check(device): return _tensor_cuda(self, device, **kwargs) to_ops[self] = TO_CUDA return self def _cpu_op_remove(self: torch.Tensor, **kwargs): to_args = to_ops.pop(self, None) if to_args is not None and (to_dtype := to_args[1]) is not None: return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs}) return _tensor_cpu(self, **kwargs) def _cuda_init_raise_legacy(): pass def _generic_method_register(name: str, *args: Any, **kwargs: Any): try: device = torch.device(kwargs.get('device', "cpu")) except Exception: return _torch_generics[name](*args, **kwargs) if device.type != 'cuda': return _torch_generics[name](*args, **kwargs) tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"}) to_ops[tensor] = TO_CUDA return tensor def patch_legacy(): torch.Tensor.__deepcopy__ = _tensor_deepcopy_register torch.Tensor.__new__ = _tensor_new_register torch.Tensor.to = _to_op_register torch.Tensor.cuda = _cuda_op_register torch.Tensor.cpu = _cpu_op_remove if Config.zero_patch_torch_device: torch.Tensor.device = _tensor_device_property torch.Tensor.dtype = _tensor_dtype_property for name in GENERIC_METHOD_NAMES: setattr(torch, name, partial(_generic_method_register, name)) torch._C._cuda_init = _cuda_init_raise_legacy torch.cuda.is_available = lambda: True torch.cuda.device_count = lambda: 1 torch.cuda.current_device = lambda: 0 torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO_LEGACY torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_LEGACY torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_LEGACY torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_LEGACY _BitsAndBytesManager().patch() def unpatch_legacy(): from contextlib import suppress torch.Tensor.__deepcopy__ = _tensor__deepcopy__ with suppress(AttributeError): del torch.Tensor.__new__ torch.Tensor.to = _tensor_to torch.Tensor.cuda = _tensor_cuda torch.Tensor.cpu = _tensor_cpu with suppress(AttributeError): del torch.Tensor.device with suppress(AttributeError): del torch.Tensor.dtype for name in GENERIC_METHOD_NAMES: setattr(torch, name, _torch_generics[name]) torch._C._cuda_init = _cuda_init_legacy torch.cuda.is_available = _cuda_available_legacy torch.cuda.device_count = _cuda_device_count_legacy torch.cuda.current_device = _cuda_current_device_legacy torch.cuda.mem_get_info = _cuda_mem_get_info torch.cuda.get_device_capability = _cuda_get_device_capability_legacy torch.cuda.get_device_properties = _cuda_get_device_properties_legacy torch.cuda.get_device_name = _cuda_get_device_name_legacy _BitsAndBytesManager().unpatch() def pack_legacy(): return 0 def init_legacy(nvidia_uuid: str): os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid torch.Tensor([0]).cuda() def size_legacy(): return 0 def move_legacy(callback: Callable[[int], None] | None = None): for tensor, parsed_args in to_ops.items(): _, dtype, _, memory_format = parsed_args tensor.data = _tensor_to(tensor, device='cuda', dtype=dtype, memory_format=memory_format) _BitsAndBytesManager().move() torch.cuda.synchronize() def is_in_bad_fork_legacy(): return False if torch: try: num_threads = torch.get_num_threads() torch.set_num_interop_threads(num_threads) except RuntimeError: pass if Config.zero_gpu_v2: _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_v2, unpatch_v2, pack_v2, init_v2, size_v2, move_v2, is_in_bad_fork_v2 else: _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_legacy, unpatch_legacy, pack_legacy, init_legacy, size_legacy, move_legacy, is_in_bad_fork_legacy else: def _placeholder_func(*args, **kwargs): pass def _placeholder_zero(*args, **kwargs): return 0 def _placeholder_false(*args, **kwargs): return False _patch, _unpatch, _init, _move = _placeholder_func, _placeholder_func, _placeholder_func, _placeholder_func _pack, _size = _placeholder_zero, _placeholder_zero _is_in_bad_fork = _placeholder_false patch_torch, unpatch_torch, pack_torch, init_torch, size_torch, move_torch, is_in_bad_fork_torch = _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork _patch_torch_global = patch_torch _unpatch_torch_global = unpatch_torch GENERATOR_GLOBAL_TIMEOUT = 20 * 60 SPAWN_PROGRESS_CLEANUP, SPAWN_PROGRESS_INIT = 0.1, 0.1 forked = False class Worker(Generic[Res]): thread: Thread arg_queue: "SimpleQueue[tuple[Params, GradioPartialContext]]" res_queue: "SimpleQueue[Res | None]" _sentinel: "Thread" def __init__(self, task: Callable, is_generator: bool, allow_token: str, nvidia_uuid: str): self._sentinel = Thread(target=self._close_on_exit, daemon=True) self.arg_queue = SimpleQueue() self.res_queue = SimpleQueue() args = task, is_generator, self.arg_queue, self.res_queue, allow_token, nvidia_uuid, [] self.thread = Thread(target=self._worker_thread_wrapper, args=args, daemon=True) self.thread.start() self._sentinel.start() def _worker_thread_wrapper(self, task: Callable[..., Any], is_generator: bool, arg_queue: SimpleQueue[tuple[Params, GradioPartialContext]], res_queue: SimpleQueue[Any | None], allow_token: str, nvidia_uuid: str, fds: list[int]): global forked forked = True initialized = False while True: try: (args, kwargs), gradio_context = arg_queue.get() except (OSError, EOFError): break if not initialized: if (init_res := worker_init(res_queue=res_queue, allow_token=allow_token, nvidia_uuid=nvidia_uuid, fds=fds)) is not None: res_queue.put(init_res) return initialized = True GradioPartialContext.apply(gradio_context) context = copy_context() if is_generator: def iterate(): try: gen = task(*args, **kwargs) for res in gen: try: res_queue.put(OkResult(res)) except Exception as e: res_queue.put(exception_result(e)) break except Exception as e: res_queue.put(exception_result(e)) finally: res_queue.put(EndResult()) with ThreadPoolExecutor(1) as executor: executor.submit(context.run, iterate) else: def run_task(): try: res = OkResult(task(*args, **kwargs)) except Exception as e: res = exception_result(e) try: res_queue.put(res) except Exception as e: res_queue.put(exception_result(e)) with ThreadPoolExecutor(1) as executor: future = executor.submit(context.run, run_task) future.result() def _close_on_exit(self): self.thread.join() self.arg_queue.close() try: self.res_queue.wlock_release() except Exception: pass self.res_queue.put(None) def worker_init(res_queue: Union["SimpleQueue[RegularResQueueResult | None]", "SimpleQueue[GeneratorResQueueResult | None]"], allow_token: str, nvidia_uuid: str, fds: list[int]) -> Optional[ExceptionResult]: for fd in fds: try: os.close(fd) except Exception as e: if isinstance(e, OSError) and e.errno == 9: pass return exception_result(e) try: pass except Exception as e: print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr) progress_context = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w')) if tqdm is not None and Config.zero_gpu_v2 else nullcontext() try: patch_gradio_queue(res_queue) with progress_context as p_bar: current_progress = 0 def update(n: float): nonlocal current_progress current_progress += n if p_bar is not None and hasattr(p_bar, 'n'): p_bar.update(round(current_progress * 100) - p_bar.n) allow(allow_token) update(SPAWN_PROGRESS_CLEANUP) _unpatch_torch_global() init_torch(nvidia_uuid) update(SPAWN_PROGRESS_INIT) callback = None if (transfer_size := size_torch()) > 0: remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT) def _callback(n): return update(n * remaining / transfer_size) callback = _callback move_torch(callback=callback) _patch_torch_global() except Exception as e: return exception_result(e) return None def process_duration(duration: Duration | None) -> timedelta: return timedelta(seconds=0) def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs) -> timedelta: return timedelta(seconds=0) def exception_result(exc: Exception) -> ExceptionResult: formatted = "".join(list(map(str, sys.exc_info()))) return ExceptionResult(traceback=formatted, error_cls=exc.__class__.__name__) def regular_function_wrapper(task: Callable[Param, Res], duration: DynamicDuration[Param]) -> Callable[Param, Optional[Res]]: request_var_getter = gradio_request_var workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res] | None]] = {} task_id = id(task) @wraps(task) def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Optional[Res]: if forked: return task(*args, **kwargs) try: request_var = request_var_getter() request = request_var.get(None) if request_var else None duration_ = static_duration(duration, *args, **kwargs) schedule_response = schedule(task_id=task_id, request=request, duration=duration_) if schedule_response is None: pass allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID release_fn = partial(release, allow_token) worker = workers.pop(nvidia_index, None) if not (worker and worker.thread.is_alive() and schedule_response.idle): worker = Worker(task, False, allow_token, nvidia_uuid) worker.arg_queue.put(((args, kwargs), GradioPartialContext.get())) while True: res = worker.res_queue.get() if res is None: release_fn(fail=True, allow_404=True) pass if isinstance(res, ExceptionResult): release_fn(fail=True) pass if isinstance(res, OkResult): release_fn() workers[nvidia_index] = worker return res.value if isinstance(res, GradioQueueEvent): try_process_queue_event(res.method_name, *res.args, **res.kwargs) continue assert_never(res) except Exception as e: print(f"GPU process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr) _unpatch_torch_global() try: return task(*args, **kwargs) except Exception as cpu_e: print(f"CPU fallback execution also failed: {cpu_e}", file=sys.stderr) return None finally: _patch_torch_global() if not hasattr(task, '__annotations__'): gradio_handler.__annotations__ = {} return gradio_handler def generator_function_wrapper(task: Callable[Param, Generator[Res, None, None]], duration: DynamicDuration[Param]) -> Callable[Param, Generator[Res, None, None]]: request_var_getter = gradio_request_var workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res] | None]] = {} task_id = id(task) @wraps(task) def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]: if forked: yield from task(*args, **kwargs) return try: request_var = request_var_getter() request = request_var.get(None) if request_var else None duration_ = static_duration(duration, *args, **kwargs) schedule_response = schedule(task_id=task_id, request=request, duration=duration_) if schedule_response is None: pass allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID release_fn = partial(release, allow_token) worker = workers.pop(nvidia_index, None) if not (worker and worker.thread.is_alive() and schedule_response.idle): worker = Worker(task, True, allow_token, nvidia_uuid) worker.arg_queue.put(((args, kwargs), GradioPartialContext.get())) yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue() def fill_yield_queue(worker_instance): while True: res = worker_instance.res_queue.get() if res is None: release_fn(fail=True, allow_404=True) yield_queue.put(AbortedResult()) return if isinstance(res, ExceptionResult): release_fn(fail=True) yield_queue.put(res) return if isinstance(res, EndResult): release_fn() workers[nvidia_index] = worker_instance yield_queue.put(EndResult()) return if isinstance(res, OkResult): yield_queue.put(OkResult(res.value)) continue if isinstance(res, GradioQueueEvent): try_process_queue_event(res.method_name, *res.args, **res.kwargs) continue assert_never(res) with ThreadPoolExecutor(1) as e: e.submit(copy_context().run, fill_yield_queue, worker) while True: try: res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT) except Empty: pass if isinstance(res, AbortedResult): pass if isinstance(res, ExceptionResult): pass if isinstance(res, EndResult): return if isinstance(res, OkResult): yield res.value continue assert_never(res) except Exception as e: print(f"GPU generator process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr) _unpatch_torch_global() try: yield from task(*args, **kwargs) except Exception as cpu_e: print(f"CPU fallback execution for generator also failed: {cpu_e}", file=sys.stderr) finally: _patch_torch_global() if not hasattr(task, '__annotations__'): gradio_handler.__annotations__ = {} return gradio_handler P_decorator = ParamSpec('P_decorator') R_decorator = TypeVar('R_decorator') decorated_cache: dict[Callable, Callable] = {} @overload def GPU(task: None = None, *, duration: DynamicDuration[P_decorator] = 0) -> Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]: ... @overload def GPU(task: Callable[P_decorator, R_decorator], *, duration: DynamicDuration[P_decorator] = 0) -> Callable[P_decorator, R_decorator]: ... def GPU(task: Optional[Callable[P_decorator, R_decorator]] = None, *, duration: DynamicDuration[P_decorator] = 0, **kwargs: Unpack[EmptyKwargs]) -> Union[Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]: if "enable_queue" in kwargs: warnings.warn("`enable_queue` parameter is now ignored and always set to `True`") if task is None: return partial(_GPU, duration=duration) return _GPU(task, duration) def _GPU(task: Callable[P_decorator, R_decorator], duration: DynamicDuration[P_decorator]) -> Callable[P_decorator, R_decorator]: if not Config.zero_gpu: return task if sys.version_info.minor < 9: print("Error: Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+", file=sys.stderr) return task if task in decorated_cache: return decorated_cache[task] if inspect.iscoroutinefunction(task): print("Error: Coroutine functions are not supported by @spaces.GPU.", file=sys.stderr) return task if inspect.isgeneratorfunction(task): decorated = generator_function_wrapper(task, duration) else: decorated = regular_function_wrapper(task, duration) setattr(decorated, 'zerogpu', True) decorated_cache.update({task: decorated, decorated: decorated}) return decorated gradio_auto_wrap_enabled = Config.gradio_auto_wrap def disable_gradio_auto_wrap() -> None: global gradio_auto_wrap_enabled gradio_auto_wrap_enabled = False def enable_gradio_auto_wrap() -> None: global gradio_auto_wrap_enabled gradio_auto_wrap_enabled = True @overload def gradio_auto_wrap(task: Callable[Param, Res]) -> Callable[Param, Res]: ... @overload def gradio_auto_wrap(task: None) -> None: ... def gradio_auto_wrap(task: Optional[Callable[Param, Res]]) -> Optional[Callable[Param, Res]]: if not gradio_auto_wrap_enabled or not callable(task): return task if getattr(task, 'zerogpu', False): return task return GPU(task) def _patch_gradio_auto_wrap(): if not Config.zero_gpu or not Config.gradio_auto_wrap: return try: from gradio.blocks import Block _original_set_event_trigger = Block.set_event_trigger except (ImportError, AttributeError): print("Warning: Could not find gradio.blocks.Block.set_event_trigger for auto-wrap patching. Auto-wrap disabled.", file=sys.stderr) return @wraps(_original_set_event_trigger) def _new_set_event_trigger(self, event_name: str, fn: Union[Callable, List[Callable], None], inputs, outputs, **kwargs): if fn is None: return _original_set_event_trigger(self, event_name, fn, inputs, outputs, **kwargs) if isinstance(fn, list): wrapped_fns = [gradio_auto_wrap(f) for f in fn] return _original_set_event_trigger(self, event_name, wrapped_fns, inputs, outputs, **kwargs) else: wrapped_fn = gradio_auto_wrap(fn) return _original_set_event_trigger(self, event_name, wrapped_fn, inputs, outputs, **kwargs) Block.set_event_trigger = _new_set_event_trigger print("Gradio Block event trigger patched for ZeroGPU auto-wrap.", file=sys.stderr) if sys.version_info.minor < 8: print("Warning: Importing PySpaces requires Python 3.8+", file=sys.stderr) try: if (gr_module := sys.modules.get("gradio")) is not None: getattr(gr_module, 'Blocks') except AttributeError: print("ImportError: Gradio does not have 'Blocks' attribute. Please check your Gradio installation.", file=sys.stderr) pass def aoti_apply(compiled_fn: Any, module: Any): if torch is None: return module if hasattr(module, 'to') and isinstance(module, torch.nn.Module): module.to(device="cpu") return module __all__ = ["GPU", "gradio_auto_wrap", "disable_gradio_auto_wrap", "enable_gradio_auto_wrap", "aoti_apply"] if Config.zero_gpu: try: if is_in_bad_fork_torch(): pass except Exception as e: print(f"Could not check for bad fork: {e}", file=sys.stderr) def startup(): total_size = pack_torch() _patch_gradio_auto_wrap() if Config.zerogpu_size == 'auto': gpu_size = 'medium' if total_size < Config.zerogpu_medium_size_threshold else 'large' else: gpu_size = Config.zerogpu_size startup_report_client(self_cgroup_device_path(), gpu_size) _patch_torch_global() one_launch(startup) try: shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True) Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True) except Exception as e: print(f"Could not prepare ZeroGPU offload directory: {e}", file=sys.stderr)