File size: 5,523 Bytes
c1cd11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Simple examples showing DeepConf sample generations
"""

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


def generate_with_deepconf(
    question: str,
    enable_early_stopping: bool = True,
    threshold: float = 10.0,
    window_size: int = 10,
    max_tokens: int = 128,
):
    """Generate with DeepConf and show results"""

    # Load model (cached)
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)

    # Prepare prompt
    messages = [{"role": "user", "content": question}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Configure generation
    gen_config = GenerationConfig(
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        max_new_tokens=max_tokens,
        enable_conf=True,
        enable_early_stopping=enable_early_stopping,
        threshold=threshold,
        window_size=window_size,
        output_confidences=True,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.eos_token_id,
    )

    # Generate
    outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True)

    # Extract results
    generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1]

    if hasattr(outputs, "confidences") and outputs.confidences is not None:
        min_conf = outputs.confidences.min().item()
        max_conf = outputs.confidences.max().item()
        mean_conf = outputs.confidences.mean().item()
    else:
        min_conf = max_conf = mean_conf = None

    return {
        "text": generated_text,
        "tokens": tokens_generated,
        "min_conf": min_conf,
        "max_conf": max_conf,
        "mean_conf": mean_conf,
    }


def print_result(title: str, question: str, result: dict):
    """Pretty print generation result"""
    print(f"\n{'=' * 80}")
    print(f"{title}")
    print(f"{'=' * 80}")
    print(f"Question: {question}")
    print(f"\nGenerated ({result['tokens']} tokens):")
    print(f"{'-' * 80}")
    print(result["text"])
    print(f"{'-' * 80}")

    if result["min_conf"] is not None:
        print("\nConfidence stats:")
        print(f"  Min: {result['min_conf']:.3f}")
        print(f"  Max: {result['max_conf']:.3f}")
        print(f"  Mean: {result['mean_conf']:.3f}")


if __name__ == "__main__":
    print("\n" + "β–ˆ" * 80)
    print("DEEPCONF SAMPLE GENERATIONS")
    print("β–ˆ" * 80)

    # Example 1: Math with aggressive early stopping
    result = generate_with_deepconf(
        "What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64
    )
    print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result)

    # Example 2: Math with permissive early stopping
    result = generate_with_deepconf(
        "What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64
    )
    print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result)

    # Example 3: Math without early stopping
    result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64)
    print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result)

    # Example 4: Reasoning question
    result = generate_with_deepconf(
        "If 5 apples cost $10, how much do 3 apples cost?",
        enable_early_stopping=True,
        threshold=8.0,
        window_size=5,
        max_tokens=96,
    )
    print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result)

    # Example 5: Factual question
    result = generate_with_deepconf(
        "Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64
    )
    print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result)

    # Example 6: Calculation
    result = generate_with_deepconf(
        "Calculate: (15 + 8) Γ— 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96
    )
    print_result("Example 6: Calculation", "Calculate: (15 + 8) Γ— 2", result)

    # Example 7: Definition
    result = generate_with_deepconf(
        "Define photosynthesis in simple terms.",
        enable_early_stopping=True,
        threshold=10.0,
        window_size=10,
        max_tokens=128,
    )
    print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result)

    # Example 8: Step-by-step
    result = generate_with_deepconf(
        "Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96
    )
    print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result)

    print(f"\n{'β–ˆ' * 80}")
    print("ALL EXAMPLES COMPLETE")
    print("β–ˆ" * 80)
    print("\nKey observations:")
    print("- Lower threshold β†’ Earlier stopping (fewer tokens)")
    print("- Higher threshold β†’ Later stopping (more tokens)")
    print("- No early stopping β†’ Always generates max_tokens")
    print("- Confidence varies based on model certainty")