|
|
--- |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- custom_generate |
|
|
- sampling |
|
|
--- |
|
|
|
|
|
|
|
|
# DeepCONF Custom Generation Strategy |
|
|
|
|
|
This repository implements the DeepCONF (Deep Confidence-based Early Stopping) generation strategy for Hugging Face Transformers models, following the [Deep Think with Confidence](https://jiaweizzhao.github.io/deepconf/) approach from the paper [Deep Think with Confidence](https://huggingface.co/papers/2508.15260). |
|
|
|
|
|
## Overview |
|
|
|
|
|
DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf). |
|
|
|
|
|
## Parameters |
|
|
|
|
|
- `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`. |
|
|
- `enable_early_stopping` (bool): Whether to apply early stopping during generation (online mode) or just track confidences for post-processing (batch mode). Defaults to `True`. |
|
|
- `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`. |
|
|
- `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`. |
|
|
- `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`. |
|
|
- `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization. |
|
|
- `deepconf_variant` (str): Optional variant for automatic threshold calibration (`"low"` or `"high"`). Requires `deepconf_warmup_confidences`. |
|
|
- `deepconf_warmup_confidences` (list/tensor): Warmup confidence values for threshold calibration. Used with `deepconf_variant`. |
|
|
- `deepconf_eta` (float): Optional override for eta value in threshold calculation (defaults: 0.1 for low, 0.9 for high). |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Basic Usage |
|
|
|
|
|
To use this custom generation strategy, you can pass it directly to the `generate` method: |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
|
# Load model and tokenizer |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"your-model", |
|
|
torch_dtype="auto", |
|
|
device_map="auto" |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("your-model") |
|
|
|
|
|
# Prepare your prompt |
|
|
question = "What is the square root of 144?" |
|
|
messages = [{"role": "user", "content": question}] |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
# Configure generation with DeepCONF |
|
|
gen_config = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
max_new_tokens=512, |
|
|
enable_conf=True, # Enable DeepCONF |
|
|
window_size=2048, # Sliding window size |
|
|
threshold=17.0, # Confidence threshold |
|
|
conf_topk=20, # Top-k for confidence (default: 20) |
|
|
output_confidences=True, # Return confidence scores |
|
|
return_dict_in_generate=True, # Required for confidence output |
|
|
) |
|
|
|
|
|
# Generate with DeepCONF (Hub repo) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
generation_config=gen_config, |
|
|
custom_generate="kashif/DeepConf", # Hugging Face Hub repo |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
# Access results |
|
|
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) |
|
|
print(f"Generated: {generated_text}") |
|
|
|
|
|
# Access per-step confidences if requested |
|
|
if hasattr(outputs, 'confidences'): |
|
|
confidences = outputs.confidences # Shape: (batch_size, num_generated_tokens) |
|
|
print(f"Min confidence: {confidences.min().item():.3f}") |
|
|
print(f"Mean confidence: {confidences.mean().item():.3f}") |
|
|
``` |
|
|
|
|
|
### Calibration (DeepConf-low/high) |
|
|
|
|
|
DeepConf's online stopping threshold can be automatically derived from a warmup phase. This allows you to calibrate the threshold based on actual model behavior rather than using a fixed value. |
|
|
|
|
|
**Step 1: Warmup Phase** - Generate multiple sequences and collect their minimum confidences: |
|
|
|
|
|
```python |
|
|
from transformers import GenerationConfig |
|
|
|
|
|
# Prepare inputs |
|
|
question = "What is 2 + 2?" |
|
|
messages = [{"role": "user", "content": question}] |
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
# Configure warmup generation |
|
|
warmup_cfg = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
max_new_tokens=256, |
|
|
enable_conf=True, # Enable confidence tracking |
|
|
return_dict_in_generate=True, |
|
|
output_confidences=True, |
|
|
num_return_sequences=8, # Generate 8 warmup sequences |
|
|
# Note: Do NOT set threshold here - warmup should run without early stopping |
|
|
) |
|
|
|
|
|
# Generate warmup sequences |
|
|
warmup_out = model.generate( |
|
|
**inputs, |
|
|
generation_config=warmup_cfg, |
|
|
custom_generate="kashif/DeepConf", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
# Extract minimum confidence per sequence (C_t = min over all steps) |
|
|
warmup_C = warmup_out.confidences.min(dim=1).values.tolist() |
|
|
print(f"Warmup min confidences: {warmup_C}") |
|
|
``` |
|
|
|
|
|
**Step 2: Production Generation** - Use warmup confidences to auto-derive threshold: |
|
|
|
|
|
```python |
|
|
# Configure production generation with calibrated threshold |
|
|
gen_cfg = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
max_new_tokens=512, |
|
|
enable_conf=True, |
|
|
return_dict_in_generate=True, |
|
|
output_confidences=True, |
|
|
|
|
|
# Automatic threshold calibration |
|
|
deepconf_variant="low", # "low" (aggressive, 90th percentile) or "high" (permissive, 10th percentile) |
|
|
deepconf_warmup_confidences=warmup_C, # Pass warmup confidences |
|
|
# Optional: deepconf_eta=0.1, # Override eta (defaults: 0.1 for low, 0.9 for high) |
|
|
) |
|
|
|
|
|
# Generate with calibrated threshold |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
generation_config=gen_cfg, |
|
|
custom_generate="kashif/DeepConf", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}") |
|
|
``` |
|
|
|
|
|
**Variant Explanation:** |
|
|
- **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping |
|
|
- **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation |
|
|
|
|
|
### Two Modes of Operation |
|
|
|
|
|
DeepConf supports two modes that match different use cases: |
|
|
|
|
|
#### Mode 1: Online Early Stopping (Default) |
|
|
|
|
|
This is the default behavior where early stopping happens **during** generation: |
|
|
|
|
|
```python |
|
|
# Online mode: Stop immediately when confidence drops |
|
|
gen_config = GenerationConfig( |
|
|
enable_conf=True, |
|
|
enable_early_stopping=True, # Default: True (online stopping) |
|
|
threshold=17.0, |
|
|
window_size=2048, |
|
|
max_new_tokens=512, |
|
|
) |
|
|
|
|
|
outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf") |
|
|
``` |
|
|
|
|
|
**Use cases:** |
|
|
- Interactive generation where you want immediate results |
|
|
- Real-time applications |
|
|
- Single-sequence generation |
|
|
- Lower memory usage (no need to store full sequences) |
|
|
|
|
|
#### Mode 2: Batch Generation + Post-Processing |
|
|
|
|
|
Generate multiple sequences without early stopping, then analyze them afterward: |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
# Phase 1: Generate multiple sequences WITHOUT early stopping |
|
|
gen_config = GenerationConfig( |
|
|
enable_conf=True, |
|
|
enable_early_stopping=False, # Disable online stopping |
|
|
output_confidences=True, |
|
|
return_dict_in_generate=True, |
|
|
max_new_tokens=64, |
|
|
) |
|
|
|
|
|
# Expand inputs for batch generation (e.g., 8 sequences) |
|
|
num_sequences = 8 |
|
|
expanded_input_ids = inputs.input_ids.repeat(num_sequences, 1) |
|
|
if 'attention_mask' in inputs and inputs.attention_mask is not None: |
|
|
expanded_attention_mask = inputs.attention_mask.repeat(num_sequences, 1) |
|
|
else: |
|
|
expanded_attention_mask = None |
|
|
|
|
|
# Generate batch |
|
|
outputs = model.generate( |
|
|
input_ids=expanded_input_ids, |
|
|
attention_mask=expanded_attention_mask, |
|
|
generation_config=gen_config, |
|
|
custom_generate="kashif/DeepConf" |
|
|
) |
|
|
|
|
|
# Phase 2: Post-process to analyze confidence patterns |
|
|
from custom_generate.utils import process_batch_results |
|
|
|
|
|
results = process_batch_results( |
|
|
outputs, |
|
|
tokenizer, |
|
|
window_size=2048, |
|
|
threshold=17.0 |
|
|
) |
|
|
|
|
|
# Analyze results |
|
|
print(f"Generated {results['num_traces']} sequences") |
|
|
print(f"Min confidences: {results['min_confs']}") |
|
|
|
|
|
for i, trace in enumerate(results['traces']): |
|
|
print(f"\nSequence {i+1}:") |
|
|
print(f" Text: {trace['text'][:100]}...") |
|
|
print(f" Min confidence: {trace['min_conf']:.3f}") |
|
|
print(f" Would stop early: {trace['stopped_early']}") |
|
|
if trace['stopped_early']: |
|
|
print(f" Stop position: {trace['stop_position']}") |
|
|
``` |
|
|
|
|
|
**Use cases:** |
|
|
- Research and experimentation (try different thresholds without regenerating) |
|
|
- Batch serving (generate multiple candidates at once) |
|
|
- Analysis and voting (like the official implementation) |
|
|
- Calibration and threshold tuning |
|
|
|
|
|
**Utility Functions:** |
|
|
|
|
|
The `custom_generate/utils.py` module provides helper functions: |
|
|
|
|
|
- `process_batch_results()`: Analyze batch outputs to detect early stopping positions |
|
|
- `analyze_early_stopping()`: Calculate statistics on early stopping behavior |
|
|
- `compute_warmup_threshold()`: Derive threshold from warmup confidences |
|
|
- `extract_answer()`: Parse LaTeX `\boxed{answer}` patterns |
|
|
|
|
|
#### Complete Workflow Example (Like Official DeepConf) |
|
|
|
|
|
This demonstrates the full workflow matching the official implementation: |
|
|
|
|
|
```python |
|
|
# Step 1: Warmup phase - generate multiple sequences |
|
|
warmup_config = GenerationConfig( |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
max_new_tokens=64, |
|
|
enable_conf=True, |
|
|
enable_early_stopping=False, # No stopping during warmup |
|
|
output_confidences=True, |
|
|
return_dict_in_generate=True, |
|
|
) |
|
|
|
|
|
# Expand for 8 warmup sequences |
|
|
num_warmup = 8 |
|
|
expanded_ids = inputs.input_ids.repeat(num_warmup, 1) |
|
|
expanded_mask = inputs.attention_mask.repeat(num_warmup, 1) if 'attention_mask' in inputs else None |
|
|
|
|
|
warmup_outputs = model.generate( |
|
|
input_ids=expanded_ids, |
|
|
attention_mask=expanded_mask, |
|
|
generation_config=warmup_config, |
|
|
custom_generate="kashif/DeepConf" |
|
|
) |
|
|
|
|
|
# Process warmup to get min confidences |
|
|
from custom_generate.utils import process_batch_results, compute_warmup_threshold |
|
|
|
|
|
warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=10) |
|
|
print(f"Warmup min confidences: {warmup_results['min_confs']}") |
|
|
|
|
|
# Step 2: Compute threshold from warmup |
|
|
threshold = compute_warmup_threshold( |
|
|
warmup_results['min_confs'], |
|
|
variant="low" # or "high" |
|
|
) |
|
|
print(f"Calibrated threshold: {threshold:.3f}") |
|
|
|
|
|
# Step 3: Final generation with calibrated threshold |
|
|
final_config = GenerationConfig( |
|
|
enable_conf=True, |
|
|
enable_early_stopping=True, # Online stopping with calibrated threshold |
|
|
threshold=threshold, |
|
|
window_size=10, |
|
|
max_new_tokens=128, |
|
|
) |
|
|
|
|
|
final_output = model.generate(**inputs, generation_config=final_config, custom_generate="kashif/DeepConf") |
|
|
print(tokenizer.decode(final_output.sequences[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Technical Details |
|
|
|
|
|
### Confidence Calculation |
|
|
|
|
|
The confidence score for each generated token is calculated as follows: |
|
|
|
|
|
1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary |
|
|
2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens |
|
|
3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens |
|
|
|
|
|
This approach: |
|
|
- Uses the **full probability distribution** (before any top-k/top-p/temperature filtering) |
|
|
- Always considers a **fixed number of tokens** (conf_topk=20) |
|
|
- Naturally **includes the sampled token** if it's in the top-k |
|
|
|
|
|
### Online Stopping |
|
|
|
|
|
The online method uses a sliding window of confidence scores: |
|
|
- Maintains a window of the last `window_size` (default: 2048) confidence scores |
|
|
- Calculates the mean confidence over this window |
|
|
- Stops generation when: `mean_confidence < threshold` |
|
|
|
|
|
## Requirements |
|
|
|
|
|
- PyTorch >= 1.13.0 |
|
|
- Transformers >= 4.35.0 |
|
|
|
|
|
|