Spaces:
Running on A10G
Running on A10G
File size: 7,184 Bytes
95cbc5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import json
import argparse
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from pathlib import Path
import sys
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from agent_prompt import SYSTEM_PROMPT
from commitguard_env.parse_action import parse_action
def format_eval_prompt(sample):
# Matches the format_prompt in train_grpo.py for consistency
return (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
f"{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
f"Analyze this commit and submit your verdict.\n\n"
f"Code diff:\n```diff\n{sample['diff']}\n```<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
def evaluate(model_path, test_file, is_lora=False, base_model=None, output_file="eval_results.json"):
"""
Run model on test samples, compute accuracy metrics.
"""
print(f"Loading model from {model_path}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
if is_lora:
if not base_model:
raise ValueError("base_model is required if is_lora=True")
print(f"Loading LoRA adapter from {model_path} with base model {base_model}")
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = base_model,
max_seq_length = 2048,
load_in_4bit = True,
)
model = PeftModel.from_pretrained(model, model_path)
FastLanguageModel.for_inference(model)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load test data
print(f"Loading test data from {test_file}...")
with open(test_file, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f if line.strip()]
results = {
"summary": {
"total": len(samples),
"correct_binary": 0,
"correct_cwe": 0,
"false_positives": 0,
"false_negatives": 0,
"binary_accuracy": 0,
"cwe_accuracy": 0,
"false_positive_rate": 0,
"false_negative_rate": 0,
"cwe_breakdown": {},
},
"predictions": [],
}
print(f"Starting evaluation on {len(samples)} samples...")
for i, sample in enumerate(samples):
prompt = format_eval_prompt(sample)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=False,
)
response = tokenizer.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Use the official parser from commitguard_env
prediction = parse_action(response)
gt_vulnerable = bool(sample["is_vulnerable"])
pred_vulnerable = bool(prediction.is_vulnerable) if prediction.is_vulnerable is not None else False
correct = pred_vulnerable == gt_vulnerable
if correct:
results["summary"]["correct_binary"] += 1
if gt_vulnerable and not pred_vulnerable:
results["summary"]["false_negatives"] += 1
elif not gt_vulnerable and pred_vulnerable:
results["summary"]["false_positives"] += 1
cwe = sample.get("cwe") or "CWE-OTHER"
if cwe not in results["summary"]["cwe_breakdown"]:
results["summary"]["cwe_breakdown"][cwe] = {"total": 0, "correct": 0, "accuracy": 0}
results["summary"]["cwe_breakdown"][cwe]["total"] += 1
if correct:
results["summary"]["cwe_breakdown"][cwe]["correct"] += 1
if gt_vulnerable and correct and prediction.vuln_type and prediction.vuln_type.strip().upper() == cwe.strip().upper():
results["summary"]["correct_cwe"] += 1
results["predictions"].append({
"sample_id": sample["sample_id"],
"ground_truth": gt_vulnerable,
"predicted": pred_vulnerable,
"predicted_cwe": prediction.vuln_type,
"actual_cwe": cwe,
"response": response,
})
if (i + 1) % 10 == 0:
print(f" Processed {i+1}/{len(samples)} samples...")
# Final summary stats
summary = results["summary"]
total = summary["total"]
vuln_count = sum(1 for s in samples if s["is_vulnerable"])
safe_count = total - vuln_count
tp = summary["correct_binary"] - (safe_count - summary["false_positives"]) # TP = correct vuln predictions
# Recompute from confusion matrix
tp = vuln_count - summary["false_negatives"]
fp = summary["false_positives"]
fn = summary["false_negatives"]
tn = safe_count - fp
summary["tp"] = tp
summary["fp"] = fp
summary["tn"] = tn
summary["fn"] = fn
summary["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
summary["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
p, r = summary["precision"], summary["recall"]
summary["f1"] = 2 * p * r / (p + r) if (p + r) > 0 else 0
summary["binary_accuracy"] = summary["correct_binary"] / total if total > 0 else 0
summary["cwe_accuracy"] = summary["correct_cwe"] / vuln_count if vuln_count > 0 else 0
summary["false_positive_rate"] = summary["false_positives"] / safe_count if safe_count > 0 else 0
summary["false_negative_rate"] = summary["false_negatives"] / vuln_count if vuln_count > 0 else 0
for cwe in summary["cwe_breakdown"]:
stats = summary["cwe_breakdown"][cwe]
stats["accuracy"] = stats["correct"] / stats["total"] if stats["total"] > 0 else 0
print(f"\nEvaluation Complete:")
print(f" Binary Accuracy: {summary['binary_accuracy']:.2%}")
print(f" Precision: {summary['precision']:.2%}")
print(f" Recall: {summary['recall']:.2%}")
print(f" F1 Score: {summary['f1']:.2%}")
print(f" CWE Accuracy: {summary['cwe_accuracy']:.2%}")
print(f" Confusion: TP={summary['tp']} FP={summary['fp']} TN={summary['tn']} FN={summary['fn']}")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {output_file}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--test-file", default="data/devign_test.jsonl")
parser.add_argument("--is-lora", action="store_true")
parser.add_argument("--base-model", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--output", default="eval_results.json")
args = parser.parse_args()
evaluate(args.model_path, args.test_file, args.is_lora, args.base_model, args.output)
|