commitguard-env / scripts /evaluate.py
Nitishkumar-ai's picture
Add reward tuning, improved prompt, eval harness, and serving Dockerfile
b32b61a
import json
import argparse
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from pathlib import Path
import sys
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from agent_prompt import SYSTEM_PROMPT
from commitguard_env.parse_action import parse_action
def format_eval_prompt(sample):
# Matches the format_prompt in train_grpo.py for consistency
return (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
f"Analyze this commit and submit your verdict.\n\n"
f"Code diff:\n```diff\n{sample['diff']}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
def evaluate(model_path, test_file, is_lora=False, base_model=None, output_file="eval_results.json"):
"""
Run model on test samples, compute accuracy metrics.
"""
print(f"Loading model from {model_path}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
if is_lora:
if not base_model:
raise ValueError("base_model is required if is_lora=True")
print(f"Loading LoRA adapter from {model_path} with base model {base_model}")
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = base_model,
max_seq_length = 2048,
load_in_4bit = True,
)
model = PeftModel.from_pretrained(model, model_path)
FastLanguageModel.for_inference(model)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load test data
print(f"Loading test data from {test_file}...")
with open(test_file, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f if line.strip()]
results = {
"summary": {
"total": len(samples),
"correct_binary": 0,
"correct_cwe": 0,
"false_positives": 0,
"false_negatives": 0,
"binary_accuracy": 0,
"cwe_accuracy": 0,
"false_positive_rate": 0,
"false_negative_rate": 0,
"cwe_breakdown": {},
},
"predictions": [],
}
print(f"Starting evaluation on {len(samples)} samples...")
for i, sample in enumerate(samples):
prompt = format_eval_prompt(sample)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
)
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Use the official parser from commitguard_env
prediction = parse_action(response)
gt_vulnerable = bool(sample["is_vulnerable"])
pred_vulnerable = bool(prediction.is_vulnerable) if prediction.is_vulnerable is not None else False
correct = pred_vulnerable == gt_vulnerable
if correct:
results["summary"]["correct_binary"] += 1
if gt_vulnerable and not pred_vulnerable:
results["summary"]["false_negatives"] += 1
elif not gt_vulnerable and pred_vulnerable:
results["summary"]["false_positives"] += 1
cwe = sample.get("cwe") or "CWE-OTHER"
if cwe not in results["summary"]["cwe_breakdown"]:
results["summary"]["cwe_breakdown"][cwe] = {"total": 0, "correct": 0, "accuracy": 0}
results["summary"]["cwe_breakdown"][cwe]["total"] += 1
if correct:
results["summary"]["cwe_breakdown"][cwe]["correct"] += 1
if gt_vulnerable and correct and prediction.vuln_type and prediction.vuln_type.strip().upper() == cwe.strip().upper():
results["summary"]["correct_cwe"] += 1
results["predictions"].append({
"sample_id": sample["sample_id"],
"ground_truth": gt_vulnerable,
"predicted": pred_vulnerable,
"predicted_cwe": prediction.vuln_type,
"actual_cwe": cwe,
"response": response,
})
if (i + 1) % 10 == 0:
print(f" Processed {i+1}/{len(samples)} samples...")
# Final summary stats
summary = results["summary"]
total = summary["total"]
vuln_count = sum(1 for s in samples if s["is_vulnerable"])
safe_count = total - vuln_count
tp = summary["correct_binary"] - (safe_count - summary["false_positives"]) # TP = correct vuln predictions
# Recompute from confusion matrix
tp = vuln_count - summary["false_negatives"]
fp = summary["false_positives"]
fn = summary["false_negatives"]
tn = safe_count - fp
summary["tp"] = tp
summary["fp"] = fp
summary["tn"] = tn
summary["fn"] = fn
summary["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
summary["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
p, r = summary["precision"], summary["recall"]
summary["f1"] = 2 * p * r / (p + r) if (p + r) > 0 else 0
summary["binary_accuracy"] = summary["correct_binary"] / total if total > 0 else 0
summary["cwe_accuracy"] = summary["correct_cwe"] / vuln_count if vuln_count > 0 else 0
summary["false_positive_rate"] = summary["false_positives"] / safe_count if safe_count > 0 else 0
summary["false_negative_rate"] = summary["false_negatives"] / vuln_count if vuln_count > 0 else 0
for cwe in summary["cwe_breakdown"]:
stats = summary["cwe_breakdown"][cwe]
stats["accuracy"] = stats["correct"] / stats["total"] if stats["total"] > 0 else 0
print(f"\nEvaluation Complete:")
print(f" Binary Accuracy: {summary['binary_accuracy']:.2%}")
print(f" Precision: {summary['precision']:.2%}")
print(f" Recall: {summary['recall']:.2%}")
print(f" F1 Score: {summary['f1']:.2%}")
print(f" CWE Accuracy: {summary['cwe_accuracy']:.2%}")
print(f" Confusion: TP={summary['tp']} FP={summary['fp']} TN={summary['tn']} FN={summary['fn']}")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {output_file}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--test-file", default="data/devign_test.jsonl")
parser.add_argument("--is-lora", action="store_true")
parser.add_argument("--base-model", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--output", default="eval_results.json")
args = parser.parse_args()
evaluate(args.model_path, args.test_file, args.is_lora, args.base_model, args.output)