HybridNLI-PolyEncoder v2

A novel hybrid zero-shot multi-label text classifier combining a frozen DeBERTa-v3-large NLI backbone with 4 lightweight PolyEncoder specialist heads.

Macro F1 = 0.9054 (+0.2366 over zero-shot NLI baseline).

Architecture

Input Text → DeBERTa-v3-large (FROZEN, one pass)
                  │ token embeddings (B, T, 1024)
   ┌──────────────┼──────────────┬─────────────┐
   │              │              │             │
Economy       Technology     Finance    Environment
PolyHead       PolyHead      PolyHead    PolyHead
alpha_e        alpha_t       alpha_f     alpha_v   ← learned blend weights
   │              │              │             │
   └──────────────┼──────────────┴─────────────┘
                  │  Health: NLI-only (0.87 baseline)
         val-tuned per-label thresholds
                  │
            Final Predictions

Benchmark

Label NLI Baseline v2 Val-Tuned Δ
Macro F1 0.6688 0.9054 +0.2366
Economy 0.4354 0.8584 +0.4230
Technology 0.7317 0.9447 +0.2130
Finance 0.6826 0.8673 +0.1847
Environment 0.6232 0.8786 +0.2554
Health 0.8712 0.9780 +0.1068

Key Innovations

  1. Dual-loss — BCE(poly_logits) + BCE(blended_score) gives alpha a direct gradient path (fixes silent v1 bug).
  2. Learned alpha — each label independently discovers the optimal NLI/Poly blend.
  3. Minimal training — only 2.1M / 435M params updated (0.49%).
  4. KL alignment — PolyHead aligns with NLI on high-confidence samples.

Quick Start

from transformers import pipeline
from huggingface_hub import hf_hub_download
import torch

# 1. Load DeBERTa backbone
nli_pipe = pipeline(
    'zero-shot-classification',
    model='MoritzLaurer/deberta-v3-large-zeroshot-v2.0',
    device=0
)

# 2. Download and load checkpoint
ckpt_path = hf_hub_download(repo_id='tdnathmlenthusiast/hybrid-nli-polyencoder', filename='multihead_v2.pt')
model, thresholds = load_model_from_checkpoint(ckpt_path, nli_pipe)
model = model.to('cuda')

# Move label embeddings to the correct device
for l, emb in model._label_embs.items():
    model._label_embs[l] = emb.to(model.device)

# 3. Classify
texts = ['The Fed raised rates to fight inflation.']
nli_s = run_nli_inference([{'text': t} for t in texts])
preds = model.predict(texts, nli_s, thresholds=thresholds)
print(preds[0]['predicted_labels'])
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using tdnathmlenthusiast/hybrid-nli-polyencoder 1