Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| import json | |
| import os | |
| import argparse | |
| def plot_reward_curve(wandb_data_path, output_path="plots/reward_curve.png"): | |
| """ | |
| Plots the training reward curve. | |
| Expects a JSON file with 'step' and 'reward' keys (exported from Wandb). | |
| """ | |
| if not os.path.exists(wandb_data_path): | |
| print(f"Skipping: {wandb_data_path} not found.") | |
| return | |
| with open(wandb_data_path, "r") as f: | |
| data = json.load(f) | |
| steps = [d["step"] for d in data] | |
| rewards = [d["reward"] for d in data] | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(steps, rewards, label="GRPO Reward", color="#2ecc71", linewidth=2) | |
| plt.xlabel("Training Step") | |
| plt.ylabel("Mean Reward") | |
| plt.title("CommitGuard — GRPO Training Reward Curve") | |
| plt.grid(True, linestyle="--", alpha=0.7) | |
| plt.legend() | |
| plt.savefig(output_path) | |
| print(f"Saved: {output_path}") | |
| def plot_accuracy_comparison(baseline_acc, trained_acc, output_path="plots/baseline_vs_trained.png"): | |
| """ | |
| Plots a bar chart comparing baseline vs trained accuracy. | |
| """ | |
| labels = ['Baseline (Untrained)', 'CommitGuard (Trained)'] | |
| accuracies = [baseline_acc, trained_acc] | |
| colors = ['#95a5a6', '#3498db'] | |
| plt.figure(figsize=(8, 6)) | |
| bars = plt.bar(labels, accuracies, color=colors) | |
| plt.ylabel("Detection Accuracy (%)") | |
| plt.title("Vulnerability Detection: Baseline vs. Trained") | |
| plt.ylim(0, 100) | |
| for bar in bars: | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + 1, | |
| f'{height}%', ha='center', va='bottom', fontweight='bold') | |
| plt.savefig(output_path) | |
| print(f"Saved: {output_path}") | |
| def plot_per_cwe_breakdown(cwe_data, output_path="plots/per_cwe.png"): | |
| """ | |
| Plots a grouped bar chart for per-CWE improvement. | |
| cwe_data format: {"CWE-89": [baseline, trained], "CWE-119": [baseline, trained], ...} | |
| """ | |
| cwes = list(cwe_data.keys()) | |
| baseline_vals = [v[0] for v in cwe_data.values()] | |
| trained_vals = [v[1] for v in cwe_data.values()] | |
| x = range(len(cwes)) | |
| width = 0.35 | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.bar([i - width/2 for i in x], baseline_vals, width, label='Baseline', color='#95a5a6') | |
| ax.bar([i + width/2 for i in x], trained_vals, width, label='Trained', color='#e67e22') | |
| ax.set_ylabel('Accuracy (%)') | |
| ax.set_title('Detection Accuracy by CWE Type') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(cwes, rotation=45) | |
| ax.legend() | |
| ax.set_ylim(0, 100) | |
| plt.tight_layout() | |
| plt.savefig(output_path) | |
| print(f"Saved: {output_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--mode", choices=["reward", "accuracy", "cwe", "all"], default="all") | |
| args = parser.parse_args() | |
| os.makedirs("plots", exist_ok=True) | |
| # Example usage for morning shift: | |
| if args.mode in ["reward", "all"]: | |
| plot_reward_curve("plots/wandb_simulated.json") | |
| if args.mode in ["accuracy", "all"]: | |
| # Placeholder numbers (to be updated by Divyank/Deepak's eval) | |
| plot_accuracy_comparison(baseline_acc=32, trained_acc=68) | |
| if args.mode in ["cwe", "all"]: | |
| # Placeholder data | |
| cwe_data = { | |
| "CWE-89": [40, 85], | |
| "CWE-119": [30, 60], | |
| "CWE-79": [25, 70], | |
| "CWE-20": [35, 55] | |
| } | |
| plot_per_cwe_breakdown(cwe_data) | |