Spaces:
Runtime error
Runtime error
Commit ·
33692a0
1
Parent(s): 8c862c5
Add scripts for hero case finding, hero details retrieval, and training log plotting
Browse files- scripts/find_hero_case.py +29 -0
- scripts/get_hero_details.py +14 -0
- scripts/plot_training_logs.py +58 -0
scripts/find_hero_case.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def find_hero_case():
|
| 5 |
+
data_path = Path("data/devign_filtered.jsonl")
|
| 6 |
+
if not data_path.exists():
|
| 7 |
+
print("Data not found.")
|
| 8 |
+
return
|
| 9 |
+
|
| 10 |
+
with open(data_path, "r") as f:
|
| 11 |
+
samples = [json.loads(line) for line in f]
|
| 12 |
+
|
| 13 |
+
# Filter for interesting CWEs (SQLi, Command Injection, etc.)
|
| 14 |
+
hero_candidates = [
|
| 15 |
+
s for s in samples
|
| 16 |
+
if s["is_vulnerable"] and s["cwe"] in ["CWE-89", "CWE-78", "CWE-22", "CWE-119"]
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
print(f"Found {len(hero_candidates)} hero candidates.")
|
| 20 |
+
|
| 21 |
+
for s in hero_candidates[:5]:
|
| 22 |
+
print("-" * 40)
|
| 23 |
+
print(f"Sample ID: {s['sample_id']}")
|
| 24 |
+
print(f"CWE: {s['cwe']}")
|
| 25 |
+
print(f"Diff Context:\n{s['diff'][:300]}...")
|
| 26 |
+
print("-" * 40)
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
find_hero_case()
|
scripts/get_hero_details.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def get_hero_details(sample_id):
|
| 5 |
+
data_path = Path("data/devign_filtered.jsonl")
|
| 6 |
+
with open(data_path, "r") as f:
|
| 7 |
+
for line in f:
|
| 8 |
+
s = json.loads(line)
|
| 9 |
+
if s["sample_id"] == sample_id:
|
| 10 |
+
print(json.dumps(s, indent=2))
|
| 11 |
+
return
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
get_hero_details("d9a3b33d2c9f996537b7f1d0246dee2d0120cefb")
|
scripts/plot_training_logs.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
def plot_training(log_history, output_path):
|
| 8 |
+
# Extract rewards and steps
|
| 9 |
+
# GRPOTrainer logs 'reward' in the history
|
| 10 |
+
steps = []
|
| 11 |
+
rewards = []
|
| 12 |
+
|
| 13 |
+
for entry in log_history:
|
| 14 |
+
if "reward" in entry and "step" in entry:
|
| 15 |
+
steps.append(entry["step"])
|
| 16 |
+
rewards.append(entry["reward"])
|
| 17 |
+
|
| 18 |
+
if not steps:
|
| 19 |
+
print("No reward data found in logs.")
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
plt.figure(figsize=(10, 5))
|
| 23 |
+
plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4)
|
| 24 |
+
|
| 25 |
+
# Simple moving average for trend
|
| 26 |
+
if len(rewards) > 5:
|
| 27 |
+
window = 5
|
| 28 |
+
sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)]
|
| 29 |
+
plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2)
|
| 30 |
+
|
| 31 |
+
plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14)
|
| 32 |
+
plt.xlabel("Training Step", fontsize=12)
|
| 33 |
+
plt.ylabel("Mean Reward", fontsize=12)
|
| 34 |
+
plt.legend()
|
| 35 |
+
plt.grid(True, linestyle='--', alpha=0.6)
|
| 36 |
+
plt.tight_layout()
|
| 37 |
+
|
| 38 |
+
plt.savefig(output_path, dpi=180)
|
| 39 |
+
print(f"Training plot saved to {output_path}")
|
| 40 |
+
|
| 41 |
+
def main():
|
| 42 |
+
parser = argparse.ArgumentParser()
|
| 43 |
+
parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json")
|
| 44 |
+
parser.add_argument("--output", type=str, default="plots/training_reward_curve.png")
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
log_path = Path(args.log_file)
|
| 48 |
+
if not log_path.exists():
|
| 49 |
+
print(f"Log file {log_path} not found yet. Training might still be in progress.")
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
with open(log_path, "r") as f:
|
| 53 |
+
data = json.load(f)
|
| 54 |
+
|
| 55 |
+
plot_training(data.get("log_history", []), args.output)
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|