File size: 5,069 Bytes
ff7a767 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
---
language:
- en
license: mit
tags:
- image-generation
- autoregressive
- next-scale-prediction
- exposure-bias
- post-training
- pytorch
- imagenet
library_name: pytorch
inference: false
model-index:
- name: ZGZzz/SAR
results:
- task:
type: image-generation
name: Image Generation
dataset:
name: ImageNet 256×256
type: imagenet-1k
config: 256x256
split: validation
metrics:
- type: fid
name: FID (FlexVAR-d16, +SAR)
value: 2.89
higher_is_better: false
- type: fid
name: FID (FlexVAR-d20, +SAR)
value: 2.35
higher_is_better: false
- type: fid
name: FID (FlexVAR-d24, +SAR)
value: 2.14
higher_is_better: false
datasets:
- ILSVRC/imagenet-1k
base_model:
- jiaosiyu1999/FlexVAR
pipeline_tag: text-to-image
---
<div align="center">
<h1>Rethinking Training Dynamics in Scale-wise Autoregressive Generation</h1>
<a href="https://gengzezhou.github.io/" target="_blank">Gengze Zhou</a><sup>1*</sup>,
<a href="https://chongjiange.github.io/" target="_blank">Chongjian Ge</a><sup>2</sup>,
<a href="https://www.cs.unc.edu/~airsplay/" target="_blank">Hao Tan</a><sup>2</sup>,
<a href="https://pages.cs.wisc.edu/~fliu/" target="_blank">Feng Liu</a><sup>2</sup>,
<a href="https://yiconghong.me" target="_blank">Yicong Hong</a><sup>2</sup>
<sup>1</sup>Australian Institute for Machine Learning, Adelaide University
<sup>2</sup>Adobe Research
[](https://arxiv.org/abs/2512.06421)
[](https://huggingface.co/ZGZzz/SAR)
[](https://gengzezhou.github.io/SAR)
[](https://opensource.org/licenses/MIT)
</div>
## Model Description
**Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty.
SAR consists of:
- **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass).
- **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts.
## Key Features
- **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content.
- **General post-training recipe:** applies on top of pretrained scale-wise AR models.
- **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs.
## Model Zoo (ImageNet 256×256)
| Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights |
|---|---:|---:|---:|---|
| SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` |
| SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` |
| SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` |
## How to Use
### Installation
```bash
git clone https://github.com/GengzeZhou/SAR.git
conda create -n sar python=3.10 -y
conda activate sar
pip install -r requirements.txt
# optional
pip install flash-attn xformers
```
### Sampling / Inference (Example)
```python
import torch
from models import build_vae_var
from torchvision.utils import save_image
device = "cuda" if torch.cuda.is_available() else "cpu"
# Build VAE + VAR backbone (example: depth=16)
vae, model = build_vae_var(
V=8912, Cvae=32, device=device,
num_classes=1000, depth=16,
vae_ckpt="pretrained/FlexVAE.pth",
)
# Load SAR checkpoint
ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu")
if "trainer" in ckpt:
ckpt = ckpt["trainer"]["var_wo_ddp"]
model.load_state_dict(ckpt, strict=False)
model.eval()
with torch.no_grad():
labels = torch.tensor([207, 88, 360, 387], device=device) # example ImageNet classes
images = model.autoregressive_infer_cfg(
vqvae=vae,
B=4,
label_B=labels,
cfg=2.5,
top_k=900,
top_p=0.95,
)
save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4)
```
## Training (SAR Post-Training)
```bash
bash scripts/train_SAR_d16.sh
bash scripts/train_SAR_d20.sh
bash scripts/train_SAR_d24.sh
```
## Evaluation
```bash
bash scripts/setup_eval.sh
bash scripts/eval_SAR_d16.sh
bash scripts/eval_SAR_d20.sh
bash scripts/eval_SAR_d24.sh
```
## Acknowledgements
This codebase builds upon **VAR** and **FlexVAR**.
## Citation
```bibtex
@article{zhou2025rethinking,
title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation},
author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong},
journal={arXiv preprint arXiv:2512.06421},
year={2025}
}
``` |