| """ |
| Simple examples showing DeepConf sample generations |
| """ |
|
|
| import torch |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
|
|
|
| def generate_with_deepconf( |
| question: str, |
| enable_early_stopping: bool = True, |
| threshold: float = 10.0, |
| window_size: int = 10, |
| max_tokens: int = 128, |
| ): |
| """Generate with DeepConf and show results""" |
|
|
| |
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
|
|
| |
| messages = [{"role": "user", "content": question}] |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| |
| gen_config = GenerationConfig( |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.95, |
| max_new_tokens=max_tokens, |
| enable_conf=True, |
| enable_early_stopping=enable_early_stopping, |
| threshold=threshold, |
| window_size=window_size, |
| output_confidences=True, |
| return_dict_in_generate=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True) |
|
|
| |
| generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) |
| tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1] |
|
|
| if hasattr(outputs, "confidences") and outputs.confidences is not None: |
| min_conf = outputs.confidences.min().item() |
| max_conf = outputs.confidences.max().item() |
| mean_conf = outputs.confidences.mean().item() |
| else: |
| min_conf = max_conf = mean_conf = None |
|
|
| return { |
| "text": generated_text, |
| "tokens": tokens_generated, |
| "min_conf": min_conf, |
| "max_conf": max_conf, |
| "mean_conf": mean_conf, |
| } |
|
|
|
|
| def print_result(title: str, question: str, result: dict): |
| """Pretty print generation result""" |
| print(f"\n{'=' * 80}") |
| print(f"{title}") |
| print(f"{'=' * 80}") |
| print(f"Question: {question}") |
| print(f"\nGenerated ({result['tokens']} tokens):") |
| print(f"{'-' * 80}") |
| print(result["text"]) |
| print(f"{'-' * 80}") |
|
|
| if result["min_conf"] is not None: |
| print("\nConfidence stats:") |
| print(f" Min: {result['min_conf']:.3f}") |
| print(f" Max: {result['max_conf']:.3f}") |
| print(f" Mean: {result['mean_conf']:.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| print("\n" + "â–ˆ" * 80) |
| print("DEEPCONF SAMPLE GENERATIONS") |
| print("â–ˆ" * 80) |
|
|
| |
| result = generate_with_deepconf( |
| "What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64 |
| ) |
| print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result) |
|
|
| |
| result = generate_with_deepconf( |
| "What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64 |
| ) |
| print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result) |
|
|
| |
| result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64) |
| print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result) |
|
|
| |
| result = generate_with_deepconf( |
| "If 5 apples cost $10, how much do 3 apples cost?", |
| enable_early_stopping=True, |
| threshold=8.0, |
| window_size=5, |
| max_tokens=96, |
| ) |
| print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result) |
|
|
| |
| result = generate_with_deepconf( |
| "Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64 |
| ) |
| print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result) |
|
|
| |
| result = generate_with_deepconf( |
| "Calculate: (15 + 8) × 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96 |
| ) |
| print_result("Example 6: Calculation", "Calculate: (15 + 8) × 2", result) |
|
|
| |
| result = generate_with_deepconf( |
| "Define photosynthesis in simple terms.", |
| enable_early_stopping=True, |
| threshold=10.0, |
| window_size=10, |
| max_tokens=128, |
| ) |
| print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result) |
|
|
| |
| result = generate_with_deepconf( |
| "Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96 |
| ) |
| print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result) |
|
|
| print(f"\n{'â–ˆ' * 80}") |
| print("ALL EXAMPLES COMPLETE") |
| print("â–ˆ" * 80) |
| print("\nKey observations:") |
| print("- Lower threshold → Earlier stopping (fewer tokens)") |
| print("- Higher threshold → Later stopping (more tokens)") |
| print("- No early stopping → Always generates max_tokens") |
| print("- Confidence varies based on model certainty") |
|
|