Hybrid Linear Diffusion Transformer
Model Description
Hybrid Linear Diffusion Transformer is a text-to-image diffusion model trained from scratch in DC-AE latent space using flow matching. The final model is a 148.5M parameter custom hybrid Diffusion Transformer with mostly linear attention, a few full-attention anchor blocks, and T5-base text conditioning.
Architecture
- Parameter count:
148.5M - Training resolution:
512 x 512 - Latent space: DC-AE
f32c32latents with shape32 x 16 x 16for512 x 512images - Text encoder: frozen
google-t5/t5-base - Image codec: frozen
mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers - Transformer:
- width/depth:
dim=768,heads=12,depth=12 - token-funnel image+text processing
- per-block QKV
- per-block adaLN
- mostly linear attention
- full-attention anchor blocks with image-to-text cross-attention
- Mix-FFN with depthwise convolution on image tokens
- width/depth:
Credits And Inspirations
This model is a custom research implementation that combines ideas from several prior systems rather than reproducing any one architecture exactly.
- Used directly
google-t5/t5-baseas the frozen text encodermit-han-lab/dc-ae-f32c32-sana-1.1-diffusersas the frozen latent autoencoder
- Inspired by
- latent DiT-style transformer diffusion models
- Sana-style efficient latent image modeling
- flow-matching / rectified-flow training
- classifier-free guidance sampling
What Changed In This Project
- kept a hybrid token-funnel design instead of a strict dual-stream design
- used mostly linear attention with a few full-attention anchor blocks
- added explicit cross-attention only in the full-attention blocks
- used per-block QKV and per-block adaLN heads
- built a custom latent-cache WebDataset training path for faster large-scale iteration
Training Procedure
- Objective: flow matching
- Timesteps: logit-normal timestep sampling
- Training data path: latent-cache WebDataset shards
- Loader backend: raw
webdataset - Best long-run config:
- batch size: 256
- FP8 with
torchao torch.compileenabled- EMA enabled
In the final training path:
- images were represented in frozen DC-AE latent space
- prompts were encoded with frozen T5-base
- noisy inputs were formed with continuous flow-matching interpolation
- the model predicted latent velocity
- checkpoints were evaluated with both validation loss and visual comparison
The final published checkpoint was chosen visually from the training run rather than by loss alone.
Best Checkpoint
The best checkpoint for this project was selected visually rather than purely by validation loss:
- checkpoint step:
28000 - weight source: EMA
Evaluation
The final checkpoint was evaluated on 512 held-out raw images from ma-xu/fine-t2i.
Results:
- CLIP score (generated):
33.85 ± 3.70 - CLIP score (real images):
34.25 ± 3.63 - Raw-image FID:
91.46 - Raw-image KID:
0.00321 ± 0.00292
Interpretation:
- prompt alignment is strong, since generated-image CLIP score is close to the real-image CLIP score on the same held-out subset
- distribution quality still has room to improve, especially on local detail and faces
Caveat:
FID/KID here should be read as project-scale evaluation metrics rather than paper-scale benchmark claims. They were computed on a relatively small held-out subset and also reflect the full latent autoencoder bottleneck in the final decoded image quality.
Intended Use
- research experiments in latent diffusion / DiT training
- portfolio/demo image generation
- studying custom training stacks with latent caches and hybrid attention
Limitations
- Faces remain softer and blurrier than scene/object composition
- Quality is limited by model capacity, dataset quality, and latent resolution
- This is a research project, not a production-grade photorealistic generator
Training Data
The project used the synthetic_enhanced_prompt_random_resolution subset from the Hugging Face dataset ma-xu/fine-t2i, first as streamed WebDataset shards and then as a latent/text cache for faster training.
The strongest runs used the latent-cache repo:
- source dataset:
ma-xu/fine-t2i akrao9/512t2ilatent- effective training subset used in the final run: about
1.6 millionsamples
As with most text-to-image systems, quality depends heavily on dataset diversity, prompt-image alignment, and the realism of the underlying corpus.
Latent Cache Acceleration
The latent-cache dataset was created to move frozen encoder work out of the training loop.
Instead of repeatedly:
- decoding raw images
- encoding images with DC-AE
- encoding prompts with T5-base
the training loop consumed cached latent tensors and cached text features directly from WebDataset shards.
This improved the project in two ways:
- reduced per-step compute spent on frozen preprocessing
- increased practical throughput and made large-batch training easier
The final strongest runs used this latent-cache path together with the raw webdataset backend.
Limits And Scaling Outlook
Current strengths:
- coherent scene generation
- prompt understanding
- atmosphere, lighting, and object-level structure
Current weaknesses:
- faces remain softer and blurrier than the rest of the image
- local detail is less reliable than global layout
Interpretation:
- the architecture and training stack are working
- the main bottlenecks appear to be model capacity, data quality, and latent/detail limits rather than a hidden correctness bug
Scaling expectation:
This architecture is likely to benefit from larger models and better data, because the current failure mode is mostly detail fidelity rather than complete structural failure. That is a project inference based on the observed training behavior, not a claim of already-completed large-scale scaling experiments.
How To Run
This repository contains a custom architecture, so it is not loaded through a standard Diffusers pipeline.
Install
pip install -r requirements.txt
Run From Command Line
python inference.py \
--ckpt modelhybriddit_ema.safetensors \
--config model_config.json \
--prompt "a glass bridge in the mountains with a volcano" \
--steps 20 \
--cfg 3.5 \
--sampler heun \
--out outputs
Load In Python
import json
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import DiT
repo_id = "akrao9/HybridDiT"
device = "cuda" if torch.cuda.is_available() else "cpu"
weights_path = hf_hub_download(repo_id=repo_id, filename="modelhybriddit_ema.safetensors")
config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json")
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
dit = DiT(
latent_ch=cfg["latent_ch"],
latent_size=cfg["latent_size"],
text_dim=cfg["text_dim"],
text_seq=cfg["text_seq"],
dim=cfg["dit_dim"],
n_heads=cfg["dit_heads"],
n_blocks=cfg["dit_depth"],
text_drop_block=cfg.get("text_drop_block"),
full_attn_blocks=cfg.get("full_attn_blocks"),
).to(device).eval()
state_dict = load_file(weights_path)
dit.load_state_dict(state_dict, strict=True)
print("Model loaded successfully")
Notes
- Best checkpoint:
step_028000EMA - Model size:
148.5Mparameters - Training resolution:
512x512 - Effective training set used: about
1.6 millionsamples - Source dataset:
ma-xu/fine-t2i - Latent cache dataset:
akrao9/512t2ilatent - This is a custom model, not a native Diffusers pipeline
Acknowledgements
google-t5/t5-basemit-han-lab/dc-ae-f32c32-sana-1.1-diffuserstorchaodiffuserswebdataset

