#!/usr/bin/env python3 """Profile a single training iteration of PAWN. Runs cProfile on the full train_step + optimizer_step pipeline, prints comparative stats, and identifies bottlenecks. Usage: uv run python scripts/profile_step.py [--device cuda] [--batch-size 256] """ import argparse import cProfile import pstats import time import io import torch # ── PAWN imports ───────────────────────────────────────────────────────────── from pawn.config import CLMConfig from pawn.model import PAWNCLM import chess_engine as engine def generate_clm_batch(batch_size: int, device: str): """Generate a CLM batch and move to device.""" input_ids, targets, loss_mask, _mid, _gl, _tc = \ engine.generate_clm_batch(batch_size, 256, seed=42) return { "input_ids": torch.from_numpy(input_ids).long().to(device), "targets": torch.from_numpy(targets).long().to(device), "loss_mask": torch.from_numpy(loss_mask).to(device), } def clm_iteration(model, optimizer, scaler, batch, device, use_amp): """One full CLM train step: forward + loss + backward + optimizer.""" model.train() with torch.amp.autocast(device, enabled=use_amp): loss, _metrics = model.forward_train( batch["input_ids"], batch["loss_mask"], batch["targets"] ) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) return loss.item() def profile_fn(fn, n_warmup=3, n_profile=10): """Run warmup iterations, then profile n_profile iterations. Returns (cProfile.Profile, wall_times_list). """ # Warmup (let CUDA kernels JIT, allocator settle) for _ in range(n_warmup): fn() if torch.cuda.is_available(): torch.cuda.synchronize() # Timed runs (wall clock) wall_times = [] for _ in range(n_profile): if torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.perf_counter() fn() if torch.cuda.is_available(): torch.cuda.synchronize() wall_times.append(time.perf_counter() - t0) # cProfile run (single iteration for call graph) if torch.cuda.is_available(): torch.cuda.synchronize() prof = cProfile.Profile() prof.enable() fn() if torch.cuda.is_available(): torch.cuda.synchronize() prof.disable() return prof, wall_times def print_profile(name: str, prof: cProfile.Profile, wall_times: list[float], top_n: int = 25): """Print profile summary.""" mean_ms = sum(wall_times) / len(wall_times) * 1000 min_ms = min(wall_times) * 1000 max_ms = max(wall_times) * 1000 print(f"\n{'='*70}") print(f" {name}") print(f"{'='*70}") print(f" Wall time: {mean_ms:.1f}ms mean | {min_ms:.1f}ms min | {max_ms:.1f}ms max " f"({len(wall_times)} runs)") print(f"{'─'*70}") s = io.StringIO() ps = pstats.Stats(prof, stream=s) ps.sort_stats("cumulative") ps.print_stats(top_n) print(s.getvalue()) def main(): parser = argparse.ArgumentParser(description="Profile PAWN training iteration") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--warmup", type=int, default=3) parser.add_argument("--iterations", type=int, default=10) parser.add_argument("--no-amp", action="store_true", help="Disable mixed precision") args = parser.parse_args() device = args.device bs = args.batch_size use_amp = not args.no_amp print(f"Device: {device}") print(f"Batch size: {bs}") print(f"AMP: {use_amp}") print(f"Warmup: {args.warmup} | Profile: {args.iterations} iterations") # ── Profile PAWN ───────────────────────────────────────────────────────── print("\nGenerating data...") batch = generate_clm_batch(bs, device) print("Setting up PAWN model...") cfg = CLMConfig() model = PAWNCLM(cfg).to(device) opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95)) scaler = torch.amp.GradScaler(device, enabled=use_amp) n_params = sum(p.numel() for p in model.parameters()) print(f" Params: {n_params:,}") if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() print("Profiling...") prof, times = profile_fn( lambda: clm_iteration(model, opt, scaler, batch, device, use_amp), n_warmup=args.warmup, n_profile=args.iterations, ) peak_mb = 0 if torch.cuda.is_available(): peak_mb = torch.cuda.max_memory_allocated() / (1024**2) # ── Print results ───────────────────────────────────────────────────────── print_profile("PAWN (next-token softmax)", prof, times) mean_ms = sum(times) / len(times) * 1000 print(f"\n{'='*70}") print(f" SUMMARY") print(f"{'='*70}") print(f" Parameters: {n_params:,}") print(f" Mean step time (ms): {mean_ms:.1f}") print(f" Min step time (ms): {min(times)*1000:.1f}") print(f" Peak GPU memory (MB):{peak_mb:.0f}") print(f" Throughput (games/s): {bs/mean_ms*1000:.0f}") print() if __name__ == "__main__": main()