Qwen3.6-27B Full-Stack TopK SAEs

Companion to caiovicentino1/qwen36-27b-sae-papergrade (L11 / L31 / L55, d_sae=65536). Together they cover 14 layers at every-4th spacing from L11β†’L63, enabling Probe-Gated Adaptive Compute (PGAC) routing experiments and full-stack circuit analysis on a reasoning model.

This is the first publicly released full-stack SAE family for any reasoning-class LLM.

Architecture

  • Backbone: Qwen3.6-27B (hybrid GDN + standard attention, 64 decoder layers)
  • SAE type: TopK SAE (Gao et al. 2024) with AuxK dead-feature loss
  • d_in: 5120 (Qwen3.6-27B residual stream)
  • d_sae: 40960 (8Γ— expansion)
  • k_topk: 128 (Gao et al. interpretability sweet spot)
  • AuxK: k_aux=2560, alpha=1/32

Layers

11 SAEs trained from scratch on 170M tokens each:

L15, L19, L23, L27, L35, L39, L43, L47, L51, L59, L63

Combined with papergrade L11 / L31 / L55 β†’ 14 layers total.

Validation Report (1M held-out tokens, 2026-05-03)

Layer Source d_sae ve L0 alive%
L11 papergrade 65536 0.8274 128 99.60%
L15 fullstack 40960 0.6614 128 100.00%
L19 fullstack 40960 0.6322 128 100.00%
L23 fullstack 40960 0.6067 128 100.00%
L27 fullstack 40960 0.5885 128 100.00%
L31 papergrade 65536 0.6955 128 98.31%
L35 fullstack 40960 0.5890 128 100.00%
L39 fullstack 40960 0.5874 128 100.00%
L43 fullstack 40960 0.6036 128 100.00%
L47 fullstack 40960 0.6108 128 100.00%
L51 fullstack 40960 0.6509 128 100.00%
L55 papergrade 65536 0.7950 128 88.85%
L59 fullstack 40960 0.6301 128 100.00%
L63 fullstack 40960 0.7054 128 100.00%

Fullstack median ve: 0.6322 β€” comparable to Llama-Scope nano (32k expansion). Fullstack median alive%: 100% β€” strictly better than papergrade alive%.

The classic U-shape is visible: early (L11=0.83) and late (L55=0.80, L63=0.71) layers are easier to compress; mid layers (L27-L43, ~0.59-0.61) are intrinsically harder due to compositional semantic density.

Loading

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

L = 23  # any of: 15, 19, 23, 27, 35, 39, 43, 47, 51, 59, 63
weights = load_file(hf_hub_download(
    'caiovicentino1/qwen36-27b-sae-fullstack',
    f'sae_L{L}_latest.safetensors'
))
# weights = {'W_enc': (5120, 40960), 'W_dec': (40960, 5120),
#            'b_enc': (40960,), 'b_dec': (5120,)}

For papergrade L11 / L31 / L55, use the companion repo.

Minimal TopK SAE inference module:

import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKSAE(nn.Module):
    def __init__(self, d_in=5120, n=40960, k=128):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_in, n))
        self.W_dec = nn.Parameter(torch.zeros(n, d_in))
        self.b_enc = nn.Parameter(torch.zeros(n))
        self.b_dec = nn.Parameter(torch.zeros(d_in))
        self.k = k

    def forward(self, x):
        pre = (x - self.b_dec) @ self.W_enc + self.b_enc
        top_v, top_i = pre.topk(self.k, dim=-1)
        z = torch.zeros_like(pre).scatter_(-1, top_i, F.relu(top_v))
        return z @ self.W_dec + self.b_dec, z, top_i

sae = TopKSAE(5120, 40960, 128)
for k, v in weights.items():
    getattr(sae, k).data = v
sae.eval()

Files

  • sae_L{layer}_latest.safetensors β€” weights (W_enc, W_dec, b_enc, b_dec) in bf16
  • sae_L{layer}_cfg.json β€” sae_lens-compatible config
  • val_report.json β€” validation metrics (1M tokens, 2026-05-03)
  • val_report_comparison.png β€” fullstack vs papergrade plot

PGAC Phase 3 β€” Final Kernel Benchmark (5-way)

These SAEs were built to enable Probe-Gated Adaptive Compute (PGAC) β€” a method that uses a linear probe at one layer to predict SAE features at a downstream layer, allowing intermediate layers to be skipped during inference.

