#!/usr/bin/env python3 """ FunctionGemma evaluation script (v2). Uses a unified system prompt for evaluation. Usage: python -m src.evaluate --model_path ./runs//final_model --benchmark_path ./data/benchmark_dataset.json """ import os import re import sys import json import argparse import logging from pathlib import Path from typing import Dict, List, Optional, Tuple from datetime import datetime from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from tqdm import tqdm # Import config PROJECT_ROOT = Path(__file__).resolve().parent.parent if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json" DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results" from src.config import ( # noqa: E402 get_system_prompt, get_system_prompt_short, TOOLS, SOLANA_TOKENS, get_token_address ) # Logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def load_model( model_path: str, lora_path: Optional[str] = None, device: str = "auto", load_in_8bit: bool = False, load_in_4bit: bool = False, ): """Load model and tokenizer.""" logger.info(f"Loading model: {model_path}") kwargs = { "device_map": device, "trust_remote_code": True, } if load_in_8bit: kwargs["load_in_8bit"] = True elif load_in_4bit: from transformers import BitsAndBytesConfig kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) else: kwargs["torch_dtype"] = torch.bfloat16 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) if lora_path: logger.info(f"Loading LoRA adapter: {lora_path}") model = PeftModel.from_pretrained(model, lora_path) model.eval() return model, tokenizer def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]: """ Parse FunctionGemma formatted output. Format: call:FUNC_NAME{key:value,...} """ # full match pattern = r'call:(\w+)\{([^}]*)\}' match = re.search(pattern, response) if not match: # partial match (truncated) pattern = r'call:(\w+)\{([^}]*)\}' match = re.search(pattern, response) if not match: # match function name only pattern = r'call:(\w+)' match = re.search(pattern, response) if match: return match.group(1), {} # fallback: look for function names for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]: if func in response: return func, {} return None, None func_name = match.group(1) params_str = match.group(2) if len(match.groups()) > 1 else "" # parse arguments args = parse_params_string(params_str) return func_name, args def parse_params_string(params_str: str) -> Dict: """Parse parameter string.""" args = {} if not params_str: return args # pattern: key:value or key:value param_pattern = r'(\w+):(?:([^<]*)|([^,}]+))' for match in re.finditer(param_pattern, params_str): key = match.group(1) value = match.group(2) if match.group(2) is not None else match.group(3) if value is None: continue value = value.strip() # handle percentage if value.endswith('%'): try: args[key] = float(value[:-1]) / 100 continue except ValueError: pass # attempt numeric conversion try: if '.' in value: args[key] = float(value) else: args[key] = int(value) except ValueError: args[key] = value return args def is_rejection_response(response: str) -> bool: """Check if the response is a rejection/clarification.""" # no function call markers if '' not in response: return True # check clarification/rejection keywords (keep Chinese variants for CN prompts) rejection_keywords = [ "please specify", "could you", "what token", "which token", "请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币", "sorry", "can't", "cannot", "unable", "抱歉", "无法", "more information", "more details", "更多信息", ] response_lower = response.lower() for keyword in rejection_keywords: if keyword.lower() in response_lower: return True return False def format_messages_for_model( messages: List[Dict], tokenizer, tools: List[Dict] = None, ) -> str: """Format messages into the model chat template.""" if hasattr(tokenizer, 'apply_chat_template'): try: return tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, add_generation_prompt=True, ) except Exception: pass # Manual formatting fallback formatted = "" for msg in messages: role = msg["role"] content = msg["content"] if role == "system": formatted += f"system\n{content}\n" elif role == "user": formatted += f"user\n{content}\n" elif role == "assistant": formatted += f"model\n{content}\n" formatted += "model\n" return formatted def generate_response( model, tokenizer, prompt: str, system_prompt: str, max_new_tokens: int = 256, ) -> str: """Generate model response.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] input_text = format_messages_for_model(messages, tokenizer, TOOLS) inputs = tokenizer(input_text, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.1, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) response = response.replace("", "").strip() return response def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]: """Compare expected vs actual arguments.""" if not expected: return 1.0 if not actual else 0.0, [] if not actual: return 0.0, ["No arguments extracted"] errors = [] total_keys = set(expected.keys()) | set(actual.keys()) if not total_keys: return 1.0, [] matched = 0 for key in expected.keys(): exp_val = expected.get(key) act_val = actual.get(key) if exp_val is None: continue if act_val is None: errors.append(f"Missing key: {key}") continue # Compare values if str(exp_val) == str(act_val): matched += 1 elif isinstance(exp_val, str) and isinstance(act_val, str): # Partial match (contract address prefix) if exp_val[:10] == act_val[:10]: matched += 0.5 errors.append(f"Partial match for {key}") else: errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)): if abs(float(exp_val) - float(act_val)) < 0.01: matched += 1 else: errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") else: errors.append(f"Type mismatch for {key}") # Check extra keys for key in actual.keys(): if key not in expected: errors.append(f"Extra key: {key}") score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0 return score, errors def process_single_sample( sample: Dict, idx: int, model, tokenizer, system_prompt: str, ) -> Dict: """Process one sample and return evaluation result.""" sample_id = sample.get("id", idx + 1) category = sample.get("category", "unknown") user_input = sample["input"] expected_func = sample["expected"]["function_name"] expected_args = sample["expected"].get("arguments", {}) # Extract user message if isinstance(user_input, dict) and "messages" in user_input: prompt = "" for msg in user_input["messages"]: if msg.get("role") == "user": prompt = msg.get("content", "") break else: prompt = str(user_input) # Generate response response = generate_response(model, tokenizer, prompt, system_prompt) # Parse response actual_func, actual_args = parse_functiongemma_output(response) is_rejection = is_rejection_response(response) # Evaluate func_correct = False args_correct = False exact_match = False arg_score = 0.0 error_msg = None rejection_correct = False if expected_func is None: # Expecting rejection func_correct = is_rejection or actual_func is None args_correct = func_correct exact_match = func_correct arg_score = 1.0 if func_correct else 0.0 rejection_correct = func_correct if not func_correct: error_msg = f"Expected rejection, got {actual_func}" else: # Expecting a function call func_correct = actual_func == expected_func if func_correct: # Compare arguments arg_score, arg_errors = compare_arguments(expected_args, actual_args or {}) args_correct = arg_score >= 0.99 exact_match = args_correct if not args_correct: error_msg = "; ".join(arg_errors) else: error_msg = f"Expected {expected_func}, got {actual_func}" # Return result result = { "sample_id": sample_id, "category": category, "expected_func": expected_func, "actual_func": actual_func, "func_correct": func_correct, "args_correct": args_correct, "exact_match": exact_match, "rejection_correct": rejection_correct, "arg_score": arg_score, "error_msg": error_msg, "user_input": user_input, "expected_args": expected_args, "actual_args": actual_args, "response": response, } return result def evaluate_benchmark( model, tokenizer, benchmark: List[Dict], chain: str = "solana", verbose: bool = False, num_workers: int = 1, ) -> Dict: """Evaluate the benchmark (supports concurrency).""" system_prompt = get_system_prompt_short(chain) results = { "total": len(benchmark), "function_correct": 0, "arguments_correct": 0, "exact_match": 0, "rejection_correct": 0, "total_arg_score": 0.0, "by_category": {}, "by_function": {}, "errors": [], } # Protect result updates with a lock results_lock = Lock() # Concurrent processing if num_workers > 1: logger.info(f"Evaluating with {num_workers} worker threads") with ThreadPoolExecutor(max_workers=num_workers) as executor: # Submit tasks futures = { executor.submit( process_single_sample, sample, i, model, tokenizer, system_prompt ): i for i, sample in enumerate(benchmark) } # Progress bar with tqdm(total=len(benchmark), desc="Evaluation") as pbar: for future in as_completed(futures): result = future.result() # Update results (locked) with results_lock: _update_results(results, result, verbose) pbar.update(1) else: # Serial path logger.info("Evaluating with a single thread") for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")): result = process_single_sample(sample, i, model, tokenizer, system_prompt) _update_results(results, result, verbose) return results def _update_results(results: Dict, result: Dict, verbose: bool): """Update aggregated evaluation results.""" sample_id = result["sample_id"] category = result["category"] expected_func = result["expected_func"] actual_func = result["actual_func"] func_correct = result["func_correct"] args_correct = result["args_correct"] exact_match = result["exact_match"] rejection_correct = result["rejection_correct"] arg_score = result["arg_score"] error_msg = result["error_msg"] # Overall stats if func_correct: results["function_correct"] += 1 if args_correct: results["arguments_correct"] += 1 if exact_match: results["exact_match"] += 1 if rejection_correct: results["rejection_correct"] += 1 results["total_arg_score"] += arg_score # By category if category not in results["by_category"]: results["by_category"][category] = { "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 } results["by_category"][category]["total"] += 1 if func_correct: results["by_category"][category]["func_correct"] += 1 if exact_match: results["by_category"][category]["exact_match"] += 1 results["by_category"][category]["arg_score"] += arg_score # By function func_key = expected_func or "None" if func_key not in results["by_function"]: results["by_function"][func_key] = { "total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 } results["by_function"][func_key]["total"] += 1 if func_correct: results["by_function"][func_key]["func_correct"] += 1 if exact_match: results["by_function"][func_key]["exact_match"] += 1 results["by_function"][func_key]["arg_score"] += arg_score # Record errors if error_msg and len(results["errors"]) < 10: results["errors"].append({ "id": sample_id, "category": category, "input": result["user_input"], "expected_func": expected_func, "actual_func": actual_func, "expected_args": result["expected_args"], "actual_args": result["actual_args"], "error": error_msg, "response": result["response"][:200], }) if verbose: status = "✓" if exact_match else "✗" # Extract user message preview for logs user_input = result["user_input"] if isinstance(user_input, dict): user_msg = "" if "messages" in user_input: for msg in user_input["messages"]: if msg.get("role") == "user": user_msg = msg.get("content", "") break input_preview = user_msg[:50] if user_msg else str(user_input)[:50] else: input_preview = str(user_input)[:50] logger.info(f"[{sample_id}] {status} {category}: {input_preview}...") def print_report(results: Dict): """Print evaluation report.""" total = results["total"] print("\n" + "=" * 70) print("FunctionGemma Evaluation Report") print("=" * 70) print(f"\nTotal samples: {total}") print("\n" + "-" * 70) print("Overall metrics") print("-" * 70) func_acc = results["function_correct"] / total * 100 if total > 0 else 0 arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0 exact_acc = results["exact_match"] / total * 100 if total > 0 else 0 avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0 # Rejection accuracy rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f)) rejection_total = results["by_function"].get("None", {}).get("total", 0) rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0 print(f"Function selection accuracy: {func_acc:.2f}%") print(f"Argument accuracy: {arg_acc:.2f}%") print(f"Exact match accuracy: {exact_acc:.2f}%") print(f"Average argument score: {avg_arg_score:.2f}%") print(f"Rejection accuracy: {rejection_acc:.2f}%") print("\n" + "-" * 70) print("By function") print("-" * 70) for func, stats in sorted(results["by_function"].items()): func_total = stats["total"] func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0 func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0 func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0 print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | " f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%") if results["errors"]: print("\n" + "-" * 70) print("Error samples") print("-" * 70) for err in results["errors"][:5]: print(f"\nID: {err['id']} | category: {err['category']}") print(f"Input: {err['input']}") print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}") print(f"Error: {err['error']}") print("\n" + "=" * 70) def main(): parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)") parser.add_argument("--model_path", type=str, required=True, help="Model path") parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path") parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path") parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)") parser.add_argument("--chain", type=str, default="solana", help="Chain name") parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization") parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization") parser.add_argument("--verbose", action="store_true", help="Verbose logging") parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)") args = parser.parse_args() # Load model model, tokenizer = load_model( args.model_path, lora_path=args.lora_path, load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, ) # Load benchmark benchmark_path = Path(args.benchmark_path) logger.info(f"Loading benchmark: {benchmark_path}") with open(benchmark_path, 'r', encoding='utf-8') as f: benchmark = json.load(f) logger.info(f"Benchmark samples: {len(benchmark)}") # Evaluate logger.info("Starting evaluation...") results = evaluate_benchmark( model, tokenizer, benchmark, chain=args.chain, verbose=args.verbose, num_workers=args.num_workers, ) # Print report print_report(results) # Save results output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) logger.info(f"Evaluation saved to: {output_path}") if __name__ == "__main__": main()