| | import argparse |
| | import time |
| | import mlx.core as mx |
| | from transformers import AutoTokenizer |
| | from model import load_model |
| | from pathlib import Path |
| |
|
| |
|
| | def generate_text( |
| | prompt: str, |
| | model_path: str, |
| | max_tokens: int = 100, |
| | temperature: float = 0.1, |
| | top_p: float = 0.9, |
| | system: str | None = None, |
| | final_only: bool = False, |
| | stop_at_boxed: bool = False, |
| | extract_boxed: bool = False, |
| | disable_chat_template: bool = False, |
| | repetition_penalty: float = 1.0, |
| | frequency_penalty: float = 0.0, |
| | ): |
| | """Generates text using the loaded MLX model with better sampling.""" |
| | print("Loading model and tokenizer...") |
| | model = load_model(model_path) |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
|
| | |
| | chat_template_path = Path(model_path) / "chat_template.jinja" |
| | use_chat_format = chat_template_path.exists() and not disable_chat_template |
| |
|
| | print(f"Chat template found: {use_chat_format}") |
| | print("Starting generation...") |
| | print(f"Prompt: {prompt}") |
| |
|
| | |
| | if use_chat_format: |
| | messages = [] |
| | if system is None and final_only: |
| | system = ( |
| | "You are a helpful assistant. Do not reveal your reasoning. " |
| | "Respond with only the final answer enclosed in \\boxed{...}." |
| | ) |
| | if system is not None: |
| | messages.append({"role": "system", "content": system}) |
| | messages.append({"role": "user", "content": prompt}) |
| | formatted_prompt = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | print(f"Formatted prompt: {formatted_prompt}") |
| | else: |
| | |
| | bos = tokenizer.bos_token or "" |
| | formatted_prompt = f"{bos}{prompt}" |
| |
|
| | |
| | prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False) |
| | prompt_tokens = mx.array([prompt_tokens]) |
| |
|
| | print(f"Prompt tokens shape: {prompt_tokens.shape}") |
| | print( |
| | f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}" |
| | ) |
| |
|
| | |
| | start_time = time.time() |
| | generated_tokens = [] |
| | freq_counts = {} |
| |
|
| | running_text = "" |
| | seen_box_start = False |
| | for i in range(max_tokens): |
| | |
| | logits = model(prompt_tokens) |
| |
|
| | |
| | next_token_logits = logits[0, -1, :] |
| |
|
| | |
| | if repetition_penalty and repetition_penalty != 1.0 and generated_tokens: |
| | |
| | |
| | logits_list = next_token_logits.tolist() |
| | seen = set(generated_tokens) |
| | for tid in seen: |
| | val = logits_list[tid] |
| | if val > 0: |
| | logits_list[tid] = val / repetition_penalty |
| | else: |
| | logits_list[tid] = val * repetition_penalty |
| | next_token_logits = mx.array(logits_list) |
| |
|
| | if frequency_penalty and frequency_penalty > 0 and generated_tokens: |
| | |
| | counts = {} |
| | for t in generated_tokens: |
| | counts[t] = counts.get(t, 0) + 1 |
| | |
| | vocab_size = next_token_logits.shape[-1] |
| | pen = [0.0] * vocab_size |
| | for tid, c in counts.items(): |
| | pen[tid] = frequency_penalty * float(c) |
| | next_token_logits = next_token_logits - mx.array(pen) |
| |
|
| | |
| | if temperature == 0: |
| | |
| | next_token = int(mx.argmax(next_token_logits).item()) |
| | else: |
| | |
| | scaled_logits = next_token_logits / temperature |
| |
|
| | if 0.0 < top_p < 1.0: |
| | probs = mx.softmax(scaled_logits, axis=-1) |
| | sorted_probs = mx.sort(probs)[::-1] |
| | cumulative_probs = mx.cumsum(sorted_probs, axis=-1) |
| | cutoff_index = mx.sum(cumulative_probs < top_p) |
| | cutoff_prob = sorted_probs[cutoff_index.item()] |
| | mask = probs >= cutoff_prob |
| | scaled_logits = mx.where(mask, scaled_logits, float("-inf")) |
| |
|
| | |
| | next_token = mx.random.categorical(scaled_logits, num_samples=1).item() |
| |
|
| | |
| | eos_ids = tokenizer.eos_token_id |
| | if isinstance(eos_ids, (list, tuple)): |
| | stop_ids = set(int(i) for i in eos_ids) |
| | else: |
| | stop_ids = {int(eos_ids)} |
| | if next_token in stop_ids: |
| | print(f"Stopping generation at EOS token: {next_token}") |
| | break |
| |
|
| | generated_tokens.append(next_token) |
| | |
| | freq_counts[next_token] = freq_counts.get(next_token, 0) + 1 |
| | |
| | prompt_tokens = mx.concatenate( |
| | [prompt_tokens, mx.array([[next_token]])], axis=1 |
| | ) |
| |
|
| | |
| | if i < 10: |
| | token_text = tokenizer.decode([next_token]) |
| | print(f"Token {i}: {next_token} -> '{token_text}'") |
| |
|
| | |
| | if stop_at_boxed: |
| | token_text_full = tokenizer.decode([next_token], skip_special_tokens=False) |
| | running_text += token_text_full |
| | if not seen_box_start and "\\boxed{" in running_text: |
| | seen_box_start = True |
| | if seen_box_start and "}" in running_text: |
| | print("Stopping generation at boxed answer.") |
| | break |
| |
|
| | end_time = time.time() |
| |
|
| | |
| | if generated_tokens: |
| | response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | print("\n--- Response ---") |
| | print(response) |
| | else: |
| | print("\n--- No tokens generated ---") |
| |
|
| | print("------------------") |
| |
|
| | generation_speed = ( |
| | len(generated_tokens) / (end_time - start_time) if generated_tokens else 0 |
| | ) |
| | print(f"Generated {len(generated_tokens)} tokens") |
| | print(f"Generation speed: {generation_speed:.2f} tokens/sec") |
| |
|
| | |
| | if generated_tokens: |
| | full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False) |
| | print(f"\nFull response (with special tokens): '{full_response}'") |
| |
|
| | if extract_boxed and generated_tokens: |
| | import re |
| | m = None |
| | |
| | for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response): |
| | pass |
| | if m: |
| | print(f"\nExtracted boxed answer: {m.group(1).strip()}") |
| | else: |
| | print("\nNo \\boxed{...} segment found to extract.") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Run inference with the MLX model.") |
| | parser.add_argument( |
| | "--model-path", type=str, default=".", help="Path to the model directory." |
| | ) |
| | parser.add_argument( |
| | "--prompt", |
| | type=str, |
| | default="What is the capital of France?", |
| | help="The prompt to start generation from.", |
| | ) |
| | parser.add_argument( |
| | "--max-tokens", |
| | type=int, |
| | default=100, |
| | help="The maximum number of tokens to generate.", |
| | ) |
| | parser.add_argument( |
| | "--temperature", type=float, default=0.1, help="Sampling temperature." |
| | ) |
| | parser.add_argument( |
| | "--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter." |
| | ) |
| | parser.add_argument( |
| | "--system", type=str, default=None, help="Optional system message for chat template." |
| | ) |
| | parser.add_argument( |
| | "--final-only", |
| | action="store_true", |
| | help="Instruct the model to output only the final answer inside \\boxed{...}.", |
| | ) |
| | parser.add_argument( |
| | "--stop-at-boxed", |
| | action="store_true", |
| | help="Stop generation once a closing '}' appears after \\boxed{.", |
| | ) |
| | parser.add_argument( |
| | "--extract-boxed", |
| | action="store_true", |
| | help="Extract and print the content inside the last \\boxed{...} in the response.", |
| | ) |
| | parser.add_argument( |
| | "--disable-chat-template", |
| | action="store_true", |
| | help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).", |
| | ) |
| | parser.add_argument( |
| | "--repetition-penalty", |
| | type=float, |
| | default=1.0, |
| | help="Penalty (>1.0) to discourage previously generated tokens.", |
| | ) |
| | parser.add_argument( |
| | "--frequency-penalty", |
| | type=float, |
| | default=0.0, |
| | help="Subtract alpha * count(token) from logits before sampling.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | generate_text( |
| | args.prompt, |
| | args.model_path, |
| | args.max_tokens, |
| | args.temperature, |
| | args.top_p, |
| | args.system, |
| | args.final_only, |
| | args.stop_at_boxed, |
| | args.extract_boxed, |
| | args.disable_chat_template, |
| | args.repetition_penalty, |
| | args.frequency_penalty, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|