ZGZzz commited on
Commit
ff7a767
·
verified ·
1 Parent(s): 79c8ca3

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +167 -0
README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ tags:
6
+ - image-generation
7
+ - autoregressive
8
+ - next-scale-prediction
9
+ - exposure-bias
10
+ - post-training
11
+ - pytorch
12
+ - imagenet
13
+ library_name: pytorch
14
+ inference: false
15
+ model-index:
16
+ - name: ZGZzz/SAR
17
+ results:
18
+ - task:
19
+ type: image-generation
20
+ name: Image Generation
21
+ dataset:
22
+ name: ImageNet 256×256
23
+ type: imagenet-1k
24
+ config: 256x256
25
+ split: validation
26
+ metrics:
27
+ - type: fid
28
+ name: FID (FlexVAR-d16, +SAR)
29
+ value: 2.89
30
+ higher_is_better: false
31
+ - type: fid
32
+ name: FID (FlexVAR-d20, +SAR)
33
+ value: 2.35
34
+ higher_is_better: false
35
+ - type: fid
36
+ name: FID (FlexVAR-d24, +SAR)
37
+ value: 2.14
38
+ higher_is_better: false
39
+ datasets:
40
+ - ILSVRC/imagenet-1k
41
+ base_model:
42
+ - jiaosiyu1999/FlexVAR
43
+ pipeline_tag: text-to-image
44
+ ---
45
+
46
+ <div align="center">
47
+ <h1>Rethinking Training Dynamics in Scale-wise Autoregressive Generation</h1>
48
+
49
+ <a href="https://gengzezhou.github.io/" target="_blank">Gengze Zhou</a><sup>1*</sup>,
50
+ <a href="https://chongjiange.github.io/" target="_blank">Chongjian Ge</a><sup>2</sup>,
51
+ <a href="https://www.cs.unc.edu/~airsplay/" target="_blank">Hao Tan</a><sup>2</sup>,
52
+ <a href="https://pages.cs.wisc.edu/~fliu/" target="_blank">Feng Liu</a><sup>2</sup>,
53
+ <a href="https://yiconghong.me" target="_blank">Yicong Hong</a><sup>2</sup>
54
+
55
+ <sup>1</sup>Australian Institute for Machine Learning, Adelaide University &nbsp;&nbsp;&nbsp;
56
+ <sup>2</sup>Adobe Research
57
+
58
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.06421-b31b1b.svg)](https://arxiv.org/abs/2512.06421)&nbsp;
59
+ [![huggingface weights](https://img.shields.io/badge/%F0%9F%A4%97%20Weights-SAR--ckpts-yellow)](https://huggingface.co/ZGZzz/SAR)&nbsp;
60
+ [![project page](https://img.shields.io/badge/Project%20Page-SAR-blue)](https://gengzezhou.github.io/SAR)&nbsp;
61
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
62
+
63
+ </div>
64
+
65
+ ## Model Description
66
+
67
+ **Self-Autoregressive Refinement (SAR)** is a lightweight *post-training* algorithm for **scale-wise autoregressive (AR)** image generation (next-scale prediction). SAR mitigates **exposure bias** by addressing (1) train–test mismatch (teacher forcing vs. student forcing) and (2) imbalance in scale-wise learning difficulty.
68
+
69
+ SAR consists of:
70
+ - **Stagger-Scale Rollout (SSR):** a two-step rollout (teacher-forcing → student-forcing) with minimal compute overhead (one extra forward pass).
71
+ - **Contrastive Student-Forcing Loss (CSFL):** stabilizes student-forced training by aligning predictions with a teacher trajectory under self-generated contexts.
72
+
73
+ ## Key Features
74
+
75
+ - **Minimal overhead:** SSR adds only a lightweight additional forward pass to train on self-generated content.
76
+ - **General post-training recipe:** applies on top of pretrained scale-wise AR models.
77
+ - **Empirical gains:** e.g., reported **5.2% FID reduction** on FlexVAR-d16 with 10 SAR epochs.
78
+
79
+ ## Model Zoo (ImageNet 256×256)
80
+
81
+ | Model | Params | Base FID ↓ | SAR FID ↓ | SAR Weights |
82
+ |---|---:|---:|---:|---|
83
+ | SAR-d16 | 310M | 3.05 | **2.89** | `pretrained/SARd16-epo179.pth` |
84
+ | SAR-d20 | 600M | 2.41 | **2.35** | `pretrained/SARd20-epo249.pth` |
85
+ | SAR-d24 | 1.0B | 2.21 | **2.14** | `pretrained/SARd24-epo349.pth` |
86
+
87
+ ## How to Use
88
+
89
+ ### Installation
90
+
91
+ ```bash
92
+ git clone https://github.com/GengzeZhou/SAR.git
93
+ conda create -n sar python=3.10 -y
94
+ conda activate sar
95
+ pip install -r requirements.txt
96
+ # optional
97
+ pip install flash-attn xformers
98
+ ```
99
+
100
+ ### Sampling / Inference (Example)
101
+
102
+ ```python
103
+ import torch
104
+ from models import build_vae_var
105
+ from torchvision.utils import save_image
106
+
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+
109
+ # Build VAE + VAR backbone (example: depth=16)
110
+ vae, model = build_vae_var(
111
+ V=8912, Cvae=32, device=device,
112
+ num_classes=1000, depth=16,
113
+ vae_ckpt="pretrained/FlexVAE.pth",
114
+ )
115
+
116
+ # Load SAR checkpoint
117
+ ckpt = torch.load("pretrained/SARd16-epo179.pth", map_location="cpu")
118
+ if "trainer" in ckpt:
119
+ ckpt = ckpt["trainer"]["var_wo_ddp"]
120
+ model.load_state_dict(ckpt, strict=False)
121
+ model.eval()
122
+
123
+ with torch.no_grad():
124
+ labels = torch.tensor([207, 88, 360, 387], device=device) # example ImageNet classes
125
+ images = model.autoregressive_infer_cfg(
126
+ vqvae=vae,
127
+ B=4,
128
+ label_B=labels,
129
+ cfg=2.5,
130
+ top_k=900,
131
+ top_p=0.95,
132
+ )
133
+
134
+ save_image(images, "samples.png", normalize=True, value_range=(-1, 1), nrow=4)
135
+ ```
136
+
137
+ ## Training (SAR Post-Training)
138
+
139
+ ```bash
140
+ bash scripts/train_SAR_d16.sh
141
+ bash scripts/train_SAR_d20.sh
142
+ bash scripts/train_SAR_d24.sh
143
+ ```
144
+
145
+ ## Evaluation
146
+
147
+ ```bash
148
+ bash scripts/setup_eval.sh
149
+ bash scripts/eval_SAR_d16.sh
150
+ bash scripts/eval_SAR_d20.sh
151
+ bash scripts/eval_SAR_d24.sh
152
+ ```
153
+
154
+ ## Acknowledgements
155
+
156
+ This codebase builds upon **VAR** and **FlexVAR**.
157
+
158
+ ## Citation
159
+
160
+ ```bibtex
161
+ @article{zhou2025rethinking,
162
+ title={Rethinking Training Dynamics in Scale-wise Autoregressive Generation},
163
+ author={Zhou, Gengze and Ge, Chongjian and Tan, Hao and Liu, Feng and Hong, Yicong},
164
+ journal={arXiv preprint arXiv:2512.06421},
165
+ year={2025}
166
+ }
167
+ ```