|
|
|
|
|
""" |
|
|
FunctionGemma evaluation script (v2). |
|
|
|
|
|
Uses a unified system prompt for evaluation. |
|
|
|
|
|
Usage: |
|
|
python -m src.evaluate --model_path ./runs/<run>/final_model --benchmark_path ./data/benchmark_dataset.json |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from datetime import datetime |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from threading import Lock |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
if str(PROJECT_ROOT) not in sys.path: |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
DEFAULT_BENCHMARK_PATH = PROJECT_ROOT / "data" / "benchmark_dataset.json" |
|
|
DEFAULT_RESULTS_DIR = PROJECT_ROOT / "results" |
|
|
|
|
|
from src.config import ( |
|
|
get_system_prompt, get_system_prompt_short, TOOLS, |
|
|
SOLANA_TOKENS, get_token_address |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def load_model( |
|
|
model_path: str, |
|
|
lora_path: Optional[str] = None, |
|
|
device: str = "auto", |
|
|
load_in_8bit: bool = False, |
|
|
load_in_4bit: bool = False, |
|
|
): |
|
|
"""Load model and tokenizer.""" |
|
|
logger.info(f"Loading model: {model_path}") |
|
|
|
|
|
kwargs = { |
|
|
"device_map": device, |
|
|
"trust_remote_code": True, |
|
|
} |
|
|
|
|
|
if load_in_8bit: |
|
|
kwargs["load_in_8bit"] = True |
|
|
elif load_in_4bit: |
|
|
from transformers import BitsAndBytesConfig |
|
|
kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
else: |
|
|
kwargs["torch_dtype"] = torch.bfloat16 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) |
|
|
|
|
|
if lora_path: |
|
|
logger.info(f"Loading LoRA adapter: {lora_path}") |
|
|
model = PeftModel.from_pretrained(model, lora_path) |
|
|
|
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def parse_functiongemma_output(response: str) -> Tuple[Optional[str], Optional[Dict]]: |
|
|
""" |
|
|
Parse FunctionGemma formatted output. |
|
|
|
|
|
Format: <start_function_call>call:FUNC_NAME{key:<escape>value<escape>,...}<end_function_call> |
|
|
""" |
|
|
|
|
|
pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}<end_function_call>' |
|
|
match = re.search(pattern, response) |
|
|
|
|
|
if not match: |
|
|
|
|
|
pattern = r'<start_function_call>call:(\w+)\{([^}]*)\}' |
|
|
match = re.search(pattern, response) |
|
|
|
|
|
if not match: |
|
|
|
|
|
pattern = r'<start_function_call>call:(\w+)' |
|
|
match = re.search(pattern, response) |
|
|
if match: |
|
|
return match.group(1), {} |
|
|
|
|
|
|
|
|
for func in ["SEARCH_TOKEN", "EXECUTE_SWAP"]: |
|
|
if func in response: |
|
|
return func, {} |
|
|
|
|
|
return None, None |
|
|
|
|
|
func_name = match.group(1) |
|
|
params_str = match.group(2) if len(match.groups()) > 1 else "" |
|
|
|
|
|
|
|
|
args = parse_params_string(params_str) |
|
|
|
|
|
return func_name, args |
|
|
|
|
|
|
|
|
def parse_params_string(params_str: str) -> Dict: |
|
|
"""Parse parameter string.""" |
|
|
args = {} |
|
|
if not params_str: |
|
|
return args |
|
|
|
|
|
|
|
|
param_pattern = r'(\w+):(?:<escape>([^<]*)<escape>|([^,}]+))' |
|
|
|
|
|
for match in re.finditer(param_pattern, params_str): |
|
|
key = match.group(1) |
|
|
value = match.group(2) if match.group(2) is not None else match.group(3) |
|
|
|
|
|
if value is None: |
|
|
continue |
|
|
|
|
|
value = value.strip() |
|
|
|
|
|
|
|
|
if value.endswith('%'): |
|
|
try: |
|
|
args[key] = float(value[:-1]) / 100 |
|
|
continue |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
if '.' in value: |
|
|
args[key] = float(value) |
|
|
else: |
|
|
args[key] = int(value) |
|
|
except ValueError: |
|
|
args[key] = value |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def is_rejection_response(response: str) -> bool: |
|
|
"""Check if the response is a rejection/clarification.""" |
|
|
|
|
|
if '<start_function_call>' not in response: |
|
|
return True |
|
|
|
|
|
|
|
|
rejection_keywords = [ |
|
|
"please specify", "could you", "what token", "which token", |
|
|
"请问", "请提供", "请告诉", "您能", "什么代币", "哪个代币", |
|
|
"sorry", "can't", "cannot", "unable", "抱歉", "无法", |
|
|
"more information", "more details", "更多信息", |
|
|
] |
|
|
|
|
|
response_lower = response.lower() |
|
|
for keyword in rejection_keywords: |
|
|
if keyword.lower() in response_lower: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def format_messages_for_model( |
|
|
messages: List[Dict], |
|
|
tokenizer, |
|
|
tools: List[Dict] = None, |
|
|
) -> str: |
|
|
"""Format messages into the model chat template.""" |
|
|
if hasattr(tokenizer, 'apply_chat_template'): |
|
|
try: |
|
|
return tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tools=tools, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
formatted = "" |
|
|
for msg in messages: |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
|
|
|
if role == "system": |
|
|
formatted += f"<start_of_turn>system\n{content}<end_of_turn>\n" |
|
|
elif role == "user": |
|
|
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n" |
|
|
elif role == "assistant": |
|
|
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n" |
|
|
|
|
|
formatted += "<start_of_turn>model\n" |
|
|
return formatted |
|
|
|
|
|
|
|
|
def generate_response( |
|
|
model, |
|
|
tokenizer, |
|
|
prompt: str, |
|
|
system_prompt: str, |
|
|
max_new_tokens: int = 256, |
|
|
) -> str: |
|
|
"""Generate model response.""" |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
|
|
|
input_text = format_messages_for_model(messages, tokenizer, TOOLS) |
|
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=0.1, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) |
|
|
response = response.replace("<end_of_turn>", "").strip() |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
def compare_arguments(expected: Dict, actual: Dict) -> Tuple[float, List[str]]: |
|
|
"""Compare expected vs actual arguments.""" |
|
|
if not expected: |
|
|
return 1.0 if not actual else 0.0, [] |
|
|
|
|
|
if not actual: |
|
|
return 0.0, ["No arguments extracted"] |
|
|
|
|
|
errors = [] |
|
|
total_keys = set(expected.keys()) | set(actual.keys()) |
|
|
|
|
|
if not total_keys: |
|
|
return 1.0, [] |
|
|
|
|
|
matched = 0 |
|
|
|
|
|
for key in expected.keys(): |
|
|
exp_val = expected.get(key) |
|
|
act_val = actual.get(key) |
|
|
|
|
|
if exp_val is None: |
|
|
continue |
|
|
|
|
|
if act_val is None: |
|
|
errors.append(f"Missing key: {key}") |
|
|
continue |
|
|
|
|
|
|
|
|
if str(exp_val) == str(act_val): |
|
|
matched += 1 |
|
|
elif isinstance(exp_val, str) and isinstance(act_val, str): |
|
|
|
|
|
if exp_val[:10] == act_val[:10]: |
|
|
matched += 0.5 |
|
|
errors.append(f"Partial match for {key}") |
|
|
else: |
|
|
errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
|
|
elif isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)): |
|
|
if abs(float(exp_val) - float(act_val)) < 0.01: |
|
|
matched += 1 |
|
|
else: |
|
|
errors.append(f"Value mismatch for {key}: expected {exp_val}, got {act_val}") |
|
|
else: |
|
|
errors.append(f"Type mismatch for {key}") |
|
|
|
|
|
|
|
|
for key in actual.keys(): |
|
|
if key not in expected: |
|
|
errors.append(f"Extra key: {key}") |
|
|
|
|
|
score = matched / len([k for k in expected.keys() if expected.get(k) is not None]) if expected else 1.0 |
|
|
return score, errors |
|
|
|
|
|
|
|
|
def process_single_sample( |
|
|
sample: Dict, |
|
|
idx: int, |
|
|
model, |
|
|
tokenizer, |
|
|
system_prompt: str, |
|
|
) -> Dict: |
|
|
"""Process one sample and return evaluation result.""" |
|
|
sample_id = sample.get("id", idx + 1) |
|
|
category = sample.get("category", "unknown") |
|
|
user_input = sample["input"] |
|
|
expected_func = sample["expected"]["function_name"] |
|
|
expected_args = sample["expected"].get("arguments", {}) |
|
|
|
|
|
|
|
|
if isinstance(user_input, dict) and "messages" in user_input: |
|
|
prompt = "" |
|
|
for msg in user_input["messages"]: |
|
|
if msg.get("role") == "user": |
|
|
prompt = msg.get("content", "") |
|
|
break |
|
|
else: |
|
|
prompt = str(user_input) |
|
|
|
|
|
|
|
|
response = generate_response(model, tokenizer, prompt, system_prompt) |
|
|
|
|
|
|
|
|
actual_func, actual_args = parse_functiongemma_output(response) |
|
|
is_rejection = is_rejection_response(response) |
|
|
|
|
|
|
|
|
func_correct = False |
|
|
args_correct = False |
|
|
exact_match = False |
|
|
arg_score = 0.0 |
|
|
error_msg = None |
|
|
rejection_correct = False |
|
|
|
|
|
if expected_func is None: |
|
|
|
|
|
func_correct = is_rejection or actual_func is None |
|
|
args_correct = func_correct |
|
|
exact_match = func_correct |
|
|
arg_score = 1.0 if func_correct else 0.0 |
|
|
rejection_correct = func_correct |
|
|
|
|
|
if not func_correct: |
|
|
error_msg = f"Expected rejection, got {actual_func}" |
|
|
else: |
|
|
|
|
|
func_correct = actual_func == expected_func |
|
|
|
|
|
if func_correct: |
|
|
|
|
|
arg_score, arg_errors = compare_arguments(expected_args, actual_args or {}) |
|
|
args_correct = arg_score >= 0.99 |
|
|
exact_match = args_correct |
|
|
|
|
|
if not args_correct: |
|
|
error_msg = "; ".join(arg_errors) |
|
|
else: |
|
|
error_msg = f"Expected {expected_func}, got {actual_func}" |
|
|
|
|
|
|
|
|
result = { |
|
|
"sample_id": sample_id, |
|
|
"category": category, |
|
|
"expected_func": expected_func, |
|
|
"actual_func": actual_func, |
|
|
"func_correct": func_correct, |
|
|
"args_correct": args_correct, |
|
|
"exact_match": exact_match, |
|
|
"rejection_correct": rejection_correct, |
|
|
"arg_score": arg_score, |
|
|
"error_msg": error_msg, |
|
|
"user_input": user_input, |
|
|
"expected_args": expected_args, |
|
|
"actual_args": actual_args, |
|
|
"response": response, |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def evaluate_benchmark( |
|
|
model, |
|
|
tokenizer, |
|
|
benchmark: List[Dict], |
|
|
chain: str = "solana", |
|
|
verbose: bool = False, |
|
|
num_workers: int = 1, |
|
|
) -> Dict: |
|
|
"""Evaluate the benchmark (supports concurrency).""" |
|
|
system_prompt = get_system_prompt_short(chain) |
|
|
|
|
|
results = { |
|
|
"total": len(benchmark), |
|
|
"function_correct": 0, |
|
|
"arguments_correct": 0, |
|
|
"exact_match": 0, |
|
|
"rejection_correct": 0, |
|
|
"total_arg_score": 0.0, |
|
|
"by_category": {}, |
|
|
"by_function": {}, |
|
|
"errors": [], |
|
|
} |
|
|
|
|
|
|
|
|
results_lock = Lock() |
|
|
|
|
|
|
|
|
if num_workers > 1: |
|
|
logger.info(f"Evaluating with {num_workers} worker threads") |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
|
|
|
futures = { |
|
|
executor.submit( |
|
|
process_single_sample, |
|
|
sample, i, model, tokenizer, system_prompt |
|
|
): i for i, sample in enumerate(benchmark) |
|
|
} |
|
|
|
|
|
|
|
|
with tqdm(total=len(benchmark), desc="Evaluation") as pbar: |
|
|
for future in as_completed(futures): |
|
|
result = future.result() |
|
|
|
|
|
|
|
|
with results_lock: |
|
|
_update_results(results, result, verbose) |
|
|
|
|
|
pbar.update(1) |
|
|
else: |
|
|
|
|
|
logger.info("Evaluating with a single thread") |
|
|
for i, sample in enumerate(tqdm(benchmark, desc="Evaluation")): |
|
|
result = process_single_sample(sample, i, model, tokenizer, system_prompt) |
|
|
_update_results(results, result, verbose) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def _update_results(results: Dict, result: Dict, verbose: bool): |
|
|
"""Update aggregated evaluation results.""" |
|
|
sample_id = result["sample_id"] |
|
|
category = result["category"] |
|
|
expected_func = result["expected_func"] |
|
|
actual_func = result["actual_func"] |
|
|
func_correct = result["func_correct"] |
|
|
args_correct = result["args_correct"] |
|
|
exact_match = result["exact_match"] |
|
|
rejection_correct = result["rejection_correct"] |
|
|
arg_score = result["arg_score"] |
|
|
error_msg = result["error_msg"] |
|
|
|
|
|
|
|
|
if func_correct: |
|
|
results["function_correct"] += 1 |
|
|
if args_correct: |
|
|
results["arguments_correct"] += 1 |
|
|
if exact_match: |
|
|
results["exact_match"] += 1 |
|
|
if rejection_correct: |
|
|
results["rejection_correct"] += 1 |
|
|
results["total_arg_score"] += arg_score |
|
|
|
|
|
|
|
|
if category not in results["by_category"]: |
|
|
results["by_category"][category] = { |
|
|
"total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
|
|
} |
|
|
results["by_category"][category]["total"] += 1 |
|
|
if func_correct: |
|
|
results["by_category"][category]["func_correct"] += 1 |
|
|
if exact_match: |
|
|
results["by_category"][category]["exact_match"] += 1 |
|
|
results["by_category"][category]["arg_score"] += arg_score |
|
|
|
|
|
|
|
|
func_key = expected_func or "None" |
|
|
if func_key not in results["by_function"]: |
|
|
results["by_function"][func_key] = { |
|
|
"total": 0, "func_correct": 0, "exact_match": 0, "arg_score": 0.0 |
|
|
} |
|
|
results["by_function"][func_key]["total"] += 1 |
|
|
if func_correct: |
|
|
results["by_function"][func_key]["func_correct"] += 1 |
|
|
if exact_match: |
|
|
results["by_function"][func_key]["exact_match"] += 1 |
|
|
results["by_function"][func_key]["arg_score"] += arg_score |
|
|
|
|
|
|
|
|
if error_msg and len(results["errors"]) < 10: |
|
|
results["errors"].append({ |
|
|
"id": sample_id, |
|
|
"category": category, |
|
|
"input": result["user_input"], |
|
|
"expected_func": expected_func, |
|
|
"actual_func": actual_func, |
|
|
"expected_args": result["expected_args"], |
|
|
"actual_args": result["actual_args"], |
|
|
"error": error_msg, |
|
|
"response": result["response"][:200], |
|
|
}) |
|
|
|
|
|
if verbose: |
|
|
status = "✓" if exact_match else "✗" |
|
|
|
|
|
user_input = result["user_input"] |
|
|
if isinstance(user_input, dict): |
|
|
user_msg = "" |
|
|
if "messages" in user_input: |
|
|
for msg in user_input["messages"]: |
|
|
if msg.get("role") == "user": |
|
|
user_msg = msg.get("content", "") |
|
|
break |
|
|
input_preview = user_msg[:50] if user_msg else str(user_input)[:50] |
|
|
else: |
|
|
input_preview = str(user_input)[:50] |
|
|
logger.info(f"[{sample_id}] {status} {category}: {input_preview}...") |
|
|
|
|
|
|
|
|
def print_report(results: Dict): |
|
|
"""Print evaluation report.""" |
|
|
total = results["total"] |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("FunctionGemma Evaluation Report") |
|
|
print("=" * 70) |
|
|
print(f"\nTotal samples: {total}") |
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
print("Overall metrics") |
|
|
print("-" * 70) |
|
|
|
|
|
func_acc = results["function_correct"] / total * 100 if total > 0 else 0 |
|
|
arg_acc = results["arguments_correct"] / total * 100 if total > 0 else 0 |
|
|
exact_acc = results["exact_match"] / total * 100 if total > 0 else 0 |
|
|
avg_arg_score = results["total_arg_score"] / total * 100 if total > 0 else 0 |
|
|
|
|
|
|
|
|
rejection_samples = sum(1 for f in results["by_function"].values() if "None" in str(f)) |
|
|
rejection_total = results["by_function"].get("None", {}).get("total", 0) |
|
|
rejection_acc = results["rejection_correct"] / rejection_total * 100 if rejection_total > 0 else 0 |
|
|
|
|
|
print(f"Function selection accuracy: {func_acc:.2f}%") |
|
|
print(f"Argument accuracy: {arg_acc:.2f}%") |
|
|
print(f"Exact match accuracy: {exact_acc:.2f}%") |
|
|
print(f"Average argument score: {avg_arg_score:.2f}%") |
|
|
print(f"Rejection accuracy: {rejection_acc:.2f}%") |
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
print("By function") |
|
|
print("-" * 70) |
|
|
|
|
|
for func, stats in sorted(results["by_function"].items()): |
|
|
func_total = stats["total"] |
|
|
func_correct = stats["func_correct"] / func_total * 100 if func_total > 0 else 0 |
|
|
func_arg_score = stats["arg_score"] / func_total * 100 if func_total > 0 else 0 |
|
|
func_exact = stats["exact_match"] / func_total * 100 if func_total > 0 else 0 |
|
|
|
|
|
print(f"{func:15} | samples: {func_total:3} | func acc: {func_correct:6.2f}% | " |
|
|
f"arg score: {func_arg_score:6.2f}% | exact: {func_exact:6.2f}%") |
|
|
|
|
|
if results["errors"]: |
|
|
print("\n" + "-" * 70) |
|
|
print("Error samples") |
|
|
print("-" * 70) |
|
|
|
|
|
for err in results["errors"][:5]: |
|
|
print(f"\nID: {err['id']} | category: {err['category']}") |
|
|
print(f"Input: {err['input']}") |
|
|
print(f"Expected: {err['expected_func']} | Actual: {err['actual_func']}") |
|
|
print(f"Error: {err['error']}") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="FunctionGemma evaluation (v2)") |
|
|
parser.add_argument("--model_path", type=str, required=True, help="Model path") |
|
|
parser.add_argument("--lora_path", type=str, default=None, help="LoRA adapter path") |
|
|
parser.add_argument("--benchmark_path", type=str, default=str(DEFAULT_BENCHMARK_PATH), help="Benchmark dataset path") |
|
|
parser.add_argument("--output_path", type=str, default=None, help="Output path (defaults to results/ with timestamp)") |
|
|
parser.add_argument("--chain", type=str, default="solana", help="Chain name") |
|
|
parser.add_argument("--load_in_8bit", action="store_true", help="Enable 8-bit quantization") |
|
|
parser.add_argument("--load_in_4bit", action="store_true", help="Enable 4-bit quantization") |
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose logging") |
|
|
parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads (default 4)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model, tokenizer = load_model( |
|
|
args.model_path, |
|
|
lora_path=args.lora_path, |
|
|
load_in_8bit=args.load_in_8bit, |
|
|
load_in_4bit=args.load_in_4bit, |
|
|
) |
|
|
|
|
|
|
|
|
benchmark_path = Path(args.benchmark_path) |
|
|
logger.info(f"Loading benchmark: {benchmark_path}") |
|
|
with open(benchmark_path, 'r', encoding='utf-8') as f: |
|
|
benchmark = json.load(f) |
|
|
|
|
|
logger.info(f"Benchmark samples: {len(benchmark)}") |
|
|
|
|
|
|
|
|
logger.info("Starting evaluation...") |
|
|
results = evaluate_benchmark( |
|
|
model, tokenizer, benchmark, |
|
|
chain=args.chain, |
|
|
verbose=args.verbose, |
|
|
num_workers=args.num_workers, |
|
|
) |
|
|
|
|
|
|
|
|
print_report(results) |
|
|
|
|
|
|
|
|
output_path = Path(args.output_path) if args.output_path else DEFAULT_RESULTS_DIR / f"evaluation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"Evaluation saved to: {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|