Ignaciohhhhggfgjfrffd commited on
Commit
74d924f
·
verified ·
1 Parent(s): 7276778

Upload __init__.py

Browse files
Files changed (1) hide show
  1. __init__.py +1838 -0
__init__.py ADDED
@@ -0,0 +1,1838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import ctypes
3
+ import gc
4
+ import inspect
5
+ import json
6
+ import mmap
7
+ import os
8
+ import shutil
9
+ import signal
10
+ import sys
11
+ import time
12
+ import warnings
13
+ from collections import defaultdict
14
+ from concurrent.futures import as_completed, ThreadPoolExecutor
15
+ from contextlib import contextmanager, nullcontext
16
+ from contextvars import copy_context
17
+ from dataclasses import dataclass
18
+ from datetime import timedelta
19
+ from functools import lru_cache as cache, partial, wraps
20
+ from importlib import metadata
21
+ import importlib
22
+ from queue import Empty, Queue as ThreadQueue
23
+ from threading import Thread
24
+ from types import ModuleType, SimpleNamespace
25
+ from typing import (
26
+ Any, Callable, Dict, Generator, Generic, List, Literal, NamedTuple,
27
+ Optional, Set, Tuple, Type, TypedDict, TypeVar, Union, overload
28
+ )
29
+ from typing_extensions import (
30
+ assert_never, ParamSpec, TypeAlias, Unpack, get_args
31
+ )
32
+ from pathlib import Path
33
+ from packaging import version
34
+
35
+ import gradio as gr
36
+ import httpx
37
+ from gradio.context import Context, LocalContext
38
+ from gradio.helpers import Progress, TrackedIterable
39
+ from gradio.queueing import Queue
40
+ from pydantic import BaseModel
41
+
42
+ warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
43
+
44
+ try:
45
+ import torch
46
+ from torch.utils.weak import WeakTensorKeyDictionary
47
+ except ImportError:
48
+ torch = None
49
+ WeakTensorKeyDictionary = dict
50
+
51
+ if torch and "weights_only" in inspect.signature(torch.load).parameters:
52
+ _original_torch_load = torch.load
53
+ @wraps(_original_torch_load)
54
+ def patched_torch_load(*args, **kwargs):
55
+ kwargs.setdefault("weights_only", False)
56
+ return _original_torch_load(*args, **kwargs)
57
+ torch.load = patched_torch_load
58
+
59
+ try:
60
+ from tqdm import tqdm as _tqdm
61
+ except ImportError:
62
+ _tqdm = None
63
+
64
+ def boolean(value: str | None) -> bool:
65
+ return value is not None and value.lower() in ("1", "t", "true")
66
+
67
+ class Settings:
68
+ def __init__(self):
69
+ self.zero_gpu = boolean(os.getenv('SPACES_ZERO_GPU'))
70
+ self.zero_device_api_url = os.getenv('SPACES_ZERO_DEVICE_API_URL')
71
+ self.gradio_auto_wrap = boolean(os.getenv('SPACES_GRADIO_AUTO_WRAP'))
72
+ self.zero_patch_torch_device = boolean(os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
73
+ self.zero_gpu_v2 = boolean(os.getenv('ZEROGPU_V2'))
74
+ GPUSizeConfig = Literal['auto', 'medium', 'large']
75
+ self.zerogpu_size: Union[Literal['medium', 'large'], Literal['auto']] = os.getenv('ZEROGPU_SIZE', 'large')
76
+ self.zerogpu_medium_size_threshold = int(os.getenv('ZEROGPU_MEDIUM_SIZE_THRESHOLD', 30 * 2**30))
77
+ ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
78
+ self.zerogpu_offload_dir = os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT)
79
+ self.zerogpu_proc_self_cgroup_path = os.getenv('ZEROGPU_PROC_SELF_CGROUP_PATH', '/proc/self/cgroup')
80
+ self.zerogpu_cuda_device_name = os.getenv('ZEROGPU_CUDA_DEVICE_NAME', "NVIDIA H200 MIG 3g.71gb")
81
+ self.zerogpu_cuda_total_memory = int(os.getenv('ZEROGPU_CUDA_TOTAL_MEMORY', 74625056768))
82
+ self.zerogpu_cuda_reserved_memory = int(os.getenv('ZEROGPU_CUDA_RESERVED_MEMORY', 0))
83
+ self.zerogpu_cuda_capability_major = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MAJOR', 9))
84
+ self.zerogpu_cuda_capability_minor = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MINOR', 0))
85
+ self.zerogpu_cuda_multi_processor_count = int(os.getenv('ZEROGPU_CUDA_MULTI_PROCESSOR_COUNT', 60))
86
+
87
+ Config = Settings()
88
+
89
+ if Config.zero_gpu:
90
+ if Config.zero_device_api_url is None:
91
+ print("Error: SPACES_ZERO_DEVICE_API_URL environment variable must be set on ZeroGPU Spaces.", file=sys.stderr)
92
+ GPUSizeConfig = Literal['auto', 'medium', 'large']
93
+ if Config.zerogpu_size not in get_args(GPUSizeConfig):
94
+ print(f"Error: ZEROGPU_SIZE should be one of {', '.join(get_args(GPUSizeConfig))}", file=sys.stderr)
95
+
96
+ T = TypeVar('T')
97
+
98
+ @cache
99
+ def self_cgroup_device_path() -> str:
100
+ try:
101
+ cgroup_content = Path(Config.zerogpu_proc_self_cgroup_path).read_text()
102
+ for line in cgroup_content.strip().split('\n'):
103
+ contents = line.split(':devices:')
104
+ if len(contents) == 2:
105
+ return contents[1]
106
+ except Exception as e:
107
+ print(f"Could not determine cgroup device path: {e}", file=sys.stderr)
108
+ return ""
109
+
110
+ class SimpleQueue(ThreadQueue[T]):
111
+ def put(self, obj: T):
112
+ try:
113
+ super().put(obj)
114
+ except Exception as e:
115
+ print(f"Error in SimpleQueue.put: {e}", file=sys.stderr)
116
+
117
+ def close(self):
118
+ try:
119
+ pass
120
+ except Exception as e:
121
+ print(f"Error closing SimpleQueue: {e}", file=sys.stderr)
122
+
123
+ def wlock_release(self):
124
+ try:
125
+ pass
126
+ except (ValueError, Exception):
127
+ pass
128
+
129
+ def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
130
+ def drop(*args, **kwargs):
131
+ return fn()
132
+ return drop
133
+
134
+ def gradio_request_var():
135
+ try:
136
+ from gradio.context import LocalContext
137
+ return LocalContext.request
138
+ except ImportError:
139
+ print("Could not import Gradio LocalContext. Ensure Gradio version is at least 3.46.", file=sys.stderr)
140
+ return None
141
+
142
+ def malloc_trim():
143
+ try:
144
+ ctypes.CDLL("libc.so.6").malloc_trim(0)
145
+ except (OSError, AttributeError) as e:
146
+ print(f"malloc_trim not available on this system: {e}", file=sys.stderr)
147
+
148
+ debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
149
+
150
+ def jwt_payload(token: str) -> dict[str, Any]:
151
+ try:
152
+ _, payload, _ = token.split('.')
153
+ return json.loads(base64.urlsafe_b64decode(f'{payload}=='))
154
+ except Exception as e:
155
+ print(f"Error decoding JWT payload: {e}", file=sys.stderr)
156
+ return {}
157
+
158
+ if torch:
159
+ @wraps(torch.empty_like)
160
+ def empty_like_raw_alloc(tensor: torch.Tensor, **kwargs) -> torch.Tensor:
161
+ empty = torch.empty_like(tensor, **{**kwargs, 'requires_grad': False})
162
+ if (nbytes := empty.untyped_storage().nbytes()) > 0:
163
+ try:
164
+ buffer = mmap.mmap(-1, nbytes, prot=mmap.PROT_READ | mmap.PROT_WRITE)
165
+ buffer_torch = torch.frombuffer(buffer, dtype=torch.uint8)
166
+ empty.set_(buffer_torch.untyped_storage(), 0, empty.shape, empty.stride())
167
+ except Exception as e:
168
+ print(f"Failed to create mmap buffer for tensor: {e}", file=sys.stderr)
169
+ empty.requires_grad_(kwargs.get('requires_grad', False))
170
+ return empty
171
+
172
+ Params = Tuple[Tuple[object, ...], Dict[str, Any]]
173
+ Res = TypeVar('Res')
174
+ Param = ParamSpec('Param')
175
+
176
+ class EmptyKwargs(TypedDict):
177
+ pass
178
+
179
+ @dataclass
180
+ class OkResult(Generic[Res]):
181
+ value: Res
182
+
183
+ @dataclass
184
+ class ExceptionResult:
185
+ traceback: str
186
+ error_cls: str
187
+
188
+ @dataclass
189
+ class AbortedResult:
190
+ pass
191
+
192
+ @dataclass
193
+ class EndResult:
194
+ pass
195
+
196
+ @dataclass
197
+ class GradioQueueEvent:
198
+ method_name: str
199
+ args: tuple[Any, ...]
200
+ kwargs: dict[str, Any]
201
+
202
+ RegularResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "GradioQueueEvent"]
203
+ GeneratorResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "GradioQueueEvent"]
204
+ YieldQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "AbortedResult"]
205
+
206
+ Duration: TypeAlias = Union[int, timedelta]
207
+ DynamicDuration: TypeAlias = Union[Duration, Callable[Param, Duration], None]
208
+
209
+ if torch:
210
+ class AliasId(NamedTuple):
211
+ data_ptr: int
212
+ dtype: torch.dtype
213
+ shape: tuple[int, ...]
214
+ stride: tuple[int, ...]
215
+
216
+ @classmethod
217
+ def from_tensor(cls, tensor: torch.Tensor):
218
+ return cls(
219
+ tensor.data_ptr(),
220
+ tensor.dtype,
221
+ tensor.shape,
222
+ tensor.stride(),
223
+ )
224
+
225
+ AllowToken = str
226
+ NvidiaIndex = int
227
+ NvidiaUUID = str
228
+ CGroupPath = str
229
+ TaskId = int
230
+ GPUSize = Literal['medium', 'large']
231
+ AuthLevel = Literal['regular', 'pro']
232
+ QueuingReason = Literal['node', 'concurrency']
233
+
234
+ AUTHENTICATED_HEADER = 'X-Authenticated'
235
+ QUEUING_REASON_HEADER = 'X-Queuing-Reason'
236
+
237
+ class ScheduleResponse(BaseModel):
238
+ idle: bool
239
+ nvidiaIndex: int
240
+ nvidiaUUID: str
241
+ allowToken: str
242
+
243
+ class ScheduleMetadata(BaseModel):
244
+ auth: Optional[AuthLevel] = None
245
+ queuing_reason: Optional[QueuingReason] = None
246
+
247
+ class QuotaInfos(BaseModel):
248
+ left: int
249
+ wait: timedelta
250
+
251
+ class QueueEvent(BaseModel):
252
+ event: Literal['ping', 'failed', 'succeeded']
253
+ data: Optional[ScheduleResponse] = None
254
+
255
+ def sse_parse(text: str):
256
+ event, *data = text.strip().splitlines()
257
+ assert event.startswith('event:')
258
+ event = event[6:].strip()
259
+ if event in ('ping', 'failed'):
260
+ return QueueEvent(event=event)
261
+ assert event == 'succeeded'
262
+ (data,) = data
263
+ assert data.startswith('data:')
264
+ data = data[5:].strip()
265
+ return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
266
+
267
+ def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
268
+ for text in res.iter_text():
269
+ if len(text) == 0:
270
+ break
271
+ try:
272
+ yield sse_parse(text)
273
+ except GeneratorExit:
274
+ res.close()
275
+ break
276
+ except Exception as e:
277
+ print(f"Error parsing SSE event: {e}", file=sys.stderr)
278
+ continue
279
+
280
+ class APIClient:
281
+ def __init__(self, client: httpx.Client):
282
+ self.client = client
283
+
284
+ def startup_report(self, cgroup_path: str, gpu_size: GPUSize) -> httpx.codes:
285
+ try:
286
+ res = self.client.post('/startup-report', params={'cgroupPath': cgroup_path, 'gpuSize': gpu_size})
287
+ return httpx.codes(res.status_code)
288
+ except Exception as e:
289
+ print(f"Failed to send startup report: {e}", file=sys.stderr)
290
+ return httpx.codes.INTERNAL_SERVER_ERROR
291
+
292
+ def schedule(self, cgroup_path: str, task_id: int = 0, token: str | None = None, token_version: int = 1, duration_seconds: int = 0, enable_queue: bool = True):
293
+ try:
294
+ params: dict[str, str | int | bool] = {'cgroupPath': cgroup_path, 'taskId': task_id, 'enableQueue': enable_queue, 'tokenVersion': token_version, 'durationSeconds': duration_seconds}
295
+ if token is not None:
296
+ params['token'] = token
297
+ req = self.client.build_request(method='POST', url='/schedule', params=params)
298
+ res = self.client.send(req, stream=True)
299
+ status = httpx.codes(res.status_code)
300
+ auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
301
+ queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER)
302
+ metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason)
303
+ if status is not httpx.codes.OK and status is not httpx.codes.TOO_MANY_REQUESTS:
304
+ res.close()
305
+ return status, metadata
306
+ if "text/event-stream" in res.headers.get('content-type', ''):
307
+ return sse_stream(res), metadata
308
+ res.read()
309
+ if status is httpx.codes.TOO_MANY_REQUESTS:
310
+ return QuotaInfos(**res.json()), metadata
311
+ if status is httpx.codes.OK:
312
+ return ScheduleResponse(**res.json()), metadata
313
+ assert_never(status)
314
+ except Exception as e:
315
+ print(f"Error in APIClient.schedule: {e}", file=sys.stderr)
316
+ return httpx.codes.INTERNAL_SERVER_ERROR, ScheduleMetadata()
317
+
318
+ def allow(self, allow_token: str, pid: int):
319
+ try:
320
+ res = self.client.post('/allow', params={'allowToken': allow_token, 'pid': pid})
321
+ return httpx.codes(res.status_code)
322
+ except Exception as e:
323
+ print(f"Error in APIClient.allow: {e}", file=sys.stderr)
324
+ return httpx.codes.INTERNAL_SERVER_ERROR
325
+
326
+ def release(self, allow_token: str, fail: bool = False) -> httpx.codes:
327
+ try:
328
+ res = self.client.post('/release', params={'allowToken': allow_token, 'fail': fail})
329
+ return httpx.codes(res.status_code)
330
+ except Exception as e:
331
+ print(f"Error in APIClient.release: {e}", file=sys.stderr)
332
+ return httpx.codes.INTERNAL_SERVER_ERROR
333
+
334
+ def get_queue_size(self) -> float:
335
+ try:
336
+ res = self.client.get('/queue-size')
337
+ assert res.status_code == 200, res.status_code
338
+ return res.json()
339
+ except Exception as e:
340
+ print(f"Error in APIClient.get_queue_size: {e}", file=sys.stderr)
341
+ return 0.0
342
+
343
+ def remove_tqdm_multiprocessing_lock():
344
+ if _tqdm is None:
345
+ return
346
+ try:
347
+ tqdm_lock = _tqdm.get_lock()
348
+ if hasattr(tqdm_lock, 'locks'):
349
+ pass
350
+ except Exception as e:
351
+ print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr)
352
+
353
+ tqdm = _tqdm
354
+
355
+ try:
356
+ Success = gr.Success
357
+ except AttributeError:
358
+ Success = gr.Info
359
+
360
+ Level: TypeAlias = "Literal['success', 'info', 'warning']"
361
+
362
+ def modal(level: Level):
363
+ if level == 'info': return gr.Info
364
+ if level == 'success': return Success
365
+ if level == 'warning': return gr.Warning
366
+ return gr.Info
367
+
368
+ class GradioPartialContext(NamedTuple):
369
+ event_id: str | None
370
+ in_event_listener: bool
371
+ progress: Progress | None
372
+
373
+ @staticmethod
374
+ def get():
375
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
376
+ return GradioPartialContext(
377
+ event_id=LocalContext.event_id.get(None),
378
+ in_event_listener=LocalContext.in_event_listener.get(False),
379
+ progress=LocalContext.progress.get(None),
380
+ )
381
+
382
+ @staticmethod
383
+ def apply(context: 'GradioPartialContext'):
384
+ LocalContext.event_id.set(context.event_id)
385
+ LocalContext.in_event_listener.set(context.in_event_listener)
386
+ LocalContext.progress.set(context.progress)
387
+
388
+ def get_queue_instance():
389
+ blocks = LocalContext.blocks.get(None)
390
+ if blocks is None: return None
391
+ return getattr(blocks, '_queue', None)
392
+
393
+ def get_event():
394
+ queue = get_queue_instance()
395
+ event_id = LocalContext.event_id.get(None)
396
+ if queue is None or event_id is None: return None
397
+ for job in getattr(queue, 'active_jobs', []):
398
+ if job is None: continue
399
+ for event in job:
400
+ if getattr(event, '_id', None) == event_id:
401
+ return event
402
+ return None
403
+
404
+ def get_server_port() -> int | None:
405
+ from_request_context = True
406
+ if (blocks := LocalContext.blocks.get(None)) is None:
407
+ from_request_context = False
408
+ if (blocks := Context.root_block) is None: return None
409
+ if (server := getattr(blocks, "server", None)) is None:
410
+ if from_request_context:
411
+ warnings.warn("Gradio: No blocks.server inside a request")
412
+ return -1
413
+
414
+ server_config = getattr(server, 'config', None)
415
+
416
+ if isinstance(server_config, dict):
417
+ return server_config.get('port')
418
+ elif isinstance(server_config, Settings):
419
+ warnings.warn("ZeroGPU: Gradio server.config appears to be the global ZeroGPU Config object. Cannot determine Gradio port from this object.")
420
+ return None
421
+ elif hasattr(server_config, 'port'):
422
+ return server_config.port
423
+
424
+ warnings.warn(f"ZeroGPU: Unexpected type for server.config ({type(server_config)}). Cannot determine Gradio port.")
425
+ return None
426
+
427
+ def try_process_queue_event(method_name: str, *args, **kwargs):
428
+ queue = get_queue_instance()
429
+ if queue is None:
430
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
431
+ return
432
+ method = getattr(queue, method_name, None)
433
+ if callable(method):
434
+ try:
435
+ method(*args, **kwargs)
436
+ except Exception as e:
437
+ print(f"Error processing Gradio queue event '{method_name}': {e}", file=sys.stderr)
438
+
439
+ QUEUE_RPC_METHODS = ["set_progress", "log_message"]
440
+
441
+ def patch_gradio_queue(res_queue: Union[SimpleQueue[RegularResQueueResult | None], SimpleQueue[GeneratorResQueueResult | None]]):
442
+ def rpc_method(method_name: str):
443
+ def method(*args, **kwargs):
444
+ if args and isinstance(args[0], Queue): args = args[1:]
445
+ res_queue.put(GradioQueueEvent(method_name, args, kwargs))
446
+ return method
447
+
448
+ for method_name in QUEUE_RPC_METHODS:
449
+ if (method := getattr(Queue, method_name, None)) is None:
450
+ warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
451
+ continue
452
+ if not callable(method):
453
+ warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
454
+ continue
455
+ setattr(Queue, method_name, rpc_method(method_name))
456
+ TrackedIterable.__reduce__ = tracked_iterable__reduce__
457
+
458
+ def tracked_iterable__reduce__(self):
459
+ try:
460
+ res: tuple = super(TrackedIterable, self).__reduce__()
461
+ cls, base, state, *_ = res
462
+ return cls, base, {**state, **{'iterable': None, '_tqdm': None}}
463
+ except Exception:
464
+ return object, (), {}
465
+
466
+ def supports_auth():
467
+ try:
468
+ return version.parse(gr.__version__) >= version.Version('4.27.0')
469
+ except Exception:
470
+ return False
471
+
472
+ Param_one_launch = ParamSpec('Param_one_launch')
473
+
474
+ def one_launch(task: Callable[Param_one_launch, None], *task_args: Param_one_launch.args, **task_kwargs: Param_one_launch.kwargs):
475
+ _launch = gr.Blocks.launch
476
+ @wraps(gr.Blocks.launch)
477
+ def launch(*args, **kwargs):
478
+ task(*task_args, **task_kwargs)
479
+ gr.Blocks.launch = _launch
480
+ return gr.Blocks.launch(*args, **kwargs)
481
+ gr.Blocks.launch = launch
482
+
483
+ class HTMLError(gr.Error):
484
+ def __str__(self): return str(self.message)
485
+
486
+ def error(title: str, message: str, html: bool = False):
487
+ print(f"ERROR: {title} - {message}", file=sys.stderr)
488
+ error_cls = HTMLError if html else gr.Error
489
+ params = inspect.signature(gr.Error).parameters
490
+ kwargs: dict[str, Any] = {}
491
+ if 'title' in params: kwargs['title'] = title
492
+ if 'print_exception' in params: kwargs['print_exception'] = False
493
+ try:
494
+ pass
495
+ except Exception:
496
+ pass
497
+
498
+ def info(title: str, message: str, level: Level = 'info'):
499
+ print(f"INFO: {title} - {message}")
500
+ info_cls = modal(level)
501
+ params = inspect.signature(gr.Info).parameters
502
+ kwargs: dict[str, Any] = {}
503
+ if 'title' in params: kwargs['title'] = title
504
+ try:
505
+ info_cls(message, **kwargs)
506
+ except Exception:
507
+ pass
508
+
509
+ TOKEN_HEADER = 'X-IP-Token'
510
+ UNUSED_MESSAGE = "GPU device not used"
511
+ NO_GPU_MESSAGE_REGULAR = "No GPU was available"
512
+ NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60 seconds"
513
+ EXAMPLES_RETRY_MESSAGE = "Try re-running outside of examples if it happened after clicking one"
514
+ SIGNUP_ON_HF_TXT = "Create a free account"
515
+ SIGNUP_ON_HF_URL = "https://huggingface.co/join"
516
+ SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
517
+ SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
518
+
519
+ def api_client():
520
+ assert Config.zero_device_api_url is not None
521
+ httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
522
+ return APIClient(httpx_client)
523
+
524
+ def startup_report_client(cgroup_path: str, gpu_size: GPUSize):
525
+ retries, max_retries = 0, 2
526
+ client = api_client()
527
+ status = None
528
+ while retries <= max_retries:
529
+ status = client.startup_report(cgroup_path, gpu_size)
530
+ if status is not httpx.codes.NOT_FOUND:
531
+ break
532
+ time.sleep(1)
533
+ retries += 1
534
+ if status is not httpx.codes.OK:
535
+ print(f"Error while initializing ZeroGPU: status {status}", file=sys.stderr)
536
+
537
+ def html_string(html_contents: str, text_contents: str):
538
+ class HTMLString(str):
539
+ def __str__(self): return text_contents
540
+ return HTMLString(html_contents)
541
+
542
+ def _toast_action(auth: AuthLevel | None, supports_html: bool, pro_message: str, unlogged_desc: str, logged_desc: str, ending: str) -> tuple[str, str]:
543
+ if not supports_auth() or auth == 'pro':
544
+ return pro_message, pro_message
545
+ link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
546
+ text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
547
+ desc = unlogged_desc if auth is None else logged_desc
548
+ desc += f" {ending}."
549
+ style = ";".join(["white-space: nowrap", "text-underline-offset: 2px", "color: var(--body-text-color)"])
550
+ html = f'<a style="{style}" href="{link}">{text}</a> {desc}'
551
+ markdown = f'[{text}]({link}) {desc}'
552
+ return html, markdown
553
+
554
+ def schedule(task_id: int, request: gr.Request | None = None, duration: timedelta = timedelta(0), _first_attempt: bool = True) -> Optional[ScheduleResponse]:
555
+ try:
556
+ gradio_version = version.parse(gr.__version__)
557
+ if gradio_version.major < 4:
558
+ print("ZeroGPU is only compatible with Gradio 4+", file=sys.stderr)
559
+ return None
560
+ except Exception:
561
+ print("Could not parse Gradio version.", file=sys.stderr)
562
+ return None
563
+
564
+ GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39')
565
+ GRADIO_HANDSHAKE = gradio_version >= version.Version('5.16.1')
566
+ token, payload = _get_token_and_payload(request)
567
+ if token is not None and (token_error := payload.get('error')):
568
+ info("ZeroGPU client warning", f"Falling back to IP-based quotas ({token_error})", level='warning')
569
+
570
+ duration_seconds = duration.seconds
571
+
572
+ res, meta = api_client().schedule(cgroup_path=self_cgroup_device_path(), task_id=task_id, token=token, token_version=2 if GRADIO_HANDSHAKE else 1, duration_seconds=duration_seconds)
573
+
574
+ if isinstance(res, ScheduleResponse):
575
+ print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.")
576
+ return res
577
+ if isinstance(res, QuotaInfos):
578
+ requested = duration.seconds
579
+ message = ""
580
+ if res.wait < timedelta(0):
581
+ message = f"The requested GPU duration ({requested}s) is larger than the maximum allowed"
582
+ elif token is None:
583
+ message = f"Space app has reached its GPU limit. {EXAMPLES_RETRY_MESSAGE}"
584
+ else:
585
+ if payload.get('user') is None and res.wait == timedelta(0):
586
+ message = "You have exceeded your runs limit."
587
+ else:
588
+ gpu = "Pro GPU" if meta.auth == 'pro' else ("free GPU" if meta.auth == 'regular' else "GPU")
589
+ message = f"You have exceeded your {gpu} quota ({requested}s requested vs. {res.left}s left). Try again in {res.wait}"
590
+ print(f"ZeroGPU quota exceeded: {message}", file=sys.stderr)
591
+ return None
592
+ if not isinstance(res, httpx.codes):
593
+ if meta.queuing_reason in ('node', None): info("ZeroGPU queue", "Waiting for a GPU to become available")
594
+ elif meta.queuing_reason == 'concurrency': info("ZeroGPU queue", "Waiting for a GPU slot on this Space")
595
+ else: assert_never(meta.queuing_reason)
596
+ connection_event = get_event()
597
+ if connection_event is None and request is not None:
598
+ warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
599
+ while True:
600
+ try:
601
+ event = next(res)
602
+ except StopIteration:
603
+ print("Unexpected end of stream in schedule", file=sys.stderr)
604
+ return None
605
+ except httpx.RemoteProtocolError:
606
+ if not _first_attempt:
607
+ print("Error while re-trying after queue disconnect", file=sys.stderr)
608
+ return None
609
+ return schedule(task_id, request, duration, _first_attempt=False)
610
+ except Exception as e:
611
+ print(f"Error processing schedule event stream: {e}", file=sys.stderr)
612
+ return None
613
+ if event.event == 'ping':
614
+ if connection_event is not None and not connection_event.alive:
615
+ res.close()
616
+ print("Connection closed by visitor while queueing", file=sys.stderr)
617
+ return None
618
+ continue
619
+ if event.event == 'failed':
620
+ if token is None:
621
+ message = f"{NO_GPU_MESSAGE_INQUEUE}. {EXAMPLES_RETRY_MESSAGE}"
622
+ else:
623
+ _, details_markdown = _toast_action(auth=meta.auth, supports_html=GRADIO_HTML_TOASTS, pro_message="Retry later", unlogged_desc="to get a higher", logged_desc="to get the highest", ending="priority in ZeroGPU queues")
624
+ message = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
625
+ print(f"ZeroGPU queue timeout: {message}", file=sys.stderr)
626
+ return None
627
+ if event.event == 'succeeded':
628
+ assert event.data is not None
629
+ if connection_event is not None and not connection_event.alive:
630
+ release(event.data.allowToken)
631
+ print("Connection closed by visitor on queue success", file=sys.stderr)
632
+ return None
633
+ info("ZeroGPU queue", "Successfully acquired a GPU", level='success')
634
+ print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.")
635
+ return event.data
636
+ if res is httpx.codes.SERVICE_UNAVAILABLE:
637
+ print(f"ZeroGPU client error: {NO_GPU_MESSAGE_REGULAR}", file=sys.stderr)
638
+ return None
639
+ if res is httpx.codes.UNAUTHORIZED:
640
+ print("ZeroGPU client error: Expired ZeroGPU proxy token", file=sys.stderr)
641
+ return None
642
+ reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown"
643
+ print(f"ZeroGPU API /schedule error: {res} ({reason})", file=sys.stderr)
644
+ return None
645
+
646
+ def allow(allow_token: str) -> None:
647
+ process_id = os.getpid()
648
+ if process_id == 1:
649
+ print("CRITICAL: Allowing PID 1 on ZeroGPU will end up killing your Space. Aborting.", file=sys.stderr)
650
+ return
651
+ if api_client().allow(allow_token=allow_token, pid=process_id) is not httpx.codes.OK:
652
+ print(f"API call to /allow failed for token {allow_token}", file=sys.stderr)
653
+
654
+ def release(allow_token: str, *, fail: bool = False, allow_404: bool = True) -> None:
655
+ res = api_client().release(allow_token=allow_token, fail=fail)
656
+ if res is httpx.codes.NO_CONTENT:
657
+ try:
658
+ info("ZeroGPU client warning", UNUSED_MESSAGE, level='warning')
659
+ except AttributeError:
660
+ pass
661
+ warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
662
+ return
663
+ if res is httpx.codes.NOT_FOUND:
664
+ if not allow_404:
665
+ warnings.warn("ZeroGPU API /release warning: 404 Not Found")
666
+ return
667
+ if httpx.codes.is_success(res):
668
+ return
669
+ reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown"
670
+ print(f"ZeroGPU API /release error: {res} ({reason})", file=sys.stderr)
671
+
672
+ def _get_token(request: gr.Request | None) -> str | None:
673
+ if request is None: return None
674
+ headers = getattr(request, 'headers', None)
675
+ if headers is None or not hasattr(headers, '__dict__'):
676
+ print("ZeroGPU client error: Internal Gradio error (headers not found)", file=sys.stderr)
677
+ return None
678
+ if not hasattr(headers, 'get'):
679
+ headers = headers.__dict__
680
+ return headers.get(TOKEN_HEADER.lower())
681
+
682
+ def _get_token_and_payload(request: gr.Request | None) -> tuple[str | None, dict[str, Any]]:
683
+ token = _get_token(request)
684
+ if token is None: return None, {}
685
+ payload = jwt_payload(token)
686
+ return token, payload
687
+
688
+ def compute_base_free_memory(total_memory: int) -> int:
689
+ pytorch_base_memory = 309002240
690
+ return total_memory - pytorch_base_memory - Config.zerogpu_cuda_reserved_memory
691
+
692
+ CUDA_DEVICE_NAME_STATIC = Config.zerogpu_cuda_device_name
693
+ CUDA_TOTAL_MEMORY_STATIC = Config.zerogpu_cuda_total_memory
694
+ CUDA_MEM_GET_INFO_STATIC = (compute_base_free_memory(CUDA_TOTAL_MEMORY_STATIC), CUDA_TOTAL_MEMORY_STATIC)
695
+ CUDA_DEVICE_CAPABILITY_STATIC = (Config.zerogpu_cuda_capability_major, Config.zerogpu_cuda_capability_minor)
696
+ CUDA_DEVICE_PROPERTIES_STATIC = SimpleNamespace(name=CUDA_DEVICE_NAME_STATIC, major=CUDA_DEVICE_CAPABILITY_STATIC[0], minor=CUDA_DEVICE_CAPABILITY_STATIC[1], total_memory=CUDA_TOTAL_MEMORY_STATIC, multi_processor_count=Config.zerogpu_cuda_multi_processor_count)
697
+
698
+ if torch:
699
+ class MockCudaRuntime:
700
+ def setDevice(self, device):
701
+ pass
702
+ def getDevice(self):
703
+ return 0
704
+ def deviceSynchronize(self):
705
+ pass
706
+ def deviceGetStreamPriorityRange(self):
707
+ return 0, 0
708
+ cudart = MockCudaRuntime()
709
+
710
+ if torch and torch.version.cuda.startswith("12."):
711
+ CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "num_sync_all_streams": 0, "num_device_alloc": 0, "num_device_free": 0, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}
712
+ else:
713
+ CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}
714
+
715
+ def cudaMemGetInfo(device: int, /):
716
+ return CUDA_MEM_GET_INFO_STATIC
717
+
718
+ PAGE_SIZE = 4096
719
+ try:
720
+ TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
721
+ except (ValueError, AttributeError):
722
+ TOTAL_MEMORY = 8 * (1024**3)
723
+ VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
724
+ BUFFER_SIZE = 128 * 2**20
725
+ BUFFER_COUNT = 2
726
+ if torch:
727
+ TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
728
+
729
+ if torch:
730
+ @dataclass
731
+ class ZeroGPUTensorPack:
732
+ base_dir: str
733
+ batches: list[list[TensorWithSizes]]
734
+ big_tensors: list[list[TensorWithSizes]]
735
+ fakes: dict[torch.Tensor, list[torch.Tensor]]
736
+ total_size: int
737
+
738
+ def path(self):
739
+ return f'{self.base_dir}/{id(self)}'
740
+
741
+ def __del__(self):
742
+ try:
743
+ os.remove(self.path())
744
+ except (FileNotFoundError, TypeError, AttributeError):
745
+ pass
746
+
747
+ def write_packing(fd: int, tensor: torch.Tensor):
748
+ try:
749
+ clone = torch.empty_like(tensor)
750
+ size = clone.untyped_storage().size()
751
+ buffer = torch.UntypedStorage(VM_MAX_SIZE)
752
+ buffer_ptr = buffer.data_ptr()
753
+ offset = -buffer_ptr % PAGE_SIZE
754
+ padding = -size % PAGE_SIZE
755
+ clone.set_(buffer[offset:offset + size], 0, clone.shape, clone.stride())
756
+ clone.copy_(tensor)
757
+ mv = memoryview((ctypes.c_char * (size + padding)).from_address(buffer_ptr + offset))
758
+ written_bytes = 0
759
+ while written_bytes < size:
760
+ written_bytes += os.write(fd, mv[written_bytes:])
761
+ except Exception as e:
762
+ print(f"Error during tensor write packing: {e}", file=sys.stderr)
763
+
764
+ def pack_tensors(tensors: set[torch.Tensor], fakes: dict[torch.Tensor, list[torch.Tensor]], offload_dir: str, callback: Callable[[int], None] | None = None):
765
+ callback = (lambda b: None) if callback is None else callback
766
+ batches: list[list[TensorWithSizes]] = []
767
+ big_tensors: list[list[TensorWithSizes]] = []
768
+ tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
769
+ for tensor in tensors:
770
+ size = tensor.numel() * tensor.element_size()
771
+ aligned_size = size + (-size % PAGE_SIZE)
772
+ tensors_with_sizes.append((tensor, size, aligned_size))
773
+ current_batch, current_size = [], 0
774
+ for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
775
+ if aligned_size > BUFFER_SIZE:
776
+ big_tensors.append((tensor, size, aligned_size))
777
+ continue
778
+ current_size += aligned_size
779
+ if current_size > BUFFER_SIZE:
780
+ batches.append(current_batch)
781
+ current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
782
+ else:
783
+ current_batch.append((tensor, size, aligned_size))
784
+ if current_batch:
785
+ batches.append(current_batch)
786
+ get_meta = {tensor: empty_like_raw_alloc(tensor) for tensor in tensors}
787
+ batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
788
+ big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
789
+ fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
790
+ pack = ZeroGPUTensorPack(base_dir=offload_dir, batches=batches_meta, big_tensors=big_tensors_meta, fakes=fakes_meta, total_size=sum([size for _, size, _ in tensors_with_sizes]))
791
+ fd = -1
792
+ try:
793
+ fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
794
+ total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
795
+ total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
796
+ if total_asize > 0:
797
+ os.posix_fallocate(fd, 0, total_asize)
798
+ for batch in batches:
799
+ for tensor, size, _ in batch:
800
+ write_packing(fd, tensor)
801
+ callback(size)
802
+ for tensor, size, _ in big_tensors:
803
+ write_packing(fd, tensor)
804
+ callback(size)
805
+ return pack
806
+ except Exception as e:
807
+ print(f"Failed to pack tensors to disk: {e}", file=sys.stderr)
808
+ return pack
809
+ finally:
810
+ if fd != -1:
811
+ os.close(fd)
812
+
813
+ def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int], None] | None = None):
814
+ callback = (lambda b: None) if callback is None else callback
815
+ free_buffers: ThreadQueue[torch.Tensor] = ThreadQueue()
816
+ read_buffers: ThreadQueue[torch.Tensor] = ThreadQueue()
817
+ for _ in range(BUFFER_COUNT):
818
+ free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
819
+ def read(fd: int, buffer: torch.Tensor, size: int):
820
+ mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
821
+ read_bytes = 0
822
+ while read_bytes < size:
823
+ read_bytes += os.readv(fd, [mv[read_bytes:]])
824
+ def disk_to_pin(fd: int):
825
+ for batch in pack.batches:
826
+ buffer = free_buffers.get()
827
+ batch_size = sum([aligned_size for *_, aligned_size in batch])
828
+ read(fd, buffer, batch_size)
829
+ read_buffers.put(buffer)
830
+ for *_, aligned_size in pack.big_tensors:
831
+ read_bytes = 0
832
+ while read_bytes < aligned_size:
833
+ buffer = free_buffers.get()
834
+ read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
835
+ read(fd, buffer, read_size)
836
+ read_buffers.put(buffer)
837
+ read_bytes += read_size
838
+ def pin_to_cuda():
839
+ total_duration_in_callback = 0
840
+ for batch in pack.batches:
841
+ buffer = read_buffers.get()
842
+ offset = 0
843
+ cuda_storages = []
844
+ for tensor, size, aligned_size in batch:
845
+ cuda_storages.append(buffer[offset:offset + size].cuda(non_blocking=True))
846
+ offset += aligned_size
847
+ torch.cuda.synchronize()
848
+ free_buffers.put(buffer)
849
+ batch_total_size = 0
850
+ for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
851
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
852
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
853
+ for fake in pack.fakes[tensor]:
854
+ fake.data = cuda_tensor
855
+ batch_total_size += size
856
+ t0 = time.perf_counter()
857
+ callback(batch_total_size)
858
+ total_duration_in_callback += time.perf_counter() - t0
859
+ for tensor, size, _ in pack.big_tensors:
860
+ cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
861
+ offset = 0
862
+ while offset < size:
863
+ buffer = read_buffers.get()
864
+ read_size = min(BUFFER_SIZE, size - offset)
865
+ cuda_storage[offset:offset + read_size] = buffer[:read_size]
866
+ offset += read_size
867
+ torch.cuda.synchronize()
868
+ free_buffers.put(buffer)
869
+ t0 = time.perf_counter()
870
+ callback(read_size)
871
+ total_duration_in_callback += time.perf_counter() - t0
872
+ cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
873
+ cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
874
+ for fake in pack.fakes[tensor]:
875
+ fake.data = cuda_tensor
876
+ debug(f"{total_duration_in_callback=}")
877
+ fd = -1
878
+ try:
879
+ with ThreadPoolExecutor(2) as e:
880
+ fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
881
+ futures = [e.submit(copy_context().run, disk_to_pin, fd), e.submit(copy_context().run, pin_to_cuda)]
882
+ for future in as_completed(futures):
883
+ future.result()
884
+ except Exception as e:
885
+ print(f"Error during pack_to_cuda: {e}", file=sys.stderr)
886
+ finally:
887
+ if fd != -1:
888
+ os.close(fd)
889
+
890
+ @contextmanager
891
+ def cuda_unavailable(torch_module: ModuleType):
892
+ _is_available = torch_module.cuda.is_available
893
+ torch_module.cuda.is_available = lambda: False
894
+ yield
895
+ torch_module.cuda.is_available = _is_available
896
+
897
+ def maybe_import_bitsandbytes():
898
+ try:
899
+ if torch is None: return None
900
+ bnb_version = version.parse(metadata.version('bitsandbytes'))
901
+ if bnb_version < version.parse('0.40.0'):
902
+ print(f"Warning: ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})", file=sys.stderr)
903
+ return None
904
+ ctx_factory = (lambda: cuda_unavailable(torch)) if bnb_version < version.parse('0.43.1') else nullcontext
905
+ with (ctx := ctx_factory()):
906
+ importlib.import_module('bitsandbytes')
907
+ if not isinstance(ctx, nullcontext):
908
+ print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑", file=sys.stderr)
909
+ return ctx_factory
910
+ except (ImportError, metadata.PackageNotFoundError):
911
+ return None
912
+ except Exception as e:
913
+ print(f"Unexpected error during bitsandbytes check: {e}", file=sys.stderr)
914
+ return None
915
+
916
+ bnb_import_context = maybe_import_bitsandbytes()
917
+
918
+ if bnb_import_context and torch:
919
+ from torch.utils.weak import WeakTensorKeyDictionary
920
+ with (import_ctx := bnb_import_context()):
921
+ CUDASetup = None
922
+ if not isinstance(import_ctx, nullcontext):
923
+ from bitsandbytes.cuda_setup.main import CUDASetup
924
+ from bitsandbytes import cextension, functional
925
+ from bitsandbytes.nn import Int8Params, Params4bit
926
+
927
+ _param_to_8bit = Int8Params.to
928
+ _param_cuda_8bit = Int8Params.cuda
929
+ _param_to_4bit = Params4bit.to
930
+ _param_cuda_4bit = Params4bit.cuda
931
+ TensorToArgs_bnb = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
932
+ to_ops_8bit: dict[Int8Params, TensorToArgs_bnb | None] = WeakTensorKeyDictionary()
933
+ to_ops_4bit: dict[Params4bit, TensorToArgs_bnb | None] = WeakTensorKeyDictionary()
934
+
935
+ def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
936
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
937
+ device, *_ = parsed
938
+ if not isinstance(device, torch.device) or device.type != 'cuda':
939
+ return _param_to_8bit(self, *args, **kwargs)
940
+ to_ops_8bit[self] = parsed
941
+ return self
942
+
943
+ def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
944
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
945
+ device, *_ = parsed
946
+ if not isinstance(device, torch.device) or device.type != 'cuda':
947
+ return _param_to_4bit(self, *args, **kwargs)
948
+ to_ops_4bit[self] = parsed
949
+ return self
950
+
951
+ def _cuda_op_arg_check_bnb(device: Union[torch.device, int, str, None]) -> bool:
952
+ if device is None or isinstance(device, int): return True
953
+ if isinstance(device, str): device = torch.device(device)
954
+ return device.type == 'cuda'
955
+
956
+ def _cuda_op_register_8bit(self: Int8Params, device: Union[torch.device, int, str, None] = None, **kwargs):
957
+ if not _cuda_op_arg_check_bnb(device): return _param_cuda_8bit(self, device, **kwargs)
958
+ to_ops_8bit[self] = None
959
+ return self
960
+
961
+ def _cuda_op_register_4bit(self: Params4bit, device: Union[torch.device, int, str, None] = None, **kwargs):
962
+ if not _cuda_op_arg_check_bnb(device): return _param_cuda_4bit(self, device, **kwargs)
963
+ to_ops_4bit[self] = None
964
+ return self
965
+
966
+ def _patch_bnb():
967
+ Int8Params.to = _to_op_register_8bit
968
+ Int8Params.cuda = _cuda_op_register_8bit
969
+ Params4bit.to = _to_op_register_4bit
970
+ Params4bit.cuda = _cuda_op_register_4bit
971
+
972
+ def _unpatch_bnb():
973
+ Int8Params.to = _param_to_8bit
974
+ Int8Params.cuda = _param_cuda_8bit
975
+ Params4bit.to = _param_to_4bit
976
+ Params4bit.cuda = _param_cuda_4bit
977
+
978
+ def _move_bnb():
979
+ if CUDASetup is not None:
980
+ CUDASetup._instance = None
981
+ importlib.reload(cextension)
982
+ functional.lib = cextension.lib
983
+ for tensor, parsed_args in to_ops_8bit.items():
984
+ dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None)
985
+ tensor.data = _param_to_8bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format)
986
+ for tensor, parsed_args in to_ops_4bit.items():
987
+ dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None)
988
+ tensor.data = _param_to_4bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format)
989
+ else:
990
+ def _patch_bnb(): pass
991
+ def _unpatch_bnb(): pass
992
+ def _move_bnb(): pass
993
+
994
+ patch_bnb = _patch_bnb
995
+ unpatch_bnb = _unpatch_bnb
996
+ move_bnb = _move_bnb
997
+
998
+ class _BitsAndBytesManager:
999
+ def patch(self): return patch_bnb()
1000
+ def unpatch(self): return unpatch_bnb()
1001
+ def move(self): return move_bnb()
1002
+
1003
+ if torch:
1004
+ PINNED_MEMORY_RATIO_LIMIT = 0.1
1005
+ OPS_INPUTS_CHECK_NO_RETURN = (torch.Tensor.equal,)
1006
+ OPS_INPUT_CHECK_SELF_RETURN = (torch.Tensor.set_, torch.ops.aten.set_.source_Tensor)
1007
+ OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
1008
+ _tensor_make_subclass = torch.Tensor._make_subclass
1009
+ _asarray = torch.asarray
1010
+ _device = torch.device
1011
+ _cuda_init_v2 = torch._C._cuda_init
1012
+ _cuda_exchange_device = torch.cuda._exchange_device
1013
+ _cuda_available_v2 = torch.cuda.is_available
1014
+ _cuda_device_count_v2 = torch.cuda.device_count
1015
+ _cuda_current_device_v2 = torch.cuda.current_device
1016
+ _cuda_synchronize = torch.cuda.synchronize
1017
+ _cuda_get_device_capability_v2 = torch.cuda.get_device_capability
1018
+ _cuda_get_device_properties_v2 = torch.cuda.get_device_properties
1019
+ _cuda_get_device_name_v2 = torch.cuda.get_device_name
1020
+ _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict
1021
+ _cuda_cudart = torch.cuda.cudart
1022
+ _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
1023
+ cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary()
1024
+ tensor_packs: list[ZeroGPUTensorPack] = []
1025
+
1026
+ class ZeroGPUTensor(torch.Tensor): pass
1027
+
1028
+ def empty_fake(tensor: torch.Tensor):
1029
+ fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad)
1030
+ if fake.__class__ != tensor.__class__:
1031
+ fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad)
1032
+ return fake
1033
+
1034
+ def no_int_device(*args, **kwargs):
1035
+ if len(args) and isinstance(index := args[0], int):
1036
+ args = (f'cuda:{index}', *args[1:])
1037
+ if isinstance(index := kwargs.get('device'), int):
1038
+ kwargs['device'] = f'cuda:{index}'
1039
+ return args, kwargs
1040
+
1041
+ class ZeroGPUFunctionMode(torch.overrides.TorchFunctionMode):
1042
+ def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
1043
+ kwargs = {} if kwargs is None else kwargs
1044
+ try:
1045
+ if func == torch._C._nn._parse_to:
1046
+ args, kwargs = no_int_device(*args, **kwargs)
1047
+ return func(*args, **kwargs)
1048
+ if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
1049
+ memory_format = kwargs.get("memory_format")
1050
+ device_str = "cuda" if func == torch.Tensor.cuda else "cpu"
1051
+ to_kwargs = {"device": device_str}
1052
+ if memory_format is not None: to_kwargs["memory_format"] = memory_format
1053
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), to_kwargs)
1054
+ if func == torch.Tensor.to and len(args) > 1:
1055
+ parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs)
1056
+ device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs)
1057
+ return self.__torch_function__(torch.Tensor.to, types, (args[0],), {'device': device, 'dtype': dtype, 'memory_format': memory_format})
1058
+ if func == torch.Tensor.data.__set__:
1059
+ self_tensor, target = args
1060
+ if target in cuda_aliases:
1061
+ if (target_original := cuda_aliases[target]) is None:
1062
+ print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), target), file=sys.stderr)
1063
+ return
1064
+ original = empty_fake(self_tensor)
1065
+ original.data = target_original
1066
+ cuda_aliases[self_tensor] = original
1067
+ elif self_tensor in cuda_aliases:
1068
+ del cuda_aliases[self_tensor]
1069
+ self_tensor.data = target
1070
+ return
1071
+ if func == torch.Tensor.device.__get__:
1072
+ tensor, = args
1073
+ if tensor in cuda_aliases: return torch.device('cuda', index=0)
1074
+ elif func == torch.Tensor.__repr__:
1075
+ tensor, = args
1076
+ if tensor in cuda_aliases:
1077
+ original = cuda_aliases[tensor] or tensor.to('meta')
1078
+ original_class = original.__class__
1079
+ original.__class__ = ZeroGPUTensor
1080
+ try:
1081
+ return func(original, **kwargs)
1082
+ finally:
1083
+ original.__class__ = original_class
1084
+ elif func == torch.Tensor.untyped_storage:
1085
+ tensor, = args
1086
+ if tensor in cuda_aliases:
1087
+ if (original := cuda_aliases[tensor]) is None:
1088
+ print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), tensor), file=sys.stderr)
1089
+ return None
1090
+ res = func(original, **kwargs)
1091
+ res._zerogpu = True
1092
+ return res
1093
+ cuda: bool | None = None
1094
+ if (device := kwargs.get('device')) is not None:
1095
+ device = torch.device(device)
1096
+ cuda = device.type == 'cuda'
1097
+ if cuda: kwargs['device'] = torch.device('cpu')
1098
+ swapped, inputs_are_cuda = {}, set()
1099
+ def swap(t: torch.Tensor):
1100
+ nonlocal inputs_are_cuda
1101
+ if t not in cuda_aliases:
1102
+ inputs_are_cuda.add(False)
1103
+ return t
1104
+ original = cuda_aliases[t]
1105
+ if original is None:
1106
+ print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), t), file=sys.stderr)
1107
+ return t
1108
+ swapped[original] = t
1109
+ inputs_are_cuda.add(True)
1110
+ return original
1111
+ args_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, args)
1112
+ kwargs_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, kwargs)
1113
+ if inputs_are_cuda == {True} and cuda is not False: cuda = True
1114
+ if len(args) == 1 and torch.utils._python_dispatch.is_traceable_wrapper_subclass(wt := args[0]):
1115
+ if func in {torch.Tensor.detach, torch.ops.aten.alias.default, torch.ops.aten.clone.default}:
1116
+ with self: return torch.utils._python_dispatch.transform_subclass(wt, lambda _, t: func(t))
1117
+ res = func(*args_, **kwargs_)
1118
+ for original, fake in swapped.items(): fake.data = empty_fake(original)
1119
+ if func in {torch.ops.aten.index.Tensor, torch.Tensor.__getitem__}:
1120
+ cuda = args[0] in cuda_aliases
1121
+ inputs_are_cuda = {cuda}
1122
+ if (isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN) and not (func == torch.ops.aten.set_.source_Tensor and len(args_) == 3):
1123
+ st = args_[0] if len(args_) >= 1 and isinstance(args_[0], torch.Tensor) else None
1124
+ if (res is not st or func in OPS_INPUT_CHECK_SELF_RETURN) and inputs_are_cuda == {True, False}:
1125
+ print("RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 (ZeroGPU) and cpu!", file=sys.stderr)
1126
+ def register(t: torch.Tensor):
1127
+ if t in swapped and cuda is not False: return swapped[t]
1128
+ if cuda is not True: return t
1129
+ fake = empty_fake(t)
1130
+ cuda_aliases[fake] = t
1131
+ return fake
1132
+ return torch.utils._pytree.tree_map_only(torch.Tensor, register, res)
1133
+ except Exception as e:
1134
+ print(f"Error in ZeroGPUFunctionMode: {e}", file=sys.stderr)
1135
+ return func(*args, **kwargs)
1136
+
1137
+ class DefaultDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
1138
+ def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
1139
+ return func(*args, **(kwargs or {}))
1140
+
1141
+ function_mode = ZeroGPUFunctionMode()
1142
+ dispatch_mode = DefaultDispatchMode()
1143
+
1144
+ def _untyped_storage_new_register(*args, **kwargs):
1145
+ cuda = False
1146
+ if (device := kwargs.get('device')) is not None and device.type == 'cuda':
1147
+ cuda = True
1148
+ del kwargs['device']
1149
+ storage = torch._C.StorageBase.__new__(*args, **kwargs)
1150
+ if cuda: storage._zerogpu = True
1151
+ return storage
1152
+
1153
+ @property
1154
+ def _untyped_storage_device(self):
1155
+ if hasattr(self, '_zerogpu'): return torch.device('cuda', index=0)
1156
+ return torch._C.StorageBase.device.__get__(self)
1157
+
1158
+ def _tensor_make_subclass_function_mode(*args, **kwargs):
1159
+ with torch._C.DisableTorchFunction():
1160
+ return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
1161
+
1162
+ def _asarray_function_mode(*args, **kwargs):
1163
+ with torch._C.DisableTorchFunction():
1164
+ return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
1165
+
1166
+ class _DeviceStringOnlyMeta(type):
1167
+ def __instancecheck__(cls, instance): return isinstance(instance, _device)
1168
+
1169
+ class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta):
1170
+ def __new__(cls, *args, **kwargs):
1171
+ args, kwargs = no_int_device(*args, **kwargs)
1172
+ return _device(*args, **kwargs)
1173
+
1174
+ def _cuda_init_raise_v2():
1175
+ pass
1176
+
1177
+ def _cuda_dummy_exchange_device(device):
1178
+ assert device in {-1, 0}
1179
+ return device
1180
+
1181
+ def patch_v2():
1182
+ function_mode.__enter__()
1183
+ dispatch_mode.__enter__()
1184
+ torch.Tensor._make_subclass = _tensor_make_subclass_function_mode
1185
+ torch.UntypedStorage.__new__ = _untyped_storage_new_register
1186
+ torch.UntypedStorage.device = _untyped_storage_device
1187
+ torch.asarray = _asarray_function_mode
1188
+ torch.device = _DeviceStringOnly
1189
+ torch._C._cuda_init = _cuda_init_raise_v2
1190
+ torch.cuda._exchange_device = _cuda_dummy_exchange_device
1191
+ torch.cuda.is_available = lambda: True
1192
+ torch.cuda.device_count = lambda: 1
1193
+ torch.cuda.current_device = lambda: 0
1194
+ torch.cuda.synchronize = lambda *args: None
1195
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_STATIC
1196
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_STATIC
1197
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_STATIC
1198
+ torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC
1199
+ torch.cuda.cudart = lambda: cudart
1200
+ if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
1201
+ _BitsAndBytesManager().patch()
1202
+
1203
+ def unpatch_v2():
1204
+ from contextlib import suppress
1205
+ try:
1206
+ dispatch_mode.__exit__(None, None, None)
1207
+ function_mode.__exit__(None, None, None)
1208
+ except RuntimeError: pass
1209
+ torch.Tensor._make_subclass = _tensor_make_subclass
1210
+ torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
1211
+ torch.UntypedStorage.device = torch._C.StorageBase.device
1212
+ torch.asarray = _asarray
1213
+ torch.device = _device
1214
+ torch._C._cuda_init = _cuda_init_v2
1215
+ torch.cuda._exchange_device = _cuda_exchange_device
1216
+ torch.cuda.is_available = _cuda_available_v2
1217
+ torch.cuda.device_count = _cuda_device_count_v2
1218
+ torch.cuda.current_device = _cuda_current_device_v2
1219
+ torch.cuda.synchronize = _cuda_synchronize
1220
+ torch.cuda.get_device_capability = _cuda_get_device_capability_v2
1221
+ torch.cuda.get_device_properties = _cuda_get_device_properties_v2
1222
+ torch.cuda.get_device_name = _cuda_get_device_name_v2
1223
+ torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict
1224
+ torch.cuda.cudart = _cuda_cudart
1225
+ if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
1226
+ _BitsAndBytesManager().unpatch()
1227
+
1228
+ def _total_unpacked_size():
1229
+ tensors = [t for t in cuda_aliases.values() if t is not None]
1230
+ deduped = {AliasId.from_tensor(t): t for t in tensors}
1231
+ return sum([t.numel() * t.element_size() for t in deduped.values()])
1232
+
1233
+ def _pack_v2_internal(offload_dir: str):
1234
+ originals, originals_dedup, fakes = set(), {}, defaultdict(list)
1235
+ for fake, original in cuda_aliases.items():
1236
+ if original is not None:
1237
+ original_id = AliasId.from_tensor(original)
1238
+ if original_id not in originals_dedup:
1239
+ originals_dedup[original_id] = original
1240
+ originals.add(original)
1241
+ fakes[originals_dedup[original_id]].append(fake)
1242
+ total_size = _total_unpacked_size()
1243
+ progress_context = tqdm(total=total_size, unit='B', unit_scale=True, desc="ZeroGPU tensors packing") if tqdm is not None and total_size > 0 else nullcontext()
1244
+ with progress_context as progress:
1245
+ update = progress.update if progress is not None else lambda _: None
1246
+ pack = pack_tensors(originals, fakes, offload_dir, callback=update)
1247
+ tensor_packs.append(pack)
1248
+ for fake_list in fakes.values():
1249
+ for fake in fake_list: cuda_aliases[fake] = None
1250
+ return total_size
1251
+
1252
+ def pack_v2():
1253
+ total_size = _pack_v2_internal(Config.zerogpu_offload_dir)
1254
+ gc.collect()
1255
+ malloc_trim()
1256
+ return total_size
1257
+
1258
+ def init_v2(nvidia_uuid: str):
1259
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
1260
+ torch.Tensor([0]).cuda()
1261
+
1262
+ def size_v2():
1263
+ return _total_unpacked_size() + sum([p.total_size for p in tensor_packs])
1264
+
1265
+ def _move_v2_internal(callback: Callable[[int], None] | None = None):
1266
+ cb = callback or (lambda _: None)
1267
+ pinned_limit, moved = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT, {}
1268
+ for fake, original in cuda_aliases.items():
1269
+ if original is not None:
1270
+ original_id = AliasId.from_tensor(original)
1271
+ if original_id not in moved:
1272
+ use_pinned = original.numel() * original.element_size() < pinned_limit
1273
+ original_cuda = original.pin_memory().cuda(non_blocking=True) if use_pinned else original.cuda()
1274
+ moved[original_id] = original_cuda
1275
+ cb(fake.numel() * fake.element_size())
1276
+ torch.cuda.synchronize()
1277
+ for fake, original in cuda_aliases.items():
1278
+ if original is not None: fake.data = moved[AliasId.from_tensor(original)]
1279
+ for tensor_pack in tensor_packs: pack_to_cuda(tensor_pack, callback=cb)
1280
+ _BitsAndBytesManager().move()
1281
+
1282
+ def move_v2(callback: Callable[[int], None] | None = None):
1283
+ cb = callback or (lambda _: None)
1284
+ with ThreadPoolExecutor(1) as e:
1285
+ e.submit(copy_context().run, _move_v2_internal, callback=cb).result()
1286
+ torch.cuda.synchronize()
1287
+
1288
+ def is_in_bad_fork_v2():
1289
+ return False
1290
+
1291
+ CUDA_DEVICE_NAME_LEGACY, CUDA_TOTAL_MEMORY_LEGACY = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb', 42144366592
1292
+ CUDA_MEM_GET_INFO_LEGACY = (41911451648, CUDA_TOTAL_MEMORY_LEGACY)
1293
+ CUDA_DEVICE_CAPABILITY_LEGACY = (8, 0)
1294
+ CUDA_DEVICE_PROPERTIES_LEGACY = SimpleNamespace(name=CUDA_DEVICE_NAME_LEGACY, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY_LEGACY, multi_processor_count=42)
1295
+ GENERIC_METHOD_NAMES = ['arange', 'as_tensor', 'asarray', 'bartlett_window', 'blackman_window', 'empty', 'empty_like', 'empty_strided', 'eye', 'full', 'full_like', 'hamming_window', 'hann_window', 'kaiser_window', 'linspace', 'logspace', 'ones', 'ones_like', 'rand', 'rand_like', 'randint', 'randint_like', 'randn', 'randn_like', 'randperm', 'range', 'sparse_bsc_tensor', 'sparse_bsr_tensor', 'sparse_compressed_tensor', 'sparse_coo_tensor', 'sparse_csc_tensor', 'sparse_csr_tensor', 'tensor', 'tril_indices', 'triu_indices', 'zeros', 'zeros_like']
1296
+ TO_CUDA = (torch.device('cuda'), None, False, None)
1297
+ _tensor__deepcopy__, _tensor_to, _tensor_cuda, _tensor_cpu = torch.Tensor.__deepcopy__, torch.Tensor.to, torch.Tensor.cuda, torch.Tensor.cpu
1298
+ _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
1299
+ _cuda_init_legacy, _cuda_available_legacy, _cuda_device_count_legacy, _cuda_current_device_legacy = torch._C._cuda_init, torch.cuda.is_available, torch.cuda.device_count, torch.cuda.current_device
1300
+ _cuda_mem_get_info, _cuda_get_device_capability_legacy, _cuda_get_device_properties_legacy, _cuda_get_device_name_legacy = torch.cuda.mem_get_info, torch.cuda.get_device_capability, torch.cuda.get_device_properties, torch.cuda.get_device_name
1301
+ TensorToArgs_legacy = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
1302
+ to_ops: dict[torch.Tensor, TensorToArgs_legacy] = WeakTensorKeyDictionary()
1303
+
1304
+ def _tensor_new_register(*args, **kwargs):
1305
+ new_tensor = torch._C._TensorBase.__new__(*args, **kwargs)
1306
+ if (base := getattr(new_tensor, '_base', None)) is not None and base in to_ops:
1307
+ to_ops[new_tensor] = to_ops[base]
1308
+ return new_tensor
1309
+
1310
+ def _tensor_deepcopy_register(self: torch.Tensor, memo):
1311
+ new_tensor = _tensor__deepcopy__(self, memo)
1312
+ if isinstance(new_tensor, torch.Tensor) and self in to_ops:
1313
+ to_ops[new_tensor] = to_ops[self]
1314
+ return new_tensor
1315
+
1316
+ @property
1317
+ def _tensor_device_property(self: torch.Tensor):
1318
+ if self in to_ops: return torch.device(type='cuda', index=0)
1319
+ del torch.Tensor.device
1320
+ try: return self.device
1321
+ finally: torch.Tensor.device = _tensor_device_property
1322
+
1323
+ @property
1324
+ def _tensor_dtype_property(self: torch.Tensor):
1325
+ if self in to_ops and (to_dtype := to_ops[self][1]) is not None: return to_dtype
1326
+ del torch.Tensor.dtype
1327
+ try: return self.dtype
1328
+ finally: torch.Tensor.dtype = _tensor_dtype_property
1329
+
1330
+ def _to_op_register(self: torch.Tensor, *args, **kwargs):
1331
+ parsed = torch._C._nn._parse_to(*args, **kwargs)
1332
+ device, dtype, *_ = parsed
1333
+ to_args = to_ops.pop(self, None)
1334
+ if device is None:
1335
+ if to_args is not None:
1336
+ to_ops[self] = (to_args[0], dtype, *to_args[2:])
1337
+ return self
1338
+ return _tensor_to(self, *args, **kwargs)
1339
+ if device.type != 'cuda':
1340
+ if to_args is not None and (to_dtype := to_args[1]) is not None:
1341
+ kwargs = {'dtype': to_dtype, **kwargs}
1342
+ return _tensor_to(self, *args, **kwargs)
1343
+ to_ops[self] = parsed
1344
+ return self
1345
+
1346
+ def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
1347
+ if device is None or isinstance(device, int): return True
1348
+ if isinstance(device, str): device = torch.device(device)
1349
+ return device.type == 'cuda'
1350
+
1351
+ def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
1352
+ if not _cuda_op_arg_check(device): return _tensor_cuda(self, device, **kwargs)
1353
+ to_ops[self] = TO_CUDA
1354
+ return self
1355
+
1356
+ def _cpu_op_remove(self: torch.Tensor, **kwargs):
1357
+ to_args = to_ops.pop(self, None)
1358
+ if to_args is not None and (to_dtype := to_args[1]) is not None:
1359
+ return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
1360
+ return _tensor_cpu(self, **kwargs)
1361
+
1362
+ def _cuda_init_raise_legacy():
1363
+ pass
1364
+
1365
+ def _generic_method_register(name: str, *args: Any, **kwargs: Any):
1366
+ try:
1367
+ device = torch.device(kwargs.get('device', "cpu"))
1368
+ except Exception:
1369
+ return _torch_generics[name](*args, **kwargs)
1370
+ if device.type != 'cuda':
1371
+ return _torch_generics[name](*args, **kwargs)
1372
+ tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
1373
+ to_ops[tensor] = TO_CUDA
1374
+ return tensor
1375
+
1376
+ def patch_legacy():
1377
+ torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
1378
+ torch.Tensor.__new__ = _tensor_new_register
1379
+ torch.Tensor.to = _to_op_register
1380
+ torch.Tensor.cuda = _cuda_op_register
1381
+ torch.Tensor.cpu = _cpu_op_remove
1382
+ if Config.zero_patch_torch_device:
1383
+ torch.Tensor.device = _tensor_device_property
1384
+ torch.Tensor.dtype = _tensor_dtype_property
1385
+ for name in GENERIC_METHOD_NAMES: setattr(torch, name, partial(_generic_method_register, name))
1386
+ torch._C._cuda_init = _cuda_init_raise_legacy
1387
+ torch.cuda.is_available = lambda: True
1388
+ torch.cuda.device_count = lambda: 1
1389
+ torch.cuda.current_device = lambda: 0
1390
+ torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO_LEGACY
1391
+ torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_LEGACY
1392
+ torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_LEGACY
1393
+ torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_LEGACY
1394
+ _BitsAndBytesManager().patch()
1395
+
1396
+ def unpatch_legacy():
1397
+ from contextlib import suppress
1398
+ torch.Tensor.__deepcopy__ = _tensor__deepcopy__
1399
+ with suppress(AttributeError): del torch.Tensor.__new__
1400
+ torch.Tensor.to = _tensor_to
1401
+ torch.Tensor.cuda = _tensor_cuda
1402
+ torch.Tensor.cpu = _tensor_cpu
1403
+ with suppress(AttributeError): del torch.Tensor.device
1404
+ with suppress(AttributeError): del torch.Tensor.dtype
1405
+ for name in GENERIC_METHOD_NAMES: setattr(torch, name, _torch_generics[name])
1406
+ torch._C._cuda_init = _cuda_init_legacy
1407
+ torch.cuda.is_available = _cuda_available_legacy
1408
+ torch.cuda.device_count = _cuda_device_count_legacy
1409
+ torch.cuda.current_device = _cuda_current_device_legacy
1410
+ torch.cuda.mem_get_info = _cuda_mem_get_info
1411
+ torch.cuda.get_device_capability = _cuda_get_device_capability_legacy
1412
+ torch.cuda.get_device_properties = _cuda_get_device_properties_legacy
1413
+ torch.cuda.get_device_name = _cuda_get_device_name_legacy
1414
+ _BitsAndBytesManager().unpatch()
1415
+
1416
+ def pack_legacy(): return 0
1417
+ def init_legacy(nvidia_uuid: str):
1418
+ os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
1419
+ torch.Tensor([0]).cuda()
1420
+ def size_legacy(): return 0
1421
+ def move_legacy(callback: Callable[[int], None] | None = None):
1422
+ for tensor, parsed_args in to_ops.items():
1423
+ _, dtype, _, memory_format = parsed_args
1424
+ tensor.data = _tensor_to(tensor, device='cuda', dtype=dtype, memory_format=memory_format)
1425
+ _BitsAndBytesManager().move()
1426
+ torch.cuda.synchronize()
1427
+ def is_in_bad_fork_legacy():
1428
+ return False
1429
+
1430
+ if torch:
1431
+ try:
1432
+ num_threads = torch.get_num_threads()
1433
+ torch.set_num_interop_threads(num_threads)
1434
+ except RuntimeError: pass
1435
+ if Config.zero_gpu_v2:
1436
+ _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_v2, unpatch_v2, pack_v2, init_v2, size_v2, move_v2, is_in_bad_fork_v2
1437
+ else:
1438
+ _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_legacy, unpatch_legacy, pack_legacy, init_legacy, size_legacy, move_legacy, is_in_bad_fork_legacy
1439
+ else:
1440
+ def _placeholder_func(*args, **kwargs): pass
1441
+ def _placeholder_zero(*args, **kwargs): return 0
1442
+ def _placeholder_false(*args, **kwargs): return False
1443
+ _patch, _unpatch, _init, _move = _placeholder_func, _placeholder_func, _placeholder_func, _placeholder_func
1444
+ _pack, _size = _placeholder_zero, _placeholder_zero
1445
+ _is_in_bad_fork = _placeholder_false
1446
+
1447
+ patch_torch, unpatch_torch, pack_torch, init_torch, size_torch, move_torch, is_in_bad_fork_torch = _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork
1448
+
1449
+ _patch_torch_global = patch_torch
1450
+ _unpatch_torch_global = unpatch_torch
1451
+
1452
+ GENERATOR_GLOBAL_TIMEOUT = 20 * 60
1453
+ SPAWN_PROGRESS_CLEANUP, SPAWN_PROGRESS_INIT = 0.1, 0.1
1454
+ forked = False
1455
+
1456
+ class Worker(Generic[Res]):
1457
+ thread: Thread
1458
+ arg_queue: "SimpleQueue[tuple[Params, GradioPartialContext]]"
1459
+ res_queue: "SimpleQueue[Res | None]"
1460
+ _sentinel: "Thread"
1461
+
1462
+ def __init__(self, task: Callable, is_generator: bool, allow_token: str, nvidia_uuid: str):
1463
+ self._sentinel = Thread(target=self._close_on_exit, daemon=True)
1464
+ self.arg_queue = SimpleQueue()
1465
+ self.res_queue = SimpleQueue()
1466
+
1467
+ args = task, is_generator, self.arg_queue, self.res_queue, allow_token, nvidia_uuid, []
1468
+ self.thread = Thread(target=self._worker_thread_wrapper, args=args, daemon=True)
1469
+ self.thread.start()
1470
+ self._sentinel.start()
1471
+
1472
+ def _worker_thread_wrapper(self, task: Callable[..., Any], is_generator: bool, arg_queue: SimpleQueue[tuple[Params, GradioPartialContext]], res_queue: SimpleQueue[Any | None], allow_token: str, nvidia_uuid: str, fds: list[int]):
1473
+ global forked
1474
+ forked = True
1475
+
1476
+ initialized = False
1477
+
1478
+ while True:
1479
+ try:
1480
+ (args, kwargs), gradio_context = arg_queue.get()
1481
+ except (OSError, EOFError): break
1482
+
1483
+ if not initialized:
1484
+ if (init_res := worker_init(res_queue=res_queue, allow_token=allow_token, nvidia_uuid=nvidia_uuid, fds=fds)) is not None:
1485
+ res_queue.put(init_res)
1486
+ return
1487
+ initialized = True
1488
+
1489
+ GradioPartialContext.apply(gradio_context)
1490
+ context = copy_context()
1491
+
1492
+ if is_generator:
1493
+ def iterate():
1494
+ try:
1495
+ gen = task(*args, **kwargs)
1496
+ for res in gen:
1497
+ try:
1498
+ res_queue.put(OkResult(res))
1499
+ except Exception as e:
1500
+ res_queue.put(exception_result(e))
1501
+ break
1502
+ except Exception as e:
1503
+ res_queue.put(exception_result(e))
1504
+ finally:
1505
+ res_queue.put(EndResult())
1506
+
1507
+ with ThreadPoolExecutor(1) as executor:
1508
+ executor.submit(context.run, iterate)
1509
+ else:
1510
+ def run_task():
1511
+ try:
1512
+ res = OkResult(task(*args, **kwargs))
1513
+ except Exception as e:
1514
+ res = exception_result(e)
1515
+ try:
1516
+ res_queue.put(res)
1517
+ except Exception as e:
1518
+ res_queue.put(exception_result(e))
1519
+
1520
+ with ThreadPoolExecutor(1) as executor:
1521
+ future = executor.submit(context.run, run_task)
1522
+ future.result()
1523
+
1524
+ def _close_on_exit(self):
1525
+ self.thread.join()
1526
+ self.arg_queue.close()
1527
+ try:
1528
+ self.res_queue.wlock_release()
1529
+ except Exception:
1530
+ pass
1531
+ self.res_queue.put(None)
1532
+
1533
+ def worker_init(res_queue: Union["SimpleQueue[RegularResQueueResult | None]", "SimpleQueue[GeneratorResQueueResult | None]"], allow_token: str, nvidia_uuid: str, fds: list[int]) -> Optional[ExceptionResult]:
1534
+ for fd in fds:
1535
+ try:
1536
+ os.close(fd)
1537
+ except Exception as e:
1538
+ if isinstance(e, OSError) and e.errno == 9: pass
1539
+ return exception_result(e)
1540
+ try:
1541
+ pass
1542
+ except Exception as e:
1543
+ print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr)
1544
+ progress_context = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w')) if tqdm is not None and Config.zero_gpu_v2 else nullcontext()
1545
+ try:
1546
+ patch_gradio_queue(res_queue)
1547
+ with progress_context as p_bar:
1548
+ current_progress = 0
1549
+ def update(n: float):
1550
+ nonlocal current_progress
1551
+ current_progress += n
1552
+ if p_bar is not None and hasattr(p_bar, 'n'):
1553
+ p_bar.update(round(current_progress * 100) - p_bar.n)
1554
+ allow(allow_token)
1555
+ update(SPAWN_PROGRESS_CLEANUP)
1556
+ _unpatch_torch_global()
1557
+ init_torch(nvidia_uuid)
1558
+ update(SPAWN_PROGRESS_INIT)
1559
+ callback = None
1560
+ if (transfer_size := size_torch()) > 0:
1561
+ remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
1562
+ def _callback(n): return update(n * remaining / transfer_size)
1563
+ callback = _callback
1564
+ move_torch(callback=callback)
1565
+ _patch_torch_global()
1566
+ except Exception as e:
1567
+ return exception_result(e)
1568
+ return None
1569
+
1570
+ def process_duration(duration: Duration | None) -> timedelta:
1571
+ return timedelta(seconds=0)
1572
+
1573
+ def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs) -> timedelta:
1574
+ return timedelta(seconds=0)
1575
+
1576
+ def exception_result(exc: Exception) -> ExceptionResult:
1577
+ formatted = "".join(list(map(str, sys.exc_info())))
1578
+ return ExceptionResult(traceback=formatted, error_cls=exc.__class__.__name__)
1579
+
1580
+ def regular_function_wrapper(task: Callable[Param, Res], duration: DynamicDuration[Param]) -> Callable[Param, Optional[Res]]:
1581
+ request_var_getter = gradio_request_var
1582
+ workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res] | None]] = {}
1583
+ task_id = id(task)
1584
+
1585
+ @wraps(task)
1586
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Optional[Res]:
1587
+ if forked:
1588
+ return task(*args, **kwargs)
1589
+ try:
1590
+ request_var = request_var_getter()
1591
+ request = request_var.get(None) if request_var else None
1592
+ duration_ = static_duration(duration, *args, **kwargs)
1593
+ schedule_response = schedule(task_id=task_id, request=request, duration=duration_)
1594
+ if schedule_response is None:
1595
+ pass
1596
+ allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID
1597
+ release_fn = partial(release, allow_token)
1598
+ worker = workers.pop(nvidia_index, None)
1599
+ if not (worker and worker.thread.is_alive() and schedule_response.idle):
1600
+ worker = Worker(task, False, allow_token, nvidia_uuid)
1601
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
1602
+ while True:
1603
+ res = worker.res_queue.get()
1604
+ if res is None:
1605
+ release_fn(fail=True, allow_404=True)
1606
+ pass
1607
+ if isinstance(res, ExceptionResult):
1608
+ release_fn(fail=True)
1609
+ pass
1610
+ if isinstance(res, OkResult):
1611
+ release_fn()
1612
+ workers[nvidia_index] = worker
1613
+ return res.value
1614
+ if isinstance(res, GradioQueueEvent):
1615
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
1616
+ continue
1617
+ assert_never(res)
1618
+ except Exception as e:
1619
+ print(f"GPU process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr)
1620
+ _unpatch_torch_global()
1621
+ try:
1622
+ return task(*args, **kwargs)
1623
+ except Exception as cpu_e:
1624
+ print(f"CPU fallback execution also failed: {cpu_e}", file=sys.stderr)
1625
+ return None
1626
+ finally:
1627
+ _patch_torch_global()
1628
+
1629
+ if not hasattr(task, '__annotations__'):
1630
+ gradio_handler.__annotations__ = {}
1631
+ return gradio_handler
1632
+
1633
+ def generator_function_wrapper(task: Callable[Param, Generator[Res, None, None]], duration: DynamicDuration[Param]) -> Callable[Param, Generator[Res, None, None]]:
1634
+ request_var_getter = gradio_request_var
1635
+ workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res] | None]] = {}
1636
+ task_id = id(task)
1637
+
1638
+ @wraps(task)
1639
+ def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
1640
+ if forked:
1641
+ yield from task(*args, **kwargs)
1642
+ return
1643
+ try:
1644
+ request_var = request_var_getter()
1645
+ request = request_var.get(None) if request_var else None
1646
+ duration_ = static_duration(duration, *args, **kwargs)
1647
+ schedule_response = schedule(task_id=task_id, request=request, duration=duration_)
1648
+ if schedule_response is None:
1649
+ pass
1650
+ allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID
1651
+ release_fn = partial(release, allow_token)
1652
+ worker = workers.pop(nvidia_index, None)
1653
+ if not (worker and worker.thread.is_alive() and schedule_response.idle):
1654
+ worker = Worker(task, True, allow_token, nvidia_uuid)
1655
+ worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
1656
+ yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
1657
+ def fill_yield_queue(worker_instance):
1658
+ while True:
1659
+ res = worker_instance.res_queue.get()
1660
+ if res is None:
1661
+ release_fn(fail=True, allow_404=True)
1662
+ yield_queue.put(AbortedResult())
1663
+ return
1664
+ if isinstance(res, ExceptionResult):
1665
+ release_fn(fail=True)
1666
+ yield_queue.put(res)
1667
+ return
1668
+ if isinstance(res, EndResult):
1669
+ release_fn()
1670
+ workers[nvidia_index] = worker_instance
1671
+ yield_queue.put(EndResult())
1672
+ return
1673
+ if isinstance(res, OkResult):
1674
+ yield_queue.put(OkResult(res.value))
1675
+ continue
1676
+ if isinstance(res, GradioQueueEvent):
1677
+ try_process_queue_event(res.method_name, *res.args, **res.kwargs)
1678
+ continue
1679
+ assert_never(res)
1680
+ with ThreadPoolExecutor(1) as e:
1681
+ e.submit(copy_context().run, fill_yield_queue, worker)
1682
+ while True:
1683
+ try:
1684
+ res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
1685
+ except Empty:
1686
+ pass
1687
+ if isinstance(res, AbortedResult):
1688
+ pass
1689
+ if isinstance(res, ExceptionResult):
1690
+ pass
1691
+ if isinstance(res, EndResult):
1692
+ return
1693
+ if isinstance(res, OkResult):
1694
+ yield res.value
1695
+ continue
1696
+ assert_never(res)
1697
+ except Exception as e:
1698
+ print(f"GPU generator process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr)
1699
+ _unpatch_torch_global()
1700
+ try:
1701
+ yield from task(*args, **kwargs)
1702
+ except Exception as cpu_e:
1703
+ print(f"CPU fallback execution for generator also failed: {cpu_e}", file=sys.stderr)
1704
+ finally:
1705
+ _patch_torch_global()
1706
+
1707
+ if not hasattr(task, '__annotations__'):
1708
+ gradio_handler.__annotations__ = {}
1709
+ return gradio_handler
1710
+
1711
+ P_decorator = ParamSpec('P_decorator')
1712
+ R_decorator = TypeVar('R_decorator')
1713
+ decorated_cache: dict[Callable, Callable] = {}
1714
+
1715
+ @overload
1716
+ def GPU(task: None = None, *, duration: DynamicDuration[P_decorator] = 0) -> Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]: ...
1717
+
1718
+ @overload
1719
+ def GPU(task: Callable[P_decorator, R_decorator], *, duration: DynamicDuration[P_decorator] = 0) -> Callable[P_decorator, R_decorator]: ...
1720
+
1721
+ def GPU(task: Optional[Callable[P_decorator, R_decorator]] = None, *, duration: DynamicDuration[P_decorator] = 0, **kwargs: Unpack[EmptyKwargs]) -> Union[Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]:
1722
+ if "enable_queue" in kwargs:
1723
+ warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
1724
+ if task is None:
1725
+ return partial(_GPU, duration=duration)
1726
+ return _GPU(task, duration)
1727
+
1728
+ def _GPU(task: Callable[P_decorator, R_decorator], duration: DynamicDuration[P_decorator]) -> Callable[P_decorator, R_decorator]:
1729
+ if not Config.zero_gpu:
1730
+ return task
1731
+ if sys.version_info.minor < 9:
1732
+ print("Error: Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+", file=sys.stderr)
1733
+ return task
1734
+ if task in decorated_cache:
1735
+ return decorated_cache[task]
1736
+ if inspect.iscoroutinefunction(task):
1737
+ print("Error: Coroutine functions are not supported by @spaces.GPU.", file=sys.stderr)
1738
+ return task
1739
+ if inspect.isgeneratorfunction(task):
1740
+ decorated = generator_function_wrapper(task, duration)
1741
+ else:
1742
+ decorated = regular_function_wrapper(task, duration)
1743
+ setattr(decorated, 'zerogpu', True)
1744
+ decorated_cache.update({task: decorated, decorated: decorated})
1745
+ return decorated
1746
+
1747
+ gradio_auto_wrap_enabled = Config.gradio_auto_wrap
1748
+
1749
+ def disable_gradio_auto_wrap() -> None:
1750
+ global gradio_auto_wrap_enabled
1751
+ gradio_auto_wrap_enabled = False
1752
+
1753
+ def enable_gradio_auto_wrap() -> None:
1754
+ global gradio_auto_wrap_enabled
1755
+ gradio_auto_wrap_enabled = True
1756
+
1757
+ @overload
1758
+ def gradio_auto_wrap(task: Callable[Param, Res]) -> Callable[Param, Res]: ...
1759
+
1760
+ @overload
1761
+ def gradio_auto_wrap(task: None) -> None: ...
1762
+
1763
+ def gradio_auto_wrap(task: Optional[Callable[Param, Res]]) -> Optional[Callable[Param, Res]]:
1764
+ if not gradio_auto_wrap_enabled or not callable(task):
1765
+ return task
1766
+ if getattr(task, 'zerogpu', False):
1767
+ return task
1768
+ return GPU(task)
1769
+
1770
+ def _patch_gradio_auto_wrap():
1771
+ if not Config.zero_gpu or not Config.gradio_auto_wrap:
1772
+ return
1773
+
1774
+ try:
1775
+ from gradio.blocks import Block
1776
+ _original_set_event_trigger = Block.set_event_trigger
1777
+ except (ImportError, AttributeError):
1778
+ print("Warning: Could not find gradio.blocks.Block.set_event_trigger for auto-wrap patching. Auto-wrap disabled.", file=sys.stderr)
1779
+ return
1780
+
1781
+ @wraps(_original_set_event_trigger)
1782
+ def _new_set_event_trigger(self, event_name: str, fn: Union[Callable, List[Callable], None], inputs, outputs, **kwargs):
1783
+ if fn is None:
1784
+ return _original_set_event_trigger(self, event_name, fn, inputs, outputs, **kwargs)
1785
+
1786
+ if isinstance(fn, list):
1787
+ wrapped_fns = [gradio_auto_wrap(f) for f in fn]
1788
+ return _original_set_event_trigger(self, event_name, wrapped_fns, inputs, outputs, **kwargs)
1789
+ else:
1790
+ wrapped_fn = gradio_auto_wrap(fn)
1791
+ return _original_set_event_trigger(self, event_name, wrapped_fn, inputs, outputs, **kwargs)
1792
+
1793
+ Block.set_event_trigger = _new_set_event_trigger
1794
+ print("Gradio Block event trigger patched for ZeroGPU auto-wrap.", file=sys.stderr)
1795
+
1796
+ if sys.version_info.minor < 8:
1797
+ print("Warning: Importing PySpaces requires Python 3.8+", file=sys.stderr)
1798
+
1799
+ try:
1800
+ if (gr_module := sys.modules.get("gradio")) is not None:
1801
+ getattr(gr_module, 'Blocks')
1802
+ except AttributeError:
1803
+ print("ImportError: Gradio does not have 'Blocks' attribute. Please check your Gradio installation.", file=sys.stderr)
1804
+ pass
1805
+
1806
+ def aoti_apply(compiled_fn: Any, module: Any):
1807
+ if torch is None:
1808
+ return module
1809
+ if hasattr(module, 'to') and isinstance(module, torch.nn.Module):
1810
+ module.to(device="cpu")
1811
+ return module
1812
+
1813
+ __all__ = ["GPU", "gradio_auto_wrap", "disable_gradio_auto_wrap", "enable_gradio_auto_wrap", "aoti_apply"]
1814
+
1815
+ if Config.zero_gpu:
1816
+ try:
1817
+ if is_in_bad_fork_torch():
1818
+ pass
1819
+ except Exception as e:
1820
+ print(f"Could not check for bad fork: {e}", file=sys.stderr)
1821
+
1822
+ def startup():
1823
+ total_size = pack_torch()
1824
+ _patch_gradio_auto_wrap()
1825
+
1826
+ if Config.zerogpu_size == 'auto':
1827
+ gpu_size = 'medium' if total_size < Config.zerogpu_medium_size_threshold else 'large'
1828
+ else:
1829
+ gpu_size = Config.zerogpu_size
1830
+ startup_report_client(self_cgroup_device_path(), gpu_size)
1831
+
1832
+ _patch_torch_global()
1833
+ one_launch(startup)
1834
+ try:
1835
+ shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True)
1836
+ Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
1837
+ except Exception as e:
1838
+ print(f"Could not prepare ZeroGPU offload directory: {e}", file=sys.stderr)