kashif HF Staff commited on
Commit
56bd97c
·
1 Parent(s): e0297b7
Files changed (2) hide show
  1. README.md +155 -0
  2. custom_generate/generate.py +7 -2
README.md CHANGED
@@ -18,10 +18,14 @@ DeepCONF monitors the confidence of generated tokens and stops generation when c
18
  ## Parameters
19
 
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
 
21
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
22
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
23
  - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`.
24
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
 
 
 
25
 
26
  ## Usage
27
 
@@ -158,6 +162,157 @@ print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=T
158
  - **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping
159
  - **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  ## Technical Details
162
 
163
  ### Confidence Calculation
 
18
  ## Parameters
19
 
20
  - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`.
21
+ - `enable_early_stopping` (bool): Whether to apply early stopping during generation (online mode) or just track confidences for post-processing (batch mode). Defaults to `True`.
22
  - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`.
23
  - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`.
24
  - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`.
25
  - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization.
26
+ - `deepconf_variant` (str): Optional variant for automatic threshold calibration (`"low"` or `"high"`). Requires `deepconf_warmup_confidences`.
27
+ - `deepconf_warmup_confidences` (list/tensor): Warmup confidence values for threshold calibration. Used with `deepconf_variant`.
28
+ - `deepconf_eta` (float): Optional override for eta value in threshold calculation (defaults: 0.1 for low, 0.9 for high).
29
 
30
  ## Usage
31
 
 
162
  - **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping
163
  - **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation
164
 
165
+ ### Two Modes of Operation
166
+
167
+ DeepConf supports two modes that match different use cases:
168
+
169
+ #### Mode 1: Online Early Stopping (Default)
170
+
171
+ This is the default behavior where early stopping happens **during** generation:
172
+
173
+ ```python
174
+ # Online mode: Stop immediately when confidence drops
175
+ gen_config = GenerationConfig(
176
+ enable_conf=True,
177
+ enable_early_stopping=True, # Default: True (online stopping)
178
+ threshold=17.0,
179
+ window_size=2048,
180
+ max_new_tokens=512,
181
+ )
182
+
183
+ outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf")
184
+ ```
185
+
186
+ **Use cases:**
187
+ - Interactive generation where you want immediate results
188
+ - Real-time applications
189
+ - Single-sequence generation
190
+ - Lower memory usage (no need to store full sequences)
191
+
192
+ #### Mode 2: Batch Generation + Post-Processing
193
+
194
+ Generate multiple sequences without early stopping, then analyze them afterward:
195
+
196
+ ```python
197
+ import torch
198
+
199
+ # Phase 1: Generate multiple sequences WITHOUT early stopping
200
+ gen_config = GenerationConfig(
201
+ enable_conf=True,
202
+ enable_early_stopping=False, # Disable online stopping
203
+ output_confidences=True,
204
+ return_dict_in_generate=True,
205
+ max_new_tokens=64,
206
+ )
207
+
208
+ # Expand inputs for batch generation (e.g., 8 sequences)
209
+ num_sequences = 8
210
+ expanded_input_ids = inputs.input_ids.repeat(num_sequences, 1)
211
+ if 'attention_mask' in inputs and inputs.attention_mask is not None:
212
+ expanded_attention_mask = inputs.attention_mask.repeat(num_sequences, 1)
213
+ else:
214
+ expanded_attention_mask = None
215
+
216
+ # Generate batch
217
+ outputs = model.generate(
218
+ input_ids=expanded_input_ids,
219
+ attention_mask=expanded_attention_mask,
220
+ generation_config=gen_config,
221
+ custom_generate="kashif/DeepConf"
222
+ )
223
+
224
+ # Phase 2: Post-process to analyze confidence patterns
225
+ from custom_generate.utils import process_batch_results
226
+
227
+ results = process_batch_results(
228
+ outputs,
229
+ tokenizer,
230
+ window_size=2048,
231
+ threshold=17.0
232
+ )
233
+
234
+ # Analyze results
235
+ print(f"Generated {results['num_traces']} sequences")
236
+ print(f"Min confidences: {results['min_confs']}")
237
+
238
+ for i, trace in enumerate(results['traces']):
239
+ print(f"\nSequence {i+1}:")
240
+ print(f" Text: {trace['text'][:100]}...")
241
+ print(f" Min confidence: {trace['min_conf']:.3f}")
242
+ print(f" Would stop early: {trace['stopped_early']}")
243
+ if trace['stopped_early']:
244
+ print(f" Stop position: {trace['stop_position']}")
245
+ ```
246
+
247
+ **Use cases:**
248
+ - Research and experimentation (try different thresholds without regenerating)
249
+ - Batch serving (generate multiple candidates at once)
250
+ - Analysis and voting (like the official implementation)
251
+ - Calibration and threshold tuning
252
+
253
+ **Utility Functions:**
254
+
255
+ The `custom_generate/utils.py` module provides helper functions:
256
+
257
+ - `process_batch_results()`: Analyze batch outputs to detect early stopping positions
258
+ - `analyze_early_stopping()`: Calculate statistics on early stopping behavior
259
+ - `compute_warmup_threshold()`: Derive threshold from warmup confidences
260
+ - `extract_answer()`: Parse LaTeX `\boxed{answer}` patterns
261
+
262
+ #### Complete Workflow Example (Like Official DeepConf)
263
+
264
+ This demonstrates the full workflow matching the official implementation:
265
+
266
+ ```python
267
+ # Step 1: Warmup phase - generate multiple sequences
268
+ warmup_config = GenerationConfig(
269
+ do_sample=True,
270
+ temperature=0.7,
271
+ max_new_tokens=64,
272
+ enable_conf=True,
273
+ enable_early_stopping=False, # No stopping during warmup
274
+ output_confidences=True,
275
+ return_dict_in_generate=True,
276
+ )
277
+
278
+ # Expand for 8 warmup sequences
279
+ num_warmup = 8
280
+ expanded_ids = inputs.input_ids.repeat(num_warmup, 1)
281
+ expanded_mask = inputs.attention_mask.repeat(num_warmup, 1) if 'attention_mask' in inputs else None
282
+
283
+ warmup_outputs = model.generate(
284
+ input_ids=expanded_ids,
285
+ attention_mask=expanded_mask,
286
+ generation_config=warmup_config,
287
+ custom_generate="kashif/DeepConf"
288
+ )
289
+
290
+ # Process warmup to get min confidences
291
+ from custom_generate.utils import process_batch_results, compute_warmup_threshold
292
+
293
+ warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=10)
294
+ print(f"Warmup min confidences: {warmup_results['min_confs']}")
295
+
296
+ # Step 2: Compute threshold from warmup
297
+ threshold = compute_warmup_threshold(
298
+ warmup_results['min_confs'],
299
+ variant="low" # or "high"
300
+ )
301
+ print(f"Calibrated threshold: {threshold:.3f}")
302
+
303
+ # Step 3: Final generation with calibrated threshold
304
+ final_config = GenerationConfig(
305
+ enable_conf=True,
306
+ enable_early_stopping=True, # Online stopping with calibrated threshold
307
+ threshold=threshold,
308
+ window_size=10,
309
+ max_new_tokens=128,
310
+ )
311
+
312
+ final_output = model.generate(**inputs, generation_config=final_config, custom_generate="kashif/DeepConf")
313
+ print(tokenizer.decode(final_output.sequences[0], skip_special_tokens=True))
314
+ ```
315
+
316
  ## Technical Details
317
 
318
  ### Confidence Calculation
custom_generate/generate.py CHANGED
@@ -52,6 +52,7 @@ def generate(
52
 
53
  # Get DeepCONF parameters from generation_config or set defaults
54
  enable_conf = getattr(generation_config, "enable_conf", False)
 
55
  window_size = getattr(generation_config, "window_size", 2048)
56
  threshold = getattr(
57
  generation_config, "threshold", 17.0
@@ -263,6 +264,10 @@ def generate(
263
 
264
  # Get top-k tokens from full probability distribution
265
  top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1)
 
 
 
 
266
  log_probs = torch.log(top_probs)
267
  # Confidence is negative mean of log probabilities of top-k tokens
268
  conf = -log_probs.mean().item()
@@ -273,8 +278,8 @@ def generate(
273
  conf_group_lists[i].append(conf)
274
  conf_grouped_sums[i] += conf
275
 
276
- # Apply confidence-based early stopping when window is full
277
- if len(conf_group_lists[i]) >= window_size:
278
  avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i])
279
  if avg_conf < threshold:
280
  deepconf_stopping[i] = False
 
52
 
53
  # Get DeepCONF parameters from generation_config or set defaults
54
  enable_conf = getattr(generation_config, "enable_conf", False)
55
+ enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
56
  window_size = getattr(generation_config, "window_size", 2048)
57
  threshold = getattr(
58
  generation_config, "threshold", 17.0
 
264
 
265
  # Get top-k tokens from full probability distribution
266
  top_probs, _ = torch.topk(probs[i], k=conf_topk, dim=-1)
267
+ # Add epsilon for numerical stability (prevent log(0) = -inf)
268
+ # Use 1e-7 for float16 compatibility (float16 min ~6e-8)
269
+ eps = torch.finfo(top_probs.dtype).eps if top_probs.dtype == torch.float32 else 1e-7
270
+ top_probs = torch.clamp(top_probs, min=eps)
271
  log_probs = torch.log(top_probs)
272
  # Confidence is negative mean of log probabilities of top-k tokens
273
  conf = -log_probs.mean().item()
 
278
  conf_group_lists[i].append(conf)
279
  conf_grouped_sums[i] += conf
280
 
281
+ # Apply confidence-based early stopping when window is full (only if enabled)
282
+ if enable_early_stopping and len(conf_group_lists[i]) >= window_size:
283
  avg_conf = conf_grouped_sums[i] / len(conf_group_lists[i])
284
  if avg_conf < threshold:
285
  deepconf_stopping[i] = False