File size: 1,796 Bytes
d60851f
 
 
 
 
 
 
 
76f5e06
d60851f
 
 
 
 
 
 
 
 
 
 
76f5e06
d60851f
76f5e06
d60851f
76f5e06
 
d60851f
76f5e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
library_name: transformers
tags:
  - custom_generate
---

# LagKV Cache

## Introduction

![LagKV Cache diagram from the original paper](https://arxiv.org/html/2504.04704v1/x1.png)

LagKV is an efficient and robust KV compression algorithm. It uses lag tokens information to compress the previous ones which significantly boost the compression performance with little computation overhead.

[Original Github](https://github.com/AI-Lab-China-Merchants-Bank/LagKV)

Details are in the following work:

[LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important](https://arxiv.org/abs/2504.04704)

## Example usage

We can use the custom generation method in this repository like the the base `generate` from `transformers`:

```py
# requires `transformers>=4.52.0`
from transformers import AutoModelForCausalLM, AutoTokenizer
# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", device_map="auto")
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Using lagkv cache
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    # lagkv cache arguments (default `lag_ratio=0.5,lag_size=128,lag_sink_size=16`)
    custom_generate="CMB-AI-LAB/lagkv_cache",
    trust_remote_code=True,
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "lagkvcache" in str(type(gen_out.past_key_values)).lower()