Nitishkumar-ai commited on
Commit
33692a0
·
1 Parent(s): 8c862c5

Add scripts for hero case finding, hero details retrieval, and training log plotting

Browse files
scripts/find_hero_case.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ def find_hero_case():
5
+ data_path = Path("data/devign_filtered.jsonl")
6
+ if not data_path.exists():
7
+ print("Data not found.")
8
+ return
9
+
10
+ with open(data_path, "r") as f:
11
+ samples = [json.loads(line) for line in f]
12
+
13
+ # Filter for interesting CWEs (SQLi, Command Injection, etc.)
14
+ hero_candidates = [
15
+ s for s in samples
16
+ if s["is_vulnerable"] and s["cwe"] in ["CWE-89", "CWE-78", "CWE-22", "CWE-119"]
17
+ ]
18
+
19
+ print(f"Found {len(hero_candidates)} hero candidates.")
20
+
21
+ for s in hero_candidates[:5]:
22
+ print("-" * 40)
23
+ print(f"Sample ID: {s['sample_id']}")
24
+ print(f"CWE: {s['cwe']}")
25
+ print(f"Diff Context:\n{s['diff'][:300]}...")
26
+ print("-" * 40)
27
+
28
+ if __name__ == "__main__":
29
+ find_hero_case()
scripts/get_hero_details.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ def get_hero_details(sample_id):
5
+ data_path = Path("data/devign_filtered.jsonl")
6
+ with open(data_path, "r") as f:
7
+ for line in f:
8
+ s = json.loads(line)
9
+ if s["sample_id"] == sample_id:
10
+ print(json.dumps(s, indent=2))
11
+ return
12
+
13
+ if __name__ == "__main__":
14
+ get_hero_details("d9a3b33d2c9f996537b7f1d0246dee2d0120cefb")
scripts/plot_training_logs.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from pathlib import Path
6
+
7
+ def plot_training(log_history, output_path):
8
+ # Extract rewards and steps
9
+ # GRPOTrainer logs 'reward' in the history
10
+ steps = []
11
+ rewards = []
12
+
13
+ for entry in log_history:
14
+ if "reward" in entry and "step" in entry:
15
+ steps.append(entry["step"])
16
+ rewards.append(entry["reward"])
17
+
18
+ if not steps:
19
+ print("No reward data found in logs.")
20
+ return
21
+
22
+ plt.figure(figsize=(10, 5))
23
+ plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4)
24
+
25
+ # Simple moving average for trend
26
+ if len(rewards) > 5:
27
+ window = 5
28
+ sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)]
29
+ plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2)
30
+
31
+ plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14)
32
+ plt.xlabel("Training Step", fontsize=12)
33
+ plt.ylabel("Mean Reward", fontsize=12)
34
+ plt.legend()
35
+ plt.grid(True, linestyle='--', alpha=0.6)
36
+ plt.tight_layout()
37
+
38
+ plt.savefig(output_path, dpi=180)
39
+ print(f"Training plot saved to {output_path}")
40
+
41
+ def main():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json")
44
+ parser.add_argument("--output", type=str, default="plots/training_reward_curve.png")
45
+ args = parser.parse_args()
46
+
47
+ log_path = Path(args.log_file)
48
+ if not log_path.exists():
49
+ print(f"Log file {log_path} not found yet. Training might still be in progress.")
50
+ return
51
+
52
+ with open(log_path, "r") as f:
53
+ data = json.load(f)
54
+
55
+ plot_training(data.get("log_history", []), args.output)
56
+
57
+ if __name__ == "__main__":
58
+ main()