Spaces:
Runtime error
Runtime error
| import json | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| import os | |
| from pathlib import Path | |
| def plot_training(log_history, output_path): | |
| # Extract rewards and steps | |
| # GRPOTrainer logs 'reward' in the history | |
| steps = [] | |
| rewards = [] | |
| for entry in log_history: | |
| if "reward" in entry and "step" in entry: | |
| steps.append(entry["step"]) | |
| rewards.append(entry["reward"]) | |
| if not steps: | |
| print("No reward data found in logs.") | |
| return | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4) | |
| # Simple moving average for trend | |
| if len(rewards) > 5: | |
| window = 5 | |
| sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)] | |
| plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2) | |
| plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14) | |
| plt.xlabel("Training Step", fontsize=12) | |
| plt.ylabel("Mean Reward", fontsize=12) | |
| plt.legend() | |
| plt.grid(True, linestyle='--', alpha=0.6) | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=180) | |
| print(f"Training plot saved to {output_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json") | |
| parser.add_argument("--output", type=str, default="plots/training_reward_curve.png") | |
| args = parser.parse_args() | |
| log_path = Path(args.log_file) | |
| if not log_path.exists(): | |
| print(f"Log file {log_path} not found yet. Training might still be in progress.") | |
| return | |
| with open(log_path, "r") as f: | |
| data = json.load(f) | |
| plot_training(data.get("log_history", []), args.output) | |
| if __name__ == "__main__": | |
| main() | |