commitguard-env / scripts /plot_training_logs.py
Nitishkumar-ai's picture
Add scripts for hero case finding, hero details retrieval, and training log plotting
33692a0
import json
import argparse
import matplotlib.pyplot as plt
import os
from pathlib import Path
def plot_training(log_history, output_path):
# Extract rewards and steps
# GRPOTrainer logs 'reward' in the history
steps = []
rewards = []
for entry in log_history:
if "reward" in entry and "step" in entry:
steps.append(entry["step"])
rewards.append(entry["reward"])
if not steps:
print("No reward data found in logs.")
return
plt.figure(figsize=(10, 5))
plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4)
# Simple moving average for trend
if len(rewards) > 5:
window = 5
sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)]
plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2)
plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14)
plt.xlabel("Training Step", fontsize=12)
plt.ylabel("Mean Reward", fontsize=12)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(output_path, dpi=180)
print(f"Training plot saved to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json")
parser.add_argument("--output", type=str, default="plots/training_reward_curve.png")
args = parser.parse_args()
log_path = Path(args.log_file)
if not log_path.exists():
print(f"Log file {log_path} not found yet. Training might still be in progress.")
return
with open(log_path, "r") as f:
data = json.load(f)
plot_training(data.get("log_history", []), args.output)
if __name__ == "__main__":
main()