HybridNLI-PolyEncoder v2
A novel hybrid zero-shot multi-label text classifier combining a frozen DeBERTa-v3-large NLI backbone with 4 lightweight PolyEncoder specialist heads.
Macro F1 = 0.9054 (+0.2366 over zero-shot NLI baseline).
Architecture
Input Text → DeBERTa-v3-large (FROZEN, one pass)
│ token embeddings (B, T, 1024)
┌──────────────┼──────────────┬─────────────┐
│ │ │ │
Economy Technology Finance Environment
PolyHead PolyHead PolyHead PolyHead
alpha_e alpha_t alpha_f alpha_v ← learned blend weights
│ │ │ │
└──────────────┼──────────────┴─────────────┘
│ Health: NLI-only (0.87 baseline)
val-tuned per-label thresholds
│
Final Predictions
Benchmark
| Label | NLI Baseline | v2 Val-Tuned | Δ |
|---|---|---|---|
| Macro F1 | 0.6688 | 0.9054 | +0.2366 |
| Economy | 0.4354 | 0.8584 | +0.4230 |
| Technology | 0.7317 | 0.9447 | +0.2130 |
| Finance | 0.6826 | 0.8673 | +0.1847 |
| Environment | 0.6232 | 0.8786 | +0.2554 |
| Health | 0.8712 | 0.9780 | +0.1068 |
Key Innovations
- Dual-loss — BCE(poly_logits) + BCE(blended_score) gives alpha a direct gradient path (fixes silent v1 bug).
- Learned alpha — each label independently discovers the optimal NLI/Poly blend.
- Minimal training — only 2.1M / 435M params updated (0.49%).
- KL alignment — PolyHead aligns with NLI on high-confidence samples.
Quick Start
from transformers import pipeline
from huggingface_hub import hf_hub_download
import torch
# 1. Load DeBERTa backbone
nli_pipe = pipeline(
'zero-shot-classification',
model='MoritzLaurer/deberta-v3-large-zeroshot-v2.0',
device=0
)
# 2. Download and load checkpoint
ckpt_path = hf_hub_download(repo_id='tdnathmlenthusiast/hybrid-nli-polyencoder', filename='multihead_v2.pt')
model, thresholds = load_model_from_checkpoint(ckpt_path, nli_pipe)
model = model.to('cuda')
# Move label embeddings to the correct device
for l, emb in model._label_embs.items():
model._label_embs[l] = emb.to(model.device)
# 3. Classify
texts = ['The Fed raised rates to fight inflation.']
nli_s = run_nli_inference([{'text': t} for t in texts])
preds = model.predict(texts, nli_s, thresholds=thresholds)
print(preds[0]['predicted_labels'])