Spaces:
Runtime error
Runtime error
| """ | |
| Lightweight offline evaluation of baseline strategies against CommitGuard env. | |
| No GPU or model required — runs locally with deterministic strategies. | |
| Reports precision, recall, F1, accuracy, mean reward, and per-CWE breakdown. | |
| Usage: | |
| python scripts/eval_baselines.py --episodes 500 | |
| python scripts/eval_baselines.py --episodes 1000 --strategy always_vuln | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from commitguard_env.environment import CommitGuardEnvironment | |
| from commitguard_env.models import CommitGuardAction | |
| class EvalResult: | |
| tp: int = 0 | |
| fp: int = 0 | |
| tn: int = 0 | |
| fn: int = 0 | |
| total_reward: float = 0.0 | |
| episodes: int = 0 | |
| rewards: list[float] = field(default_factory=list) | |
| per_cwe: dict[str, dict[str, int]] = field(default_factory=dict) | |
| def precision(self) -> float: | |
| return self.tp / (self.tp + self.fp) if (self.tp + self.fp) > 0 else 0.0 | |
| def recall(self) -> float: | |
| return self.tp / (self.tp + self.fn) if (self.tp + self.fn) > 0 else 0.0 | |
| def f1(self) -> float: | |
| p, r = self.precision, self.recall | |
| return 2 * p * r / (p + r) if (p + r) > 0 else 0.0 | |
| def accuracy(self) -> float: | |
| total = self.tp + self.fp + self.tn + self.fn | |
| return (self.tp + self.tn) / total if total > 0 else 0.0 | |
| def mean_reward(self) -> float: | |
| return self.total_reward / self.episodes if self.episodes > 0 else 0.0 | |
| def to_dict(self) -> dict: | |
| return { | |
| "episodes": self.episodes, | |
| "confusion": {"tp": self.tp, "fp": self.fp, "tn": self.tn, "fn": self.fn}, | |
| "precision": round(self.precision, 4), | |
| "recall": round(self.recall, 4), | |
| "f1": round(self.f1, 4), | |
| "accuracy": round(self.accuracy, 4), | |
| "mean_reward": round(self.mean_reward, 4), | |
| "per_cwe": self.per_cwe, | |
| } | |
| def make_action(strategy: str) -> CommitGuardAction: | |
| if strategy == "always_vuln": | |
| return CommitGuardAction( | |
| action_type="verdict", | |
| is_vulnerable=True, | |
| vuln_type="CWE-119", | |
| exploit_sketch="buffer overflow via unchecked memcpy", | |
| ) | |
| if strategy == "always_safe": | |
| return CommitGuardAction( | |
| action_type="verdict", | |
| is_vulnerable=False, | |
| vuln_type="NONE", | |
| ) | |
| if strategy == "random": | |
| vuln = random.choice([True, False]) | |
| return CommitGuardAction( | |
| action_type="verdict", | |
| is_vulnerable=vuln, | |
| vuln_type="CWE-119" if vuln else "NONE", | |
| exploit_sketch="buffer overflow" if vuln else None, | |
| ) | |
| raise ValueError(f"unknown strategy: {strategy}") | |
| def run_eval(env: CommitGuardEnvironment, episodes: int, strategy: str) -> EvalResult: | |
| result = EvalResult() | |
| for _ in range(episodes): | |
| env.reset() | |
| action = make_action(strategy) | |
| _obs, reward, _done = env.step(action) | |
| sample = next(s for s in env._samples if s.sample_id == env._state.current_sample_id) | |
| pred = action.is_vulnerable | |
| gt = sample.is_vulnerable | |
| cwe = sample.cwe or "None" | |
| if cwe not in result.per_cwe: | |
| result.per_cwe[cwe] = {"tp": 0, "fp": 0, "tn": 0, "fn": 0} | |
| if pred and gt: | |
| result.tp += 1 | |
| result.per_cwe[cwe]["tp"] += 1 | |
| elif pred and not gt: | |
| result.fp += 1 | |
| result.per_cwe[cwe]["fp"] += 1 | |
| elif not pred and gt: | |
| result.fn += 1 | |
| result.per_cwe[cwe]["fn"] += 1 | |
| else: | |
| result.tn += 1 | |
| result.per_cwe[cwe]["tn"] += 1 | |
| result.total_reward += reward | |
| result.rewards.append(reward) | |
| result.episodes += 1 | |
| return result | |
| def plot_results(results: dict[str, EvalResult], out_dir: Path) -> None: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| ax = axes[0] | |
| for name, res in results.items(): | |
| cumulative = [] | |
| s = 0.0 | |
| for i, r in enumerate(res.rewards): | |
| s += r | |
| cumulative.append(s / (i + 1)) | |
| ax.plot(cumulative, label=f"{name} (mean={res.mean_reward:.2f})", linewidth=1.2) | |
| ax.set_title("Cumulative Mean Reward") | |
| ax.set_xlabel("Episode") | |
| ax.set_ylabel("Mean Reward") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| ax = axes[1] | |
| metrics = ["precision", "recall", "f1", "accuracy"] | |
| x = range(len(metrics)) | |
| width = 0.8 / len(results) | |
| for i, (name, res) in enumerate(results.items()): | |
| vals = [getattr(res, m) for m in metrics] | |
| offset = (i - len(results) / 2 + 0.5) * width | |
| bars = ax.bar([xi + offset for xi in x], vals, width, label=name) | |
| ax.set_xticks(list(x)) | |
| ax.set_xticklabels([m.upper() for m in metrics]) | |
| ax.set_ylim(0, 1.05) | |
| ax.set_title("Classification Metrics by Strategy") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3, axis="y") | |
| plt.tight_layout() | |
| plt.savefig(out_dir / "eval_baselines.png", dpi=180) | |
| print(f"Plot saved to {out_dir / 'eval_baselines.png'}") | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Evaluate CommitGuard baseline strategies") | |
| ap.add_argument("--episodes", type=int, default=500) | |
| ap.add_argument("--strategy", type=str, default="all", | |
| choices=["always_vuln", "always_safe", "random", "all"]) | |
| ap.add_argument("--out-dir", type=Path, default=Path("plots")) | |
| ap.add_argument("--seed", type=int, default=42) | |
| args = ap.parse_args() | |
| data_path = REPO_ROOT / "data" / "devign_filtered.jsonl" | |
| env = CommitGuardEnvironment(data_path=data_path) | |
| strategies = ( | |
| ["always_vuln", "always_safe", "random"] | |
| if args.strategy == "all" | |
| else [args.strategy] | |
| ) | |
| results: dict[str, EvalResult] = {} | |
| for strat in strategies: | |
| random.seed(args.seed) | |
| env._rng = random.Random(args.seed) | |
| res = run_eval(env, args.episodes, strat) | |
| results[strat] = res | |
| d = res.to_dict() | |
| print(f"\n{'='*50}") | |
| print(f" Strategy: {strat} ({args.episodes} episodes)") | |
| print(f"{'='*50}") | |
| print(f" Accuracy: {d['accuracy']:.2%}") | |
| print(f" Precision: {d['precision']:.2%}") | |
| print(f" Recall: {d['recall']:.2%}") | |
| print(f" F1: {d['f1']:.2%}") | |
| print(f" Mean Reward: {d['mean_reward']:.4f}") | |
| print(f" Confusion: TP={d['confusion']['tp']} FP={d['confusion']['fp']} " | |
| f"TN={d['confusion']['tn']} FN={d['confusion']['fn']}") | |
| args.out_dir.mkdir(parents=True, exist_ok=True) | |
| all_results = {name: res.to_dict() for name, res in results.items()} | |
| out_path = args.out_dir / "eval_baselines.json" | |
| out_path.write_text(json.dumps(all_results, indent=2), encoding="utf-8") | |
| print(f"\nResults saved to {out_path}") | |
| try: | |
| plot_results(results, args.out_dir) | |
| except ImportError: | |
| print("matplotlib not available — skipping plot") | |
| if __name__ == "__main__": | |
| main() | |