Spaces:
Configuration error
Configuration error
File size: 3,547 Bytes
e4f3d12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import matplotlib.pyplot as plt
import json
import os
import argparse
def plot_reward_curve(wandb_data_path, output_path="plots/reward_curve.png"):
"""
Plots the training reward curve.
Expects a JSON file with 'step' and 'reward' keys (exported from Wandb).
"""
if not os.path.exists(wandb_data_path):
print(f"Skipping: {wandb_data_path} not found.")
return
with open(wandb_data_path, "r") as f:
data = json.load(f)
steps = [d["step"] for d in data]
rewards = [d["reward"] for d in data]
plt.figure(figsize=(10, 6))
plt.plot(steps, rewards, label="GRPO Reward", color="#2ecc71", linewidth=2)
plt.xlabel("Training Step")
plt.ylabel("Mean Reward")
plt.title("CommitGuard — GRPO Training Reward Curve")
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
plt.savefig(output_path)
print(f"Saved: {output_path}")
def plot_accuracy_comparison(baseline_acc, trained_acc, output_path="plots/baseline_vs_trained.png"):
"""
Plots a bar chart comparing baseline vs trained accuracy.
"""
labels = ['Baseline (Untrained)', 'CommitGuard (Trained)']
accuracies = [baseline_acc, trained_acc]
colors = ['#95a5a6', '#3498db']
plt.figure(figsize=(8, 6))
bars = plt.bar(labels, accuracies, color=colors)
plt.ylabel("Detection Accuracy (%)")
plt.title("Vulnerability Detection: Baseline vs. Trained")
plt.ylim(0, 100)
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 1,
f'{height}%', ha='center', va='bottom', fontweight='bold')
plt.savefig(output_path)
print(f"Saved: {output_path}")
def plot_per_cwe_breakdown(cwe_data, output_path="plots/per_cwe.png"):
"""
Plots a grouped bar chart for per-CWE improvement.
cwe_data format: {"CWE-89": [baseline, trained], "CWE-119": [baseline, trained], ...}
"""
cwes = list(cwe_data.keys())
baseline_vals = [v[0] for v in cwe_data.values()]
trained_vals = [v[1] for v in cwe_data.values()]
x = range(len(cwes))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar([i - width/2 for i in x], baseline_vals, width, label='Baseline', color='#95a5a6')
ax.bar([i + width/2 for i in x], trained_vals, width, label='Trained', color='#e67e22')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Detection Accuracy by CWE Type')
ax.set_xticks(x)
ax.set_xticklabels(cwes, rotation=45)
ax.legend()
ax.set_ylim(0, 100)
plt.tight_layout()
plt.savefig(output_path)
print(f"Saved: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["reward", "accuracy", "cwe", "all"], default="all")
args = parser.parse_args()
os.makedirs("plots", exist_ok=True)
# Example usage for morning shift:
if args.mode in ["reward", "all"]:
plot_reward_curve("plots/wandb_simulated.json")
if args.mode in ["accuracy", "all"]:
# Placeholder numbers (to be updated by Divyank/Deepak's eval)
plot_accuracy_comparison(baseline_acc=32, trained_acc=68)
if args.mode in ["cwe", "all"]:
# Placeholder data
cwe_data = {
"CWE-89": [40, 85],
"CWE-119": [30, 60],
"CWE-79": [25, 70],
"CWE-20": [35, 55]
}
plot_per_cwe_breakdown(cwe_data)
|