We benchmarked PGAC end-to-end against vanilla baselines on GSM8K (N=100, Qwen3.6-27B). Results (2026-05-04):

Config Mode Speedup GSM8K acc vs best baseline
Baseline no-thinking 1.00Γ— 0.39 βˆ’0.35
Baseline thinking (oracle) 0.21Γ— 0.74 0.00 (ceiling)
PGAC skip 5 layers no-thinking 1.19Γ— 0.53 βˆ’0.21
PGAC skip 10 layers no-thinking 1.34Γ— 0.54 βˆ’0.20
PGAC skip 10 + thinking thinking (full skip) 0.21Γ— 0.61 βˆ’0.13
Thinking-aware PGAC v2 thinking (skip in answer phase) 0.22Γ— (vs no-thinking baseline) / 1.03Γ— vs thinking 0.72 βˆ’0.02

Honest verdict

The original PGAC simulation (3.57Γ— at 8pp loss, Phase 2) does not survive to a real kernel implementation in reasoning-mode. Three findings:

  1. Layer skip in thinking mode degrades quality (βˆ’13pp accuracy with full skip) and gives no wall-clock speedup (KV cache cost dominates).

  2. The +15pp accuracy gain in no-thinking mode was an artifact of output truncation (modelo prolixo β†’ resposta truncada antes do nΓΊmero final; skip mid-layers β†’ resposta direta dentro do token budget). This is reproducible but does not represent reasoning improvement.

  3. Thinking-aware PGAC v2 (skip layers ONLY during answer phase, after </think>) preserves quality (βˆ’2pp, within noise) at a modest 1.03Γ— speedup over thinking baseline. This is the only configuration where PGAC kernel works honestly. The speedup is bounded by structural constraint: think phase is ~90% of generated tokens.

Where the speedup ceiling lives

Layer skip alone is bounded to ~1.5Γ— on reasoning models because:

  • Think phase dominates token count (~90%)
  • KV cache cost is memory-bandwidth bound, not FLOP bound (layer skip cuts FLOPs only)
  • Reasoning is carried by sparse subset of attention heads across many layers (RLKV, 2025) β€” whole-layer ablation removes the heads needed

For genuinely killer speedups (3-5Γ— compound) on reasoning models, layer skip must combine with KV cache eviction (ThinKV, R-KV, kvpress) and sequence-level early exit (DEER, SpecExit). We pursue this direction (unified probe controller across 3 axes) in follow-up work.

Companion paper (in submission): "Probe-Detected Grokking in Multi-Probe DPO: Orthogonal Learning Beyond Task-Specific Detectors" (NeurIPS MI Workshop 2026).

Training

  • Tokens: 200M target (170M completed, plateau confirmed in mid layers)
  • Hardware: 1Γ— RTX PRO 6000 Blackwell (96GB VRAM)
  • Time: ~16h
  • Cost: R$72 ($15)
  • Corpus: fineweb-edu (70%) + OpenThoughts-114k (20%) + OpenMathInstruct-2 (10%)
  • Optimizer: torch.optim.Adam (bf16 momentum, no bnb)
  • Schedule: Cosine warmup 5K β†’ peak 2e-4 β†’ floor 1.5e-4 (mid-training bump from 6e-5)
  • Mixed precision: bf16 (model + SAE params + Adam state)

The bf16 momentum + lower expansion (8Γ— vs papergrade 13Γ—) was chosen to fit 11 SAEs simultaneously on a 96GB GPU. Quality trade-off documented above.

Reproducibility

Build script: build_nb_pgac_phase4.py (canonical). Training notebook: see OpenInterpretability repo for nb_pgac_phase4_fullstack_sae.ipynb.

Citation

If you use these SAEs, please cite:

@misc{vicentino2026fullstack,
  title={Full-Stack Sparse Autoencoders for Qwen3.6-27B Reasoning Model},
  author={Vicentino, Caio},
  year={2026},
  howpublished={\\url{https://huggingface.co/caiovicentino1/qwen36-27b-sae-fullstack}},
}

License

Apache 2.0 β€” see LICENSE.


Part of the OpenInterpretability ecosystem.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for caiovicentino1/qwen36-27b-sae-fullstack

Base model

Qwen/Qwen3.6-27B
Finetuned
(170)
this model