commitguard-env / scripts /plot_results.py
Nitishkumar-ai's picture
Initial clean deploy commit
b74db43
import matplotlib.pyplot as plt
import json
import os
import argparse
def plot_reward_curve(wandb_data_path, output_path="plots/reward_curve.png"):
"""
Plots the training reward curve.
Expects a JSON file with 'step' and 'reward' keys (exported from Wandb).
"""
if not os.path.exists(wandb_data_path):
print(f"Skipping: {wandb_data_path} not found.")
return
with open(wandb_data_path, "r") as f:
data = json.load(f)
steps = [d["step"] for d in data]
rewards = [d["reward"] for d in data]
plt.figure(figsize=(10, 6))
plt.plot(steps, rewards, label="GRPO Reward", color="#2ecc71", linewidth=2)
plt.xlabel("Training Step")
plt.ylabel("Mean Reward")
plt.title("CommitGuard — GRPO Training Reward Curve")
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
plt.savefig(output_path)
print(f"Saved: {output_path}")
def plot_accuracy_comparison(baseline_acc, trained_acc, output_path="plots/baseline_vs_trained.png"):
"""
Plots a bar chart comparing baseline vs trained accuracy.
"""
labels = ['Baseline (Untrained)', 'CommitGuard (Trained)']
accuracies = [baseline_acc, trained_acc]
colors = ['#95a5a6', '#3498db']
plt.figure(figsize=(8, 6))
bars = plt.bar(labels, accuracies, color=colors)
plt.ylabel("Detection Accuracy (%)")
plt.title("Vulnerability Detection: Baseline vs. Trained")
plt.ylim(0, 100)
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 1,
f'{height}%', ha='center', va='bottom', fontweight='bold')
plt.savefig(output_path)
print(f"Saved: {output_path}")
def plot_per_cwe_breakdown(cwe_data, output_path="plots/per_cwe.png"):
"""
Plots a grouped bar chart for per-CWE improvement.
cwe_data format: {"CWE-89": [baseline, trained], "CWE-119": [baseline, trained], ...}
"""
cwes = list(cwe_data.keys())
baseline_vals = [v[0] for v in cwe_data.values()]
trained_vals = [v[1] for v in cwe_data.values()]
x = range(len(cwes))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar([i - width/2 for i in x], baseline_vals, width, label='Baseline', color='#95a5a6')
ax.bar([i + width/2 for i in x], trained_vals, width, label='Trained', color='#e67e22')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Detection Accuracy by CWE Type')
ax.set_xticks(x)
ax.set_xticklabels(cwes, rotation=45)
ax.legend()
ax.set_ylim(0, 100)
plt.tight_layout()
plt.savefig(output_path)
print(f"Saved: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["reward", "accuracy", "cwe", "all"], default="all")
args = parser.parse_args()
os.makedirs("plots", exist_ok=True)
# Example usage for morning shift:
if args.mode in ["reward", "all"]:
plot_reward_curve("plots/wandb_simulated.json")
if args.mode in ["accuracy", "all"]:
# Placeholder numbers (to be updated by Divyank/Deepak's eval)
plot_accuracy_comparison(baseline_acc=32, trained_acc=68)
if args.mode in ["cwe", "all"]:
# Placeholder data
cwe_data = {
"CWE-89": [40, 85],
"CWE-119": [30, 60],
"CWE-79": [25, 70],
"CWE-20": [35, 55]
}
plot_per_cwe_breakdown(cwe_data)