Matryoshka Representation Learning
Paper
•
2205.13147
•
Published
•
25
This is a fine-tuned version of Snowflake/snowflake-arctic-embed-m-v2.0 optimized for academic and scientific literature search. The model has been trained using contrastive learning with hard negative mining, specifically curated for academic search scenarios.
| Attribute | Value |
|---|---|
| Base Model | Snowflake/snowflake-arctic-embed-m-v2.0 |
| Architecture | GTE |
| Embedding Dimension | 768 |
| MRL Dimensions | 768, 512, 256, 128 |
| Max Sequence Length | 4096 |
| Pooling | CLS token |
| Precision | float16 |
| Model | Avg. | SciFact: Recall@10 | TRECCOVID: Recall@10 | NFCorpus: Recall@10 | SCIDOCS: Recall@10 | LitSearch: Recall@10 | QASA: Recall@10 |
|---|---|---|---|---|---|---|---|
| snowflake-arctic-embed-m-v2.0-academic | 0.3729 | 0.8609 | 0.0219 | 0.177 | 0.2129 | 0.6435 | 0.321 |
| snowflake-arctic-embed-m-v2.0 | 0.3654 | 0.8353 | 0.0224 | 0.1669 | 0.2122 | 0.6508 | 0.3046 |
| Parameter | Value |
|---|---|
| Learning Rate | 2e-5 |
| Batch Size | 8192 (effective) |
| Per-Device Batch Size | 32 |
| Warmup Steps | 100 |
| Weight Decay | 0.1 |
| Precision | fp16 |
| Max Length | 4096 |
| Loss Function | InfoNCE (Contrastive) |
| Temperature (Ï„) | 0.02 |
The model was trained on LEAD (Liner Embedding Academic Dataset), a combination of ~55,560 samples tailored for academic search:
This model supports Matryoshka Representation Learning. You can truncate embeddings to smaller dimensions (512, 256, 128) for faster computation and reduced storage.
# Full dimension (768)
full_embedding = embeddings[:, :768]
# MRL dimensions
embedding_512 = embeddings[:, :512]
embedding_256 = embeddings[:, :256]
embedding_128 = embeddings[:, :128]
import torch
from transformers import AutoModel, AutoTokenizer
model_path = "LinerAI/snowflake-arctic-embed-l-v2.0-academic"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
model.eval()
# For queries
def encode_query(text):
input_text = f"query: {text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=4096, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0] # CLS token
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
# For passages
def encode_passage(text):
inputs = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0] # CLS token
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings
# Example: Academic search
query = "transformer models for protein structure prediction"
abstract = "We introduce AlphaFold, a deep learning system that predicts protein structures..."
query_emb = encode_query(query)
passage_emb = encode_passage(abstract)
similarity = torch.nn.functional.cosine_similarity(query_emb, passage_emb)
print(f"Similarity: {similarity.item():.4f}")
query: {your_query_text}
{your_passage_text}
trust_remote_code=True for loadingThis model is released under the Apache 2.0 license.
Base model
Snowflake/snowflake-arctic-embed-m-v2.0