--- license: apache-2.0 library_name: transformers tags: - custom_generate - sampling --- # DeepCONF Custom Generation Strategy This repository implements the DeepCONF (Deep Confidence-based Early Stopping) generation strategy for Hugging Face Transformers models, following the [Deep Think with Confidence](https://jiaweizzhao.github.io/deepconf/) approach from the paper [Deep Think with Confidence](https://huggingface.co/papers/2508.15260). ## Overview DeepCONF monitors the confidence of generated tokens and stops generation when confidence falls below a threshold. The confidence is calculated as the negative mean log probability of the top-k tokens from the full vocabulary (before sampling/filtering is applied), following the methodology from the [official DeepConf implementation](https://github.com/facebookresearch/deepconf). ## Parameters - `enable_conf` (bool): Whether to enable the DeepCONF strategy. Defaults to `False`. - `enable_early_stopping` (bool): Whether to apply early stopping during generation (online mode) or just track confidences for post-processing (batch mode). Defaults to `True`. - `window_size` (int): Size of the sliding window for confidence calculation. Defaults to `2048`. - `threshold` (float): Confidence threshold for early stopping. Defaults to `17.0`. - `conf_topk` (int): Number of top tokens to use for confidence calculation from the full vocabulary. Defaults to `20`. - `output_confidences` (bool): If `True` and `return_dict_in_generate=True`, returns a per-step confidence tensor alongside generated sequences for debugging/visualization. - `deepconf_variant` (str): Optional variant for automatic threshold calibration (`"low"` or `"high"`). Requires `deepconf_warmup_confidences`. - `deepconf_warmup_confidences` (list/tensor): Warmup confidence values for threshold calibration. Used with `deepconf_variant`. - `deepconf_eta` (float): Optional override for eta value in threshold calculation (defaults: 0.1 for low, 0.9 for high). ## Usage ### Basic Usage To use this custom generation strategy, you can pass it directly to the `generate` method: ```python from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained( "your-model", torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("your-model") # Prepare your prompt question = "What is the square root of 144?" messages = [{"role": "user", "content": question}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Configure generation with DeepCONF gen_config = GenerationConfig( do_sample=True, temperature=0.7, top_p=0.95, max_new_tokens=512, enable_conf=True, # Enable DeepCONF window_size=2048, # Sliding window size threshold=17.0, # Confidence threshold conf_topk=20, # Top-k for confidence (default: 20) output_confidences=True, # Return confidence scores return_dict_in_generate=True, # Required for confidence output ) # Generate with DeepCONF (Hub repo) outputs = model.generate( **inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", # Hugging Face Hub repo trust_remote_code=True ) # Access results generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) print(f"Generated: {generated_text}") # Access per-step confidences if requested if hasattr(outputs, 'confidences'): confidences = outputs.confidences # Shape: (batch_size, num_generated_tokens) print(f"Min confidence: {confidences.min().item():.3f}") print(f"Mean confidence: {confidences.mean().item():.3f}") ``` ### Calibration (DeepConf-low/high) DeepConf's online stopping threshold can be automatically derived from a warmup phase. This allows you to calibrate the threshold based on actual model behavior rather than using a fixed value. **Step 1: Warmup Phase** - Generate multiple sequences and collect their minimum confidences: ```python from transformers import GenerationConfig # Prepare inputs question = "What is 2 + 2?" messages = [{"role": "user", "content": question}] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Configure warmup generation warmup_cfg = GenerationConfig( do_sample=True, temperature=0.7, top_p=0.95, max_new_tokens=256, enable_conf=True, # Enable confidence tracking return_dict_in_generate=True, output_confidences=True, num_return_sequences=8, # Generate 8 warmup sequences # Note: Do NOT set threshold here - warmup should run without early stopping ) # Generate warmup sequences warmup_out = model.generate( **inputs, generation_config=warmup_cfg, custom_generate="kashif/DeepConf", trust_remote_code=True, ) # Extract minimum confidence per sequence (C_t = min over all steps) warmup_C = warmup_out.confidences.min(dim=1).values.tolist() print(f"Warmup min confidences: {warmup_C}") ``` **Step 2: Production Generation** - Use warmup confidences to auto-derive threshold: ```python # Configure production generation with calibrated threshold gen_cfg = GenerationConfig( do_sample=True, temperature=0.7, top_p=0.95, max_new_tokens=512, enable_conf=True, return_dict_in_generate=True, output_confidences=True, # Automatic threshold calibration deepconf_variant="low", # "low" (aggressive, 90th percentile) or "high" (permissive, 10th percentile) deepconf_warmup_confidences=warmup_C, # Pass warmup confidences # Optional: deepconf_eta=0.1, # Override eta (defaults: 0.1 for low, 0.9 for high) ) # Generate with calibrated threshold outputs = model.generate( **inputs, generation_config=gen_cfg, custom_generate="kashif/DeepConf", trust_remote_code=True, ) print(f"Generated: {tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}") ``` **Variant Explanation:** - **DeepConf-low** (eta=0.1): Uses 90th percentile threshold → More aggressive early stopping - **DeepConf-high** (eta=0.9): Uses 10th percentile threshold → More permissive, allows longer generation ### Two Modes of Operation DeepConf supports two modes that match different use cases: #### Mode 1: Online Early Stopping (Default) This is the default behavior where early stopping happens **during** generation: ```python # Online mode: Stop immediately when confidence drops gen_config = GenerationConfig( enable_conf=True, enable_early_stopping=True, # Default: True (online stopping) threshold=17.0, window_size=2048, max_new_tokens=512, ) outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf") ``` **Use cases:** - Interactive generation where you want immediate results - Real-time applications - Single-sequence generation - Lower memory usage (no need to store full sequences) #### Mode 2: Batch Generation + Post-Processing Generate multiple sequences without early stopping, then analyze them afterward: ```python import torch # Phase 1: Generate multiple sequences WITHOUT early stopping gen_config = GenerationConfig( enable_conf=True, enable_early_stopping=False, # Disable online stopping output_confidences=True, return_dict_in_generate=True, max_new_tokens=64, ) # Expand inputs for batch generation (e.g., 8 sequences) num_sequences = 8 expanded_input_ids = inputs.input_ids.repeat(num_sequences, 1) if 'attention_mask' in inputs and inputs.attention_mask is not None: expanded_attention_mask = inputs.attention_mask.repeat(num_sequences, 1) else: expanded_attention_mask = None # Generate batch outputs = model.generate( input_ids=expanded_input_ids, attention_mask=expanded_attention_mask, generation_config=gen_config, custom_generate="kashif/DeepConf" ) # Phase 2: Post-process to analyze confidence patterns from custom_generate.utils import process_batch_results results = process_batch_results( outputs, tokenizer, window_size=2048, threshold=17.0 ) # Analyze results print(f"Generated {results['num_traces']} sequences") print(f"Min confidences: {results['min_confs']}") for i, trace in enumerate(results['traces']): print(f"\nSequence {i+1}:") print(f" Text: {trace['text'][:100]}...") print(f" Min confidence: {trace['min_conf']:.3f}") print(f" Would stop early: {trace['stopped_early']}") if trace['stopped_early']: print(f" Stop position: {trace['stop_position']}") ``` **Use cases:** - Research and experimentation (try different thresholds without regenerating) - Batch serving (generate multiple candidates at once) - Analysis and voting (like the official implementation) - Calibration and threshold tuning **Utility Functions:** The `custom_generate/utils.py` module provides helper functions: - `process_batch_results()`: Analyze batch outputs to detect early stopping positions - `analyze_early_stopping()`: Calculate statistics on early stopping behavior - `compute_warmup_threshold()`: Derive threshold from warmup confidences - `extract_answer()`: Parse LaTeX `\boxed{answer}` patterns #### Complete Workflow Example (Like Official DeepConf) This demonstrates the full workflow matching the official implementation: ```python # Step 1: Warmup phase - generate multiple sequences warmup_config = GenerationConfig( do_sample=True, temperature=0.7, max_new_tokens=64, enable_conf=True, enable_early_stopping=False, # No stopping during warmup output_confidences=True, return_dict_in_generate=True, ) # Expand for 8 warmup sequences num_warmup = 8 expanded_ids = inputs.input_ids.repeat(num_warmup, 1) expanded_mask = inputs.attention_mask.repeat(num_warmup, 1) if 'attention_mask' in inputs else None warmup_outputs = model.generate( input_ids=expanded_ids, attention_mask=expanded_mask, generation_config=warmup_config, custom_generate="kashif/DeepConf" ) # Process warmup to get min confidences from custom_generate.utils import process_batch_results, compute_warmup_threshold warmup_results = process_batch_results(warmup_outputs, tokenizer, window_size=10) print(f"Warmup min confidences: {warmup_results['min_confs']}") # Step 2: Compute threshold from warmup threshold = compute_warmup_threshold( warmup_results['min_confs'], variant="low" # or "high" ) print(f"Calibrated threshold: {threshold:.3f}") # Step 3: Final generation with calibrated threshold final_config = GenerationConfig( enable_conf=True, enable_early_stopping=True, # Online stopping with calibrated threshold threshold=threshold, window_size=10, max_new_tokens=128, ) final_output = model.generate(**inputs, generation_config=final_config, custom_generate="kashif/DeepConf") print(tokenizer.decode(final_output.sequences[0], skip_special_tokens=True)) ``` ## Technical Details ### Confidence Calculation The confidence score for each generated token is calculated as follows: 1. **Extract top-k tokens**: Get the top-k (default: 20) tokens with highest probabilities from the full vocabulary 2. **Compute log probabilities**: Calculate log probabilities for these top-k tokens 3. **Average**: The confidence score is `-mean(log_probs)` of the top-k tokens This approach: - Uses the **full probability distribution** (before any top-k/top-p/temperature filtering) - Always considers a **fixed number of tokens** (conf_topk=20) - Naturally **includes the sampled token** if it's in the top-k ### Online Stopping The online method uses a sliding window of confidence scores: - Maintains a window of the last `window_size` (default: 2048) confidence scores - Calculates the mean confidence over this window - Stops generation when: `mean_confidence < threshold` ## Requirements - PyTorch >= 1.13.0 - Transformers >= 4.35.0