| | """ |
| | Simple examples showing DeepConf sample generations |
| | """ |
| |
|
| | import torch |
| |
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
| |
|
| |
|
| | def generate_with_deepconf( |
| | question: str, |
| | enable_early_stopping: bool = True, |
| | threshold: float = 10.0, |
| | window_size: int = 10, |
| | max_tokens: int = 128, |
| | ): |
| | """Generate with DeepConf and show results""" |
| |
|
| | |
| | model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) |
| |
|
| | |
| | messages = [{"role": "user", "content": question}] |
| | prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| |
|
| | |
| | gen_config = GenerationConfig( |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.95, |
| | max_new_tokens=max_tokens, |
| | enable_conf=True, |
| | enable_early_stopping=enable_early_stopping, |
| | threshold=threshold, |
| | window_size=window_size, |
| | output_confidences=True, |
| | return_dict_in_generate=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | |
| | outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True) |
| |
|
| | |
| | generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) |
| | tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1] |
| |
|
| | if hasattr(outputs, "confidences") and outputs.confidences is not None: |
| | min_conf = outputs.confidences.min().item() |
| | max_conf = outputs.confidences.max().item() |
| | mean_conf = outputs.confidences.mean().item() |
| | else: |
| | min_conf = max_conf = mean_conf = None |
| |
|
| | return { |
| | "text": generated_text, |
| | "tokens": tokens_generated, |
| | "min_conf": min_conf, |
| | "max_conf": max_conf, |
| | "mean_conf": mean_conf, |
| | } |
| |
|
| |
|
| | def print_result(title: str, question: str, result: dict): |
| | """Pretty print generation result""" |
| | print(f"\n{'=' * 80}") |
| | print(f"{title}") |
| | print(f"{'=' * 80}") |
| | print(f"Question: {question}") |
| | print(f"\nGenerated ({result['tokens']} tokens):") |
| | print(f"{'-' * 80}") |
| | print(result["text"]) |
| | print(f"{'-' * 80}") |
| |
|
| | if result["min_conf"] is not None: |
| | print("\nConfidence stats:") |
| | print(f" Min: {result['min_conf']:.3f}") |
| | print(f" Max: {result['max_conf']:.3f}") |
| | print(f" Mean: {result['mean_conf']:.3f}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("\n" + "â–ˆ" * 80) |
| | print("DEEPCONF SAMPLE GENERATIONS") |
| | print("â–ˆ" * 80) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64 |
| | ) |
| | print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64 |
| | ) |
| | print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result) |
| |
|
| | |
| | result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64) |
| | print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "If 5 apples cost $10, how much do 3 apples cost?", |
| | enable_early_stopping=True, |
| | threshold=8.0, |
| | window_size=5, |
| | max_tokens=96, |
| | ) |
| | print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64 |
| | ) |
| | print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "Calculate: (15 + 8) × 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96 |
| | ) |
| | print_result("Example 6: Calculation", "Calculate: (15 + 8) × 2", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "Define photosynthesis in simple terms.", |
| | enable_early_stopping=True, |
| | threshold=10.0, |
| | window_size=10, |
| | max_tokens=128, |
| | ) |
| | print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result) |
| |
|
| | |
| | result = generate_with_deepconf( |
| | "Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96 |
| | ) |
| | print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result) |
| |
|
| | print(f"\n{'â–ˆ' * 80}") |
| | print("ALL EXAMPLES COMPLETE") |
| | print("â–ˆ" * 80) |
| | print("\nKey observations:") |
| | print("- Lower threshold → Earlier stopping (fewer tokens)") |
| | print("- Higher threshold → Later stopping (more tokens)") |
| | print("- No early stopping → Always generates max_tokens") |
| | print("- Confidence varies based on model certainty") |
| |
|