MuPaD-512 / samplers.py
xiangjx's picture
Initial model upload via upload_to_hf.py
6fc352e verified
import torch
import numpy as np
def expand_t_like_x(t, x_cur):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * (len(x_cur.size()) - 1)
t = t.view(t.size(0), *dims)
return t
def get_score_from_velocity(vt, xt, t, path_type="linear"):
"""Wrapper function: transfrom velocity prediction model to score
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, xt)
if path_type == "linear":
alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
elif path_type == "cosine":
alpha_t = torch.cos(t * np.pi / 2)
sigma_t = torch.sin(t * np.pi / 2)
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
else:
raise NotImplementedError
mean = xt
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * vt - mean) / var
return score
def compute_diffusion(t_cur):
return 2 * t_cur
def _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale):
if cfg_scale <= 1.0:
return None, None, None
cond_null = torch.zeros_like(conditioning, device=conditioning.device)
mask_null = conditioning_mask.clone() if conditioning_mask is not None else None
mod_null = modality_ids.clone() if modality_ids is not None else None
return cond_null, mask_null, mod_null
def _apply_cfg(branch, cfg_scale):
branch_cond, branch_uncond = branch.chunk(2)
return branch_uncond + cfg_scale * (branch_cond - branch_uncond)
def euler_sampler(
model,
latents,
*,
conditioning=None,
conditioning_mask=None,
modality_ids=None,
num_steps=20,
heun=False,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear", # not used, just for compatability
cls_latents=None,
):
"""Euler sampler supporting both CLS and multimodal conditioning."""
cond_null, mask_null, mod_null = _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale)
_dtype = latents.dtype
t_steps = torch.linspace(1, 0, num_steps + 1, dtype=torch.float64)
x_next = latents.to(torch.float64)
cls_next = cls_latents.to(torch.float64) if cls_latents is not None else None
device = x_next.device
def _infer(model_input, cond_input, mask_input, t_scalar, cls_input, modality_input):
lat_out, cls_out = model.inference(
model_input.to(dtype=_dtype),
t_scalar.to(dtype=_dtype),
conditioning=cond_input.to(dtype=_dtype),
conditioning_mask=mask_input,
modality_ids=modality_input,
cls_token=None if cls_input is None else cls_input.to(dtype=_dtype),
)
lat_out = lat_out.to(torch.float64)
cls_out = None if cls_out is None else cls_out.to(torch.float64)
return lat_out, cls_out
with torch.no_grad():
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
x_cur = x_next
cls_cur = cls_next
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
model_input = torch.cat([x_cur] * 2, dim=0)
cond_cur = torch.cat([conditioning, cond_null], dim=0)
mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0)
modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0)
if cls_cur is not None:
cls_model_input = torch.cat([cls_cur] * 2, dim=0)
else:
cls_model_input = None
else:
model_input = x_cur
cond_cur = conditioning
mask_cur = conditioning_mask
modality_cur = modality_ids
cls_model_input = cls_cur
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
d_cur, cls_d_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur)
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
d_cur = _apply_cfg(d_cur, cfg_scale)
if cls_d_cur is not None:
cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale)
x_next = x_cur + (t_next - t_cur) * d_cur
# --- ADD THIS BLOCK ---
# Clamp latents to a reasonable range to prevent edge explosion.
# Standard VAE latents are usually roughly N(0,1).
# Values beyond +/- 5.0 are almost certainly artifacts/outliers.
x_next = x_next.clamp(-5.0, 5.0)
# ----------------------
if cls_cur is not None and cls_d_cur is not None:
cls_next = cls_cur + (t_next - t_cur) * cls_d_cur
if heun and (i < num_steps - 1):
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
model_input = torch.cat([x_next] * 2)
cond_cur = torch.cat([conditioning, cond_null], dim=0)
mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0)
modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0)
if cls_next is not None:
cls_model_input = torch.cat([cls_next] * 2, dim=0)
else:
cls_model_input = None
else:
model_input = x_next
cond_cur = conditioning
mask_cur = conditioning_mask
modality_cur = modality_ids
cls_model_input = cls_next
time_input = torch.ones(model_input.size(0), device=model_input.device, dtype=torch.float64) * t_next
d_prime, cls_d_prime = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur)
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
d_prime = _apply_cfg(d_prime, cfg_scale)
if cls_d_prime is not None:
cls_d_prime = _apply_cfg(cls_d_prime, cfg_scale)
x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)
if cls_next is not None and cls_d_prime is not None and cls_d_cur is not None:
cls_next = cls_cur + (t_next - t_cur) * (0.5 * cls_d_cur + 0.5 * cls_d_prime)
return x_next
def euler_maruyama_sampler(
model,
latents,
*,
conditioning=None,
conditioning_mask=None,
modality_ids=None,
num_steps=20,
heun=False, # not used, just for compatability
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
):
cond_null, mask_null, mod_null = _prepare_cfg_tensors(conditioning, conditioning_mask, modality_ids, cfg_scale)
_dtype = latents.dtype
t_steps = torch.linspace(1.0, 0.04, num_steps, dtype=torch.float64)
t_steps = torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
x_next = latents.to(torch.float64)
cls_next = cls_latents.to(torch.float64) if cls_latents is not None else None
device = x_next.device
def _infer(model_input, cond_input, mask_input, t_scalar, cls_input, modality_input):
lat_out, cls_out = model.inference(
model_input.to(dtype=_dtype),
t_scalar.to(dtype=_dtype),
conditioning=cond_input.to(dtype=_dtype),
conditioning_mask=mask_input,
modality_ids=modality_input,
cls_token=None if cls_input is None else cls_input.to(dtype=_dtype),
)
lat_out = lat_out.to(torch.float64)
cls_out = None if cls_out is None else cls_out.to(torch.float64)
return lat_out, cls_out
with torch.no_grad():
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
dt = t_next - t_cur
x_cur = x_next
cls_cur = cls_next
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
model_input = torch.cat([x_cur] * 2, dim=0)
cond_cur = torch.cat([conditioning, cond_null], dim=0)
mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0)
modality_cur = None if modality_ids is None else torch.cat([modality_ids, mod_null], dim=0)
if cls_cur is not None:
cls_model_input = torch.cat([cls_cur] * 2, dim=0)
else:
cls_model_input = None
else:
model_input = x_cur
cond_cur = conditioning
mask_cur = conditioning_mask
modality_cur = modality_ids
cls_model_input = cls_cur
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
eps_i = torch.randn_like(x_cur, device=device)
deps = eps_i * torch.sqrt(torch.abs(dt))
if cls_cur is not None:
cls_eps = torch.randn_like(cls_cur, device=device)
cls_deps = cls_eps * torch.sqrt(torch.abs(dt))
else:
cls_deps = None
v_cur, cls_v_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
d_cur = v_cur - 0.5 * diffusion * s_cur
if cls_v_cur is not None and cls_cur is not None:
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
else:
cls_d_cur = None
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
d_cur = _apply_cfg(d_cur, cfg_scale)
if cls_d_cur is not None:
cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale)
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
# --- ADD THIS BLOCK ---
x_next = x_next.clamp(-5.0, 5.0)
# ----------------------
if cls_cur is not None and cls_d_cur is not None and cls_deps is not None:
cls_next = cls_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
t_cur, t_next = t_steps[-2], t_steps[-1]
dt = t_next - t_cur
x_cur = x_next
cls_cur = cls_next
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
model_input = torch.cat([x_cur] * 2, dim=0)
cond_cur = torch.cat([conditioning, cond_null], dim=0)
mask_cur = None if conditioning_mask is None else torch.cat([conditioning_mask, mask_null], dim=0)
if cls_cur is not None:
cls_model_input = torch.cat([cls_cur] * 2, dim=0)
else:
cls_model_input = None
else:
model_input = x_cur
cond_cur = conditioning
mask_cur = conditioning_mask
cls_model_input = cls_cur
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
v_cur, cls_v_cur = _infer(model_input, cond_cur, mask_cur, time_input, cls_model_input, modality_cur)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
diffusion = compute_diffusion(t_cur)
d_cur = v_cur - 0.5 * diffusion * s_cur
if cls_v_cur is not None and cls_model_input is not None:
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
else:
cls_d_cur = None
if cfg_scale > 1.0 and guidance_low <= t_cur <= guidance_high:
d_cur = _apply_cfg(d_cur, cfg_scale)
if cls_d_cur is not None:
cls_d_cur = _apply_cfg(cls_d_cur, cfg_scale)
mean_x = x_cur + dt * d_cur
# cls_mean_x is intentionally not returned; cls trajectories are internal to maintain coupling.
return mean_x