Hybrid Linear Diffusion Transformer

Model Description

Hybrid Linear Diffusion Transformer is a text-to-image diffusion model trained from scratch in DC-AE latent space using flow matching. The final model is a 148.5M parameter custom hybrid Diffusion Transformer with mostly linear attention, a few full-attention anchor blocks, and T5-base text conditioning.

Architecture

  • Parameter count: 148.5M
  • Training resolution: 512 x 512
  • Latent space: DC-AE f32c32 latents with shape 32 x 16 x 16 for 512 x 512 images
  • Text encoder: frozen google-t5/t5-base
  • Image codec: frozen mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  • Transformer:
    • width/depth: dim=768, heads=12, depth=12
    • token-funnel image+text processing
    • per-block QKV
    • per-block adaLN
    • mostly linear attention
    • full-attention anchor blocks with image-to-text cross-attention
    • Mix-FFN with depthwise convolution on image tokens

gridlatest

modelarchituecture

Credits And Inspirations

This model is a custom research implementation that combines ideas from several prior systems rather than reproducing any one architecture exactly.

  • Used directly
    • google-t5/t5-base as the frozen text encoder
    • mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers as the frozen latent autoencoder
  • Inspired by
    • latent DiT-style transformer diffusion models
    • Sana-style efficient latent image modeling
    • flow-matching / rectified-flow training
    • classifier-free guidance sampling

What Changed In This Project

  • kept a hybrid token-funnel design instead of a strict dual-stream design
  • used mostly linear attention with a few full-attention anchor blocks
  • added explicit cross-attention only in the full-attention blocks
  • used per-block QKV and per-block adaLN heads
  • built a custom latent-cache WebDataset training path for faster large-scale iteration

Training Procedure

  • Objective: flow matching
  • Timesteps: logit-normal timestep sampling
  • Training data path: latent-cache WebDataset shards
  • Loader backend: raw webdataset
  • Best long-run config:
    • batch size: 256
    • FP8 with torchao
    • torch.compile enabled
    • EMA enabled

In the final training path:

  1. images were represented in frozen DC-AE latent space
  2. prompts were encoded with frozen T5-base
  3. noisy inputs were formed with continuous flow-matching interpolation
  4. the model predicted latent velocity
  5. checkpoints were evaluated with both validation loss and visual comparison

The final published checkpoint was chosen visually from the training run rather than by loss alone.

Best Checkpoint

The best checkpoint for this project was selected visually rather than purely by validation loss:

  • checkpoint step: 28000
  • weight source: EMA

Evaluation

The final checkpoint was evaluated on 512 held-out raw images from ma-xu/fine-t2i.

Results:

  • CLIP score (generated): 33.85 ± 3.70
  • CLIP score (real images): 34.25 ± 3.63
  • Raw-image FID: 91.46
  • Raw-image KID: 0.00321 ± 0.00292

Interpretation:

  • prompt alignment is strong, since generated-image CLIP score is close to the real-image CLIP score on the same held-out subset
  • distribution quality still has room to improve, especially on local detail and faces

Caveat:

FID/KID here should be read as project-scale evaluation metrics rather than paper-scale benchmark claims. They were computed on a relatively small held-out subset and also reflect the full latent autoencoder bottleneck in the final decoded image quality.

Intended Use

  • research experiments in latent diffusion / DiT training
  • portfolio/demo image generation
  • studying custom training stacks with latent caches and hybrid attention

Limitations

  • Faces remain softer and blurrier than scene/object composition
  • Quality is limited by model capacity, dataset quality, and latent resolution
  • This is a research project, not a production-grade photorealistic generator

Training Data

The project used the synthetic_enhanced_prompt_random_resolution subset from the Hugging Face dataset ma-xu/fine-t2i, first as streamed WebDataset shards and then as a latent/text cache for faster training.

The strongest runs used the latent-cache repo:

  • source dataset: ma-xu/fine-t2i
  • akrao9/512t2ilatent
  • effective training subset used in the final run: about 1.6 million samples

As with most text-to-image systems, quality depends heavily on dataset diversity, prompt-image alignment, and the realism of the underlying corpus.

Latent Cache Acceleration

The latent-cache dataset was created to move frozen encoder work out of the training loop.

Instead of repeatedly:

  • decoding raw images
  • encoding images with DC-AE
  • encoding prompts with T5-base

the training loop consumed cached latent tensors and cached text features directly from WebDataset shards.

This improved the project in two ways:

  • reduced per-step compute spent on frozen preprocessing
  • increased practical throughput and made large-batch training easier

The final strongest runs used this latent-cache path together with the raw webdataset backend.

Limits And Scaling Outlook

Current strengths:

  • coherent scene generation
  • prompt understanding
  • atmosphere, lighting, and object-level structure

Current weaknesses:

  • faces remain softer and blurrier than the rest of the image
  • local detail is less reliable than global layout

Interpretation:

  • the architecture and training stack are working
  • the main bottlenecks appear to be model capacity, data quality, and latent/detail limits rather than a hidden correctness bug

Scaling expectation:

This architecture is likely to benefit from larger models and better data, because the current failure mode is mostly detail fidelity rather than complete structural failure. That is a project inference based on the observed training behavior, not a claim of already-completed large-scale scaling experiments.

How To Run

This repository contains a custom architecture, so it is not loaded through a standard Diffusers pipeline.

Install

pip install -r requirements.txt

Run From Command Line

python inference.py \
  --ckpt modelhybriddit_ema.safetensors \
  --config model_config.json \
  --prompt "a glass bridge in the mountains with a volcano" \
  --steps 20 \
  --cfg 3.5 \
  --sampler heun \
  --out outputs

Load In Python

import json
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import DiT

repo_id = "akrao9/HybridDiT"
device = "cuda" if torch.cuda.is_available() else "cpu"

weights_path = hf_hub_download(repo_id=repo_id, filename="modelhybriddit_ema.safetensors")
config_path = hf_hub_download(repo_id=repo_id, filename="model_config.json")

with open(config_path, "r", encoding="utf-8") as f:
    cfg = json.load(f)

dit = DiT(
    latent_ch=cfg["latent_ch"],
    latent_size=cfg["latent_size"],
    text_dim=cfg["text_dim"],
    text_seq=cfg["text_seq"],
    dim=cfg["dit_dim"],
    n_heads=cfg["dit_heads"],
    n_blocks=cfg["dit_depth"],
    text_drop_block=cfg.get("text_drop_block"),
    full_attn_blocks=cfg.get("full_attn_blocks"),
).to(device).eval()

state_dict = load_file(weights_path)
dit.load_state_dict(state_dict, strict=True)

print("Model loaded successfully")

Notes

  • Best checkpoint: step_028000 EMA
  • Model size: 148.5M parameters
  • Training resolution: 512x512
  • Effective training set used: about 1.6 million samples
  • Source dataset: ma-xu/fine-t2i
  • Latent cache dataset: akrao9/512t2ilatent
  • This is a custom model, not a native Diffusers pipeline

Acknowledgements

  • google-t5/t5-base
  • mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  • torchao
  • diffusers
  • webdataset
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train akrao9/HybridDiT