Instructions to use xiangjx/MuPaD-512 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use xiangjx/MuPaD-512 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("xiangjx/MuPaD-512", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import numpy as np | |
| def expand_t_like_x(t, x_cur): | |
| """Function to reshape time t to broadcastable dimension of x | |
| Args: | |
| t: [batch_dim,], time vector | |
| x: [batch_dim,...], data point | |
| """ | |
| dims = [1] * (len(x_cur.size()) - 1) | |
| t = t.view(t.size(0), *dims) | |
| return t | |
| def get_score_from_velocity(vt, xt, t, path_type="linear"): | |
| """Wrapper function: transfrom velocity prediction model to score | |
| Args: | |
| velocity: [batch_dim, ...] shaped tensor; velocity model output | |
| x: [batch_dim, ...] shaped tensor; x_t data point | |
| t: [batch_dim,] time tensor | |
| """ | |
| t = expand_t_like_x(t, xt) | |
| if path_type == "linear": | |
| alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 | |
| sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) | |
| elif path_type == "cosine": | |
| alpha_t = torch.cos(t * np.pi / 2) | |
| sigma_t = torch.sin(t * np.pi / 2) | |
| d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) | |
| d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) | |
| else: | |
| raise NotImplementedError | |
| mean = xt | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t | |
| score = (reverse_alpha_ratio * vt - mean) / var | |
| return score | |
| def compute_diffusion(t_cur): | |
| return 2 * t_cur | |
| def _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale): | |
| if cfg_scale <= 1.0: | |
| return None, None, None | |
| cond_null = torch.zeros_like(conditioning, device=conditioning.device) | |
| mask_null = conditioning_mask.clone() if conditioning_mask is not None else None | |
| mod_null = modality_ids.clone() if modality_ids is not None else None | |
| return cond_null, mask_null, mod_null | |
| def _apply_cfg(branch, cfg_scale): | |
| branch_cond, branch_uncond = branch.chunk(2) | |
| return branch_uncond + cfg_scale * (branch_cond - branch_uncond) | |
| def euler_sampler( | |
| model, | |
| latents, | |
| *, | |
| conditioning=None, | |
| conditioning_mask=None, | |
| modality_ids=None, | |
| num_steps=20, | |
| heun=False, | |
| cfg_scale=1.0, | |
| guidance_low=0.0, | |
| guidance_high=1.0, | |
| path_type="linear", # not used, just for compatability | |
| cls_latents=None, | |
| ): | |
| """Euler sampler supporting both CLS and multimodal conditioning.""" | |
| cond_null, mask_null, mod_null = _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale) | |
| _dtype = latents.dtype | |
| t_steps = torch.linspace(1, 0, num_steps + 1, dtype=torch.float64) | |
| x_next = latents.to(torch.float64) | |
| cls_next = cls_latents.to(torch.float64) if cls_latents is not None else None | |
| device = x_next.device | |
| def _infer(model_input, cond_input, mask_input, t_scalar, cls_input, modality_input): | |
| lat_out, cls_out = model.inference( | |
| model_input.to(dtype=_dtype), | |
| t_scalar.to(dtype=_dtype), | |
| conditioning=cond_input.to(dtype=_dtype), | |
| conditioning_mask=mask_input, | |
| modality_ids=modality_input, | |
| cls_token=None if cls_input is None else cls_input.to(dtype=_dtype), | |
| ) | |
| lat_out = lat_out.to(torch.float64) | |
| cls_out = None if cls_out is None else cls_out.to(torch.float64) | |
| return lat_out, cls_out | |
| with torch.no_grad(): | |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): | |
| x_cur = x_next | |
| cls_cur = cls_next | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| model_input = torch.cat([x_cur] * 2, dim=0) | |
| cond_cur = torch.cat([conditioning, cond_null], dim=0) | |
| mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0) | |
| modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0) | |
| if cls_cur is not None: | |
| cls_model_input = torch.cat([cls_cur] * 2, dim=0) | |
| else: | |
| cls_model_input = None | |
| else: | |
| model_input = x_cur | |
| cond_cur = conditioning | |
| mask_cur = conditioning_mask | |
| modality_cur = modality_ids | |
| cls_model_input = cls_cur | |
| time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur | |
| d_cur, cls_d_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur) | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| d_cur = _apply_cfg(d_cur, cfg_scale) | |
| if cls_d_cur is not None: | |
| cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale) | |
| x_next = x_cur + (t_next - t_cur) * d_cur | |
| # --- ADD THIS BLOCK --- | |
| # Clamp latents to a reasonable range to prevent edge explosion. | |
| # Standard VAE latents are usually roughly N(0,1). | |
| # Values beyond +/- 5.0 are almost certainly artifacts/outliers. | |
| x_next = x_next.clamp(-5.0, 5.0) | |
| # ---------------------- | |
| if cls_cur is not None and cls_d_cur is not None: | |
| cls_next = cls_cur + (t_next - t_cur) * cls_d_cur | |
| if heun and (i < num_steps - 1): | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| model_input = torch.cat([x_next] * 2) | |
| cond_cur = torch.cat([conditioning, cond_null], dim=0) | |
| mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0) | |
| modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0) | |
| if cls_next is not None: | |
| cls_model_input = torch.cat([cls_next] * 2, dim=0) | |
| else: | |
| cls_model_input = None | |
| else: | |
| model_input = x_next | |
| cond_cur = conditioning | |
| mask_cur = conditioning_mask | |
| modality_cur = modality_ids | |
| cls_model_input = cls_next | |
| time_input = torch.ones(model_input.size(0), device=model_input.device, dtype=torch.float64) * t_next | |
| d_prime, cls_d_prime = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur) | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| d_prime = _apply_cfg(d_prime, cfg_scale) | |
| if cls_d_prime is not None: | |
| cls_d_prime = _apply_cfg(cls_d_prime, cfg_scale) | |
| x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime) | |
| if cls_next is not None and cls_d_prime is not None and cls_d_cur is not None: | |
| cls_next = cls_cur + (t_next - t_cur) * (0.5 * cls_d_cur + 0.5 * cls_d_prime) | |
| return x_next | |
| def euler_maruyama_sampler( | |
| model, | |
| latents, | |
| *, | |
| conditioning=None, | |
| conditioning_mask=None, | |
| modality_ids=None, | |
| num_steps=20, | |
| heun=False, # not used, just for compatability | |
| cfg_scale=1.0, | |
| guidance_low=0.0, | |
| guidance_high=1.0, | |
| path_type="linear", | |
| cls_latents=None, | |
| ): | |
| cond_null, mask_null, mod_null = _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale) | |
| _dtype = latents.dtype | |
| t_steps = torch.linspace(1.0, 0.04, num_steps, dtype=torch.float64) | |
| t_steps = torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)]) | |
| x_next = latents.to(torch.float64) | |
| cls_next = cls_latents.to(torch.float64) if cls_latents is not None else None | |
| device = x_next.device | |
| def _infer(model_input, cond_input, mask_input, t_scalar, cls_input, modality_input): | |
| lat_out, cls_out = model.inference( | |
| model_input.to(dtype=_dtype), | |
| t_scalar.to(dtype=_dtype), | |
| conditioning=cond_input.to(dtype=_dtype), | |
| conditioning_mask=mask_input, | |
| modality_ids=modality_input, | |
| cls_token=None if cls_input is None else cls_input.to(dtype=_dtype), | |
| ) | |
| lat_out = lat_out.to(torch.float64) | |
| cls_out = None if cls_out is None else cls_out.to(torch.float64) | |
| return lat_out, cls_out | |
| with torch.no_grad(): | |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): | |
| dt = t_next - t_cur | |
| x_cur = x_next | |
| cls_cur = cls_next | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| model_input = torch.cat([x_cur] * 2, dim=0) | |
| cond_cur = torch.cat([conditioning, cond_null], dim=0) | |
| mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0) | |
| modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0) | |
| if cls_cur is not None: | |
| cls_model_input = torch.cat([cls_cur] * 2, dim=0) | |
| else: | |
| cls_model_input = None | |
| else: | |
| model_input = x_cur | |
| cond_cur = conditioning | |
| mask_cur = conditioning_mask | |
| modality_cur = modality_ids | |
| cls_model_input = cls_cur | |
| time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur | |
| diffusion = compute_diffusion(t_cur) | |
| eps_i = torch.randn_like(x_cur, device=device) | |
| deps = eps_i * torch.sqrt(torch.abs(dt)) | |
| if cls_cur is not None: | |
| cls_eps = torch.randn_like(cls_cur, device=device) | |
| cls_deps = cls_eps * torch.sqrt(torch.abs(dt)) | |
| else: | |
| cls_deps = None | |
| v_cur, cls_v_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur) | |
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) | |
| d_cur = v_cur - 0.5 * diffusion * s_cur | |
| if cls_v_cur is not None and cls_cur is not None: | |
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) | |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur | |
| else: | |
| cls_d_cur = None | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| d_cur = _apply_cfg(d_cur, cfg_scale) | |
| if cls_d_cur is not None: | |
| cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale) | |
| x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps | |
| # --- ADD THIS BLOCK --- | |
| x_next = x_next.clamp(-5.0, 5.0) | |
| # ---------------------- | |
| if cls_cur is not None and cls_d_cur is not None and cls_deps is not None: | |
| cls_next = cls_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps | |
| t_cur, t_next = t_steps[-2], t_steps[-1] | |
| dt = t_next - t_cur | |
| x_cur = x_next | |
| cls_cur = cls_next | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| model_input = torch.cat([x_cur] * 2, dim=0) | |
| cond_cur = torch.cat([conditioning, cond_null], dim=0) | |
| mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0) | |
| if cls_cur is not None: | |
| cls_model_input = torch.cat([cls_cur] * 2, dim=0) | |
| else: | |
| cls_model_input = None | |
| else: | |
| model_input = x_cur | |
| cond_cur = conditioning | |
| mask_cur = conditioning_mask | |
| cls_model_input = cls_cur | |
| time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur | |
| v_cur, cls_v_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur) | |
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) | |
| diffusion = compute_diffusion(t_cur) | |
| d_cur = v_cur - 0.5 * diffusion * s_cur | |
| if cls_v_cur is not None and cls_model_input is not None: | |
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) | |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur | |
| else: | |
| cls_d_cur = None | |
| if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high: | |
| d_cur = _apply_cfg(d_cur, cfg_scale) | |
| if cls_d_cur is not None: | |
| cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale) | |
| mean_x = x_cur + dt * d_cur | |
| # cls_mean_x is intentionally not returned; cls trajectories are internal to maintain coupling. | |
| return mean_x | |