File size: 1,877 Bytes
95cbc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
import argparse
import matplotlib.pyplot as plt
import os
from pathlib import Path

def plot_training(log_history, output_path):
    # Extract rewards and steps
    # GRPOTrainer logs 'reward' in the history
    steps = []
    rewards = []
    
    for entry in log_history:
        if "reward" in entry and "step" in entry:
            steps.append(entry["step"])
            rewards.append(entry["reward"])
    
    if not steps:
        print("No reward data found in logs.")
        return

    plt.figure(figsize=(10, 5))
    plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4)
    
    # Simple moving average for trend
    if len(rewards) > 5:
        window = 5
        sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)]
        plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2)

    plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14)
    plt.xlabel("Training Step", fontsize=12)
    plt.ylabel("Mean Reward", fontsize=12)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    plt.savefig(output_path, dpi=180)
    print(f"Training plot saved to {output_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json")
    parser.add_argument("--output", type=str, default="plots/training_reward_curve.png")
    args = parser.parse_args()

    log_path = Path(args.log_file)
    if not log_path.exists():
        print(f"Log file {log_path} not found yet. Training might still be in progress.")
        return

    with open(log_path, "r") as f:
        data = json.load(f)
    
    plot_training(data.get("log_history", []), args.output)

if __name__ == "__main__":
    main()