kashif HF Staff commited on
Commit
30add1f
·
1 Parent(s): cfa4f52

update readme

Browse files
Files changed (1) hide show
  1. README.md +102 -54
README.md CHANGED
@@ -20,95 +20,144 @@ DeepCONF monitors the confidence of generated tokens and stops generation when c
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
21
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
22
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
23
- - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20` (matches official implementation).
24
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
25
 
26
  ## Usage
27
 
 
 
28
  To use this custom generation strategy, you can pass it directly to the `generate` method:
29
 
30
  ```python
31
- from transformers import AutoModelForCausalLM, AutoTokenizer
32
 
33
- model = AutoModelForCausalLM.from_pretrained("your-model")
 
 
 
 
 
34
  tokenizer = AutoTokenizer.from_pretrained("your-model")
35
 
36
- inputs = tokenizer("Hello, world!", return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Generate with DeepCONF (Hub repo)
39
  outputs = model.generate(
40
  **inputs,
41
- enable_conf=True,
42
- window_size=2048,
43
- threshold=17.0,
44
- output_confidences=True, # request confidences
45
- return_dict_in_generate=True, # required to get tensors
46
- max_new_tokens=100,
47
  custom_generate="kashif/DeepConf", # Hugging Face Hub repo
48
  trust_remote_code=True
49
  )
 
 
 
 
 
 
 
 
 
 
50
  ```
51
 
52
- ## Calibration (DeepConf-low/high)
53
 
 
54
 
55
- DeepConf’s online stopping threshold is derived from a short warmup phase. You collect warmup trace confidences, then pass them into the generator to auto-derive the threshold for either DeepConf-low (aggressive) or DeepConf-high (permissive).
56
 
57
- 1. Warmup (num_return_sequences): collect per-trace confidences (C_t = min(step_confidences))
58
  ```python
59
  from transformers import GenerationConfig
60
 
61
- prompt = "Explain artificial intelligence."
62
- Ninit = 8 # number of warmup traces
63
- warmup_C = []
64
-
65
- warm_cfg = GenerationConfig.from_model_config(model.config)
66
- warm_cfg.do_sample = True
67
- warm_cfg.temperature = 0.7
68
- warm_cfg.top_p = 0.95
69
- warm_cfg.max_new_tokens = 64
70
- warm_cfg.enable_conf = True
71
- warm_cfg.return_dict_in_generate = True
72
- warm_cfg.output_confidences = True
73
- warm_cfg.num_return_sequences = Ninit
74
- # IMPORTANT: Do not set `warm_cfg.threshold` here. Warmup should not apply online early stopping.
75
-
76
- out = model.generate(
77
- **tokenizer(prompt, return_tensors="pt"),
78
- generation_config=warm_cfg,
 
 
 
 
 
79
  custom_generate="kashif/DeepConf",
80
  trust_remote_code=True,
81
  )
82
- # Per-trace Ct = min over steps
83
- warmup_C = out.confidences.min(dim=1).values.tolist()
 
 
84
  ```
85
 
86
- 2. Online: pass warmup confidences to auto-derive threshold
 
87
  ```python
88
- gen_cfg = GenerationConfig.from_model_config(model.config)
89
- gen_cfg.enable_conf = True
90
- gen_cfg.return_dict_in_generate = True
91
- gen_cfg.output_confidences = True
92
-
93
- # Choose a variant:
94
- # - DeepConf-low (aggressive): eta=0.1 → 90th percentile threshold
95
- # - DeepConf-high (permissive): eta=0.9 → 10th percentile threshold
96
- gen_cfg.deepconf_variant = "low" # or "high"
97
- # Optional: override eta explicitly
98
- # gen_cfg.deepconf_eta = 0.1 # defaults: 0.1 for low, 0.9 for high
99
-
100
- # Provide warmup confidences; the threshold will be derived internally
101
- gen_cfg.deepconf_warmup_confidences = warmup_C
102
-
103
- out = model.generate(
104
- **tokenizer(prompt, return_tensors="pt"),
 
 
 
105
  custom_generate="kashif/DeepConf",
106
  trust_remote_code=True,
107
- generation_config=gen_cfg,
108
- max_new_tokens=128,
109
  )
 
 
110
  ```
111
 
 
 
 
 
112
  ## Technical Details
113
 
114
  ### Confidence Calculation
@@ -123,7 +172,6 @@ This approach:
123
  - Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
124
  - Always considers a **fixed number of tokens** (conf_topk=20)
125
  - Naturally **includes the sampled token** if it's in the top-k
126
- - Matches the **official DeepConf implementation** exactly
127
 
128
  ### Online Stopping
129
 
 
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
21
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
22
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
23
+ - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`.
24
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
25
 
26
  ## Usage
27
 
28
+ ### Basic Usage
29
+
30
  To use this custom generation strategy, you can pass it directly to the `generate` method:
31
 
32
  ```python
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
34
 
35
+ # Load model and tokenizer
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ "your-model",
38
+ torch_dtype="auto",
39
+ device_map="auto"
40
+ )
41
  tokenizer = AutoTokenizer.from_pretrained("your-model")
42
 
43
+ # Prepare your prompt
44
+ question = "What is the square root of 144?"
45
+ messages = [{"role": "user", "content": question}]
46
+ prompt = tokenizer.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
52
+
53
+ # Configure generation with DeepCONF
54
+ gen_config = GenerationConfig(
55
+ do_sample=True,
56
+ temperature=0.7,
57
+ top_p=0.95,
58
+ max_new_tokens=512,
59
+ enable_conf=True, # Enable DeepCONF
60
+ window_size=2048, # Sliding window size
61
+ threshold=17.0, # Confidence threshold
62
+ conf_topk=20, # Top-k for confidence (default: 20)
63
+ output_confidences=True, # Return confidence scores
64
+ return_dict_in_generate=True, # Required for confidence output
65
+ )
66
 
67
  # Generate with DeepCONF (Hub repo)
68
  outputs = model.generate(
69
  **inputs,
70
+ generation_config=gen_config,
 
 
 
 
 
71
  custom_generate="kashif/DeepConf", # Hugging Face Hub repo
72
  trust_remote_code=True
73
  )
74
+
75
+ # Access results
76
+ generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
77
+ print(f"Generated: {generated_text}")
78
+
79
+ # Access per-step confidences if requested
80
+ if hasattr(outputs, 'confidences'):
81
+ confidences = outputs.confidences # Shape: (batch_size, num_generated_tokens)
82
+ print(f"Min confidence: {confidences.min().item():.3f}")
83
+ print(f"Mean confidence: {confidences.mean().item():.3f}")
84
  ```
85
 
86
+ ### Calibration (DeepConf-low/high)
87
 
88
+ DeepConf's online stopping threshold can be automatically derived from a warmup phase. This allows you to calibrate the threshold based on actual model behavior rather than using a fixed value.
89
 
90
+ **Step 1: Warmup Phase** - Generate multiple sequences and collect their minimum confidences:
91
 
 
92
  ```python
93
  from transformers import GenerationConfig
94
 
95
+ # Prepare inputs
96
+ question = "What is 2 + 2?"
97
+ messages = [{"role": "user", "content": question}]
98
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
99
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
+
101
+ # Configure warmup generation
102
+ warmup_cfg = GenerationConfig(
103
+ do_sample=True,
104
+ temperature=0.7,
105
+ top_p=0.95,
106
+ max_new_tokens=256,
107
+ enable_conf=True, # Enable confidence tracking
108
+ return_dict_in_generate=True,
109
+ output_confidences=True,
110
+ num_return_sequences=8, # Generate 8 warmup sequences
111
+ # Note: Do NOT set threshold here - warmup should run without early stopping
112
+ )
113
+
114
+ # Generate warmup sequences
115
+ warmup_out = model.generate(
116
+ **inputs,
117
+ generation_config=warmup_cfg,
118
  custom_generate="kashif/DeepConf",
119
  trust_remote_code=True,
120
  )
121
+
122
+ # Extract minimum confidence per sequence (C_t = min over all steps)
123
+ warmup_C = warmup_out.confidences.min(dim=1).values.tolist()
124
+ print(f"Warmup min confidences: {warmup_C}")
125
  ```
126
 
127
+ **Step 2: Production Generation** - Use warmup confidences to auto-derive threshold:
128
+
129
  ```python
130
+ # Configure production generation with calibrated threshold
131
+ gen_cfg = GenerationConfig(
132
+ do_sample=True,
133
+ temperature=0.7,
134
+ top_p=0.95,
135
+ max_new_tokens=512,
136
+ enable_conf=True,
137
+ return_dict_in_generate=True,
138
+ output_confidences=True,
139
+
140
+ # Automatic threshold calibration
141
+ deepconf_variant="low", # "low" (aggressive, 90th percentile) or "high" (permissive, 10th percentile)
142
+ deepconf_warmup_confidences=warmup_C, # Pass warmup confidences
143
+ # Optional: deepconf_eta=0.1, # Override eta (defaults: 0.1 for low, 0.9 for high)
144
+ )
145
+
146
+ # Generate with calibrated threshold
147
+ outputs = model.generate(
148
+ **inputs,
149
+ generation_config=gen_cfg,
150
  custom_generate="kashif/DeepConf",
151
  trust_remote_code=True,
 
 
152
  )
153
+
154
+ print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}")
155
  ```
156
 
157
+ **Variant Explanation:**
158
+ - **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping
159
+ - **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation
160
+
161
  ## Technical Details
162
 
163
  ### Confidence Calculation
 
172
  - Uses the **full probability distribution** (before any top-k/top-p/temperature filtering)
173
  - Always considers a **fixed number of tokens** (conf_topk=20)
174
  - Naturally **includes the sampled token** if it's in the top-k
 
175
 
176
  ### Online Stopping
177