File size: 3,444 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import matplotlib.pyplot as plt
import json
import os
import argparse

def plot_reward_curve(wandb_data_path, output_path="plots/reward_curve.png"):
    """
    Plots the training reward curve.
    Expects a JSON file with 'step' and 'reward' keys (exported from Wandb).
    """
    if not os.path.exists(wandb_data_path):
        print(f"Skipping: {wandb_data_path} not found.")
        return

    with open(wandb_data_path, "r") as f:
        data = json.load(f)

    steps = [d["step"] for d in data]
    rewards = [d["reward"] for d in data]

    plt.figure(figsize=(10, 6))
    plt.plot(steps, rewards, label="GRPO Reward", color="#2ecc71", linewidth=2)
    plt.xlabel("Training Step")
    plt.ylabel("Mean Reward")
    plt.title("CommitGuard — GRPO Training Reward Curve")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend()
    plt.savefig(output_path)
    print(f"Saved: {output_path}")

def plot_accuracy_comparison(baseline_acc, trained_acc, output_path="plots/baseline_vs_trained.png"):
    """
    Plots a bar chart comparing baseline vs trained accuracy.
    """
    labels = ['Baseline (Untrained)', 'CommitGuard (Trained)']
    accuracies = [baseline_acc, trained_acc]
    colors = ['#95a5a6', '#3498db']

    plt.figure(figsize=(8, 6))
    bars = plt.bar(labels, accuracies, color=colors)
    plt.ylabel("Detection Accuracy (%)")
    plt.title("Vulnerability Detection: Baseline vs. Trained")
    plt.ylim(0, 100)

    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 1,
                 f'{height}%', ha='center', va='bottom', fontweight='bold')

    plt.savefig(output_path)
    print(f"Saved: {output_path}")

def plot_per_cwe_breakdown(cwe_data, output_path="plots/per_cwe.png"):
    """
    Plots a grouped bar chart for per-CWE improvement.
    cwe_data format: {"CWE-89": [baseline, trained], "CWE-119": [baseline, trained], ...}
    """
    cwes = list(cwe_data.keys())
    baseline_vals = [v[0] for v in cwe_data.values()]
    trained_vals = [v[1] for v in cwe_data.values()]

    x = range(len(cwes))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar([i - width/2 for i in x], baseline_vals, width, label='Baseline', color='#95a5a6')
    ax.bar([i + width/2 for i in x], trained_vals, width, label='Trained', color='#e67e22')

    ax.set_ylabel('Accuracy (%)')
    ax.set_title('Detection Accuracy by CWE Type')
    ax.set_xticks(x)
    ax.set_xticklabels(cwes, rotation=45)
    ax.legend()
    ax.set_ylim(0, 100)

    plt.tight_layout()
    plt.savefig(output_path)
    print(f"Saved: {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", choices=["reward", "accuracy", "cwe", "all"], default="all")
    args = parser.parse_args()

    os.makedirs("plots", exist_ok=True)

    # Example usage for morning shift:
    if args.mode in ["reward", "all"]:
        plot_reward_curve("plots/wandb_simulated.json")

    if args.mode in ["accuracy", "all"]:
        # Placeholder numbers (to be updated by Divyank/Deepak's eval)
        plot_accuracy_comparison(baseline_acc=32, trained_acc=68)
    
    if args.mode in ["cwe", "all"]:
        # Placeholder data
        cwe_data = {
            "CWE-89": [40, 85],
            "CWE-119": [30, 60],
            "CWE-79": [25, 70],
            "CWE-20": [35, 55]
        }
        plot_per_cwe_breakdown(cwe_data)