File size: 5,069 Bytes
ff7a767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
---
language:
- en
license: mit
tags:
- image-generation
- autoregressive
- next-scale-prediction
- exposure-bias
- post-training
- pytorch
- imagenet
library_name: pytorch
inference: false
model-index:
- name: ZGZzz/SAR
  results:
  - task:
      type: image-generation
      name: Image Generation
    dataset:
      name: ImageNet 256×256
      type: imagenet-1k
      config: 256x256
      split: validation
    metrics:
    - type: fid
      name: FID (FlexVAR-d16, +SAR)
      value: 2.89
      higher_is_better: false
    - type: fid
      name: FID (FlexVAR-d20, +SAR)
      value: 2.35
      higher_is_better: false
    - type: fid
      name: FID (FlexVAR-d24, +SAR)
      value: 2.14
      higher_is_better: false
datasets:
- ILSVRC/imagenet-1k
base_model:
- jiaosiyu1999/FlexVAR
pipeline_tag: text-to-image
---

<div align="center">
<h1>Rethinking Training Dynamics in Scale-wise Autoregressive Generation</h1>

<a href="https://gengzezhou.github.io/" target="_blank">Gengze Zhou</a><sup>1*</sup>,
<a href="https://chongjiange.github.io/" target="_blank">Chongjian Ge</a><sup>2</sup>,
<a href="https://www.cs.unc.edu/~airsplay/" target="_blank">Hao Tan</a><sup>2</sup>,
<a href="https://pages.cs.wisc.edu/~fliu/" target="_blank">Feng Liu</a><sup>2</sup>,
<a href="https://yiconghong.me" target="_blank">Yicong Hong</a><sup>2</sup>

<sup>1</sup>Australian Institute for Machine Learning, Adelaide University &nbsp;&nbsp;&nbsp;
<sup>2</sup>Adobe Research

[![arXiv](https://img.shields.io/badge/arXiv-2512.06421-b31b1b.svg)](https://arxiv.org/abs/2512.06421)&nbsp;
[![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-SAR--ckpts-yellow)](https://huggingface.co/ZGZzz/SAR)&nbsp;
[![project page](https://img.shields.io/badge/Project%20Page-SAR-blue)](https://gengzezhou.github.io/SAR)&nbsp;
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)

</div>

## Model Description

**Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty.

SAR consists of:
- **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass).
- **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts.

## Key Features

- **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content.
- **General post-training recipe:** applies on top of pretrained scale-wise AR models.
- **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs.

## Model Zoo (ImageNet 256×256)

| Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights |
|---|---:|---:|---:|---|
| SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` |
| SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` |
| SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` |

## How to Use

### Installation

```bash
git clone https://github.com/GengzeZhou/SAR.git
conda create -n sar python=3.10 -y
conda activate sar
pip install -r requirements.txt
# optional
pip install flash-attn xformers
```

### Sampling / Inference (Example)

```python
import torch
from models import build_vae_var
from torchvision.utils import save_image

device = "cuda" if torch.cuda.is_available() else "cpu"

# Build VAE + VAR backbone (example: depth=16)
vae, model = build_vae_var(
    V=8912, Cvae=32, device=device,
    num_classes=1000, depth=16,
    vae_ckpt="pretrained/FlexVAE.pth",
)

# Load SAR checkpoint
ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu")
if "trainer" in ckpt:
    ckpt = ckpt["trainer"]["var_wo_ddp"]
model.load_state_dict(ckpt, strict=False)
model.eval()

with torch.no_grad():
    labels = torch.tensor([207, 88, 360, 387], device=device)  # example ImageNet classes
    images = model.autoregressive_infer_cfg(
        vqvae=vae,
        B=4,
        label_B=labels,
        cfg=2.5,
        top_k=900,
        top_p=0.95,
    )

save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4)
```

## Training (SAR Post-Training)

```bash
bash scripts/train_SAR_d16.sh
bash scripts/train_SAR_d20.sh
bash scripts/train_SAR_d24.sh
```

## Evaluation

```bash
bash scripts/setup_eval.sh
bash scripts/eval_SAR_d16.sh
bash scripts/eval_SAR_d20.sh
bash scripts/eval_SAR_d24.sh
```

## Acknowledgements

This codebase builds upon **VAR** and **FlexVAR**.

## Citation

```bibtex
@article{zhou2025rethinking,
  title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation},
  author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong},
  journal={arXiv preprint arXiv:2512.06421},
  year={2025}
}
```