Spaces:
Running on A10G
Running on A10G
File size: 1,877 Bytes
95cbc5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import json
import argparse
import matplotlib.pyplot as plt
import os
from pathlib import Path
def plot_training(log_history, output_path):
# Extract rewards and steps
# GRPOTrainer logs 'reward' in the history
steps = []
rewards = []
for entry in log_history:
if "reward" in entry and "step" in entry:
steps.append(entry["step"])
rewards.append(entry["reward"])
if not steps:
print("No reward data found in logs.")
return
plt.figure(figsize=(10, 5))
plt.plot(steps, rewards, label='Mean Reward (per step)', color='#2ecc71', alpha=0.4)
# Simple moving average for trend
if len(rewards) > 5:
window = 5
sma = [sum(rewards[i:i+window])/window for i in range(len(rewards)-window+1)]
plt.plot(steps[window-1:], sma, label=f'{window}-step Moving Avg', color='#e74c3c', linewidth=2)
plt.title("CommitGuard — GRPO Training Reward Curve", fontsize=14)
plt.xlabel("Training Step", fontsize=12)
plt.ylabel("Mean Reward", fontsize=12)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig(output_path, dpi=180)
print(f"Training plot saved to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--log-file", type=str, default="outputs/commitguard-llama-3b-grpo/final/trainer_state.json")
parser.add_argument("--output", type=str, default="plots/training_reward_curve.png")
args = parser.parse_args()
log_path = Path(args.log_file)
if not log_path.exists():
print(f"Log file {log_path} not found yet. Training might still be in progress.")
return
with open(log_path, "r") as f:
data = json.load(f)
plot_training(data.get("log_history", []), args.output)
if __name__ == "__main__":
main()
|