Spaces:
Runtime error
Runtime error
| import json | |
| import argparse | |
| import os | |
| import re | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from pathlib import Path | |
| import sys | |
| # Add project root to path for imports | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from agent_prompt import SYSTEM_PROMPT | |
| from commitguard_env.parse_action import parse_action | |
| def format_eval_prompt(sample): | |
| # Matches the format_prompt in train_grpo.py for consistency | |
| return ( | |
| f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" | |
| f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" | |
| f"Analyze this commit and submit your verdict.\n\n" | |
| f"Code diff:\n```diff\n{sample['diff']}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | |
| ) | |
| def evaluate(model_path, test_file, is_lora=False, base_model=None, output_file="eval_results.json"): | |
| """ | |
| Run model on test samples, compute accuracy metrics. | |
| """ | |
| print(f"Loading model from {model_path}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model | |
| if is_lora: | |
| if not base_model: | |
| raise ValueError("base_model is required if is_lora=True") | |
| print(f"Loading LoRA adapter from {model_path} with base model {base_model}") | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = base_model, | |
| max_seq_length = 2048, | |
| load_in_4bit = True, | |
| ) | |
| model = PeftModel.from_pretrained(model, model_path) | |
| FastLanguageModel.for_inference(model) | |
| else: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # Load test data | |
| print(f"Loading test data from {test_file}...") | |
| with open(test_file, "r", encoding="utf-8") as f: | |
| samples = [json.loads(line) for line in f if line.strip()] | |
| results = { | |
| "summary": { | |
| "total": len(samples), | |
| "correct_binary": 0, | |
| "correct_cwe": 0, | |
| "false_positives": 0, | |
| "false_negatives": 0, | |
| "binary_accuracy": 0, | |
| "cwe_accuracy": 0, | |
| "false_positive_rate": 0, | |
| "false_negative_rate": 0, | |
| "cwe_breakdown": {}, | |
| }, | |
| "predictions": [], | |
| } | |
| print(f"Starting evaluation on {len(samples)} samples...") | |
| for i, sample in enumerate(samples): | |
| prompt = format_eval_prompt(sample) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.1, | |
| do_sample=False, | |
| ) | |
| response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # Use the official parser from commitguard_env | |
| prediction = parse_action(response) | |
| gt_vulnerable = bool(sample["is_vulnerable"]) | |
| pred_vulnerable = bool(prediction.is_vulnerable) if prediction.is_vulnerable is not None else False | |
| correct = pred_vulnerable == gt_vulnerable | |
| if correct: | |
| results["summary"]["correct_binary"] += 1 | |
| if gt_vulnerable and not pred_vulnerable: | |
| results["summary"]["false_negatives"] += 1 | |
| elif not gt_vulnerable and pred_vulnerable: | |
| results["summary"]["false_positives"] += 1 | |
| cwe = sample.get("cwe") or "CWE-OTHER" | |
| if cwe not in results["summary"]["cwe_breakdown"]: | |
| results["summary"]["cwe_breakdown"][cwe] = {"total": 0, "correct": 0, "accuracy": 0} | |
| results["summary"]["cwe_breakdown"][cwe]["total"] += 1 | |
| if correct: | |
| results["summary"]["cwe_breakdown"][cwe]["correct"] += 1 | |
| if gt_vulnerable and correct and prediction.vuln_type and prediction.vuln_type.strip().upper() == cwe.strip().upper(): | |
| results["summary"]["correct_cwe"] += 1 | |
| results["predictions"].append({ | |
| "sample_id": sample["sample_id"], | |
| "ground_truth": gt_vulnerable, | |
| "predicted": pred_vulnerable, | |
| "predicted_cwe": prediction.vuln_type, | |
| "actual_cwe": cwe, | |
| "response": response, | |
| }) | |
| if (i + 1) % 10 == 0: | |
| print(f" Processed {i+1}/{len(samples)} samples...") | |
| # Final summary stats | |
| summary = results["summary"] | |
| total = summary["total"] | |
| vuln_count = sum(1 for s in samples if s["is_vulnerable"]) | |
| safe_count = total - vuln_count | |
| tp = summary["correct_binary"] - (safe_count - summary["false_positives"]) # TP = correct vuln predictions | |
| # Recompute from confusion matrix | |
| tp = vuln_count - summary["false_negatives"] | |
| fp = summary["false_positives"] | |
| fn = summary["false_negatives"] | |
| tn = safe_count - fp | |
| summary["tp"] = tp | |
| summary["fp"] = fp | |
| summary["tn"] = tn | |
| summary["fn"] = fn | |
| summary["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0 | |
| summary["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0 | |
| p, r = summary["precision"], summary["recall"] | |
| summary["f1"] = 2 * p * r / (p + r) if (p + r) > 0 else 0 | |
| summary["binary_accuracy"] = summary["correct_binary"] / total if total > 0 else 0 | |
| summary["cwe_accuracy"] = summary["correct_cwe"] / vuln_count if vuln_count > 0 else 0 | |
| summary["false_positive_rate"] = summary["false_positives"] / safe_count if safe_count > 0 else 0 | |
| summary["false_negative_rate"] = summary["false_negatives"] / vuln_count if vuln_count > 0 else 0 | |
| for cwe in summary["cwe_breakdown"]: | |
| stats = summary["cwe_breakdown"][cwe] | |
| stats["accuracy"] = stats["correct"] / stats["total"] if stats["total"] > 0 else 0 | |
| print(f"\nEvaluation Complete:") | |
| print(f" Binary Accuracy: {summary['binary_accuracy']:.2%}") | |
| print(f" Precision: {summary['precision']:.2%}") | |
| print(f" Recall: {summary['recall']:.2%}") | |
| print(f" F1 Score: {summary['f1']:.2%}") | |
| print(f" CWE Accuracy: {summary['cwe_accuracy']:.2%}") | |
| print(f" Confusion: TP={summary['tp']} FP={summary['fp']} TN={summary['tn']} FN={summary['fn']}") | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"Results saved to {output_file}") | |
| return results | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model-path", default="meta-llama/Llama-3.2-3B-Instruct") | |
| parser.add_argument("--test-file", default="data/devign_test.jsonl") | |
| parser.add_argument("--is-lora", action="store_true") | |
| parser.add_argument("--base-model", default="meta-llama/Llama-3.2-3B-Instruct") | |
| parser.add_argument("--output", default="eval_results.json") | |
| args = parser.parse_args() | |
| evaluate(args.model_path, args.test_file, args.is_lora, args.base_model, args.output) | |