File size: 5,705 Bytes
d7ecc62 230508d d7ecc62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | #!/usr/bin/env python3
"""Profile a single training iteration of PAWN.
Runs cProfile on the full train_step + optimizer_step pipeline,
prints comparative stats, and identifies bottlenecks.
Usage:
uv run python scripts/profile_step.py [--device cuda] [--batch-size 256]
"""
import argparse
import cProfile
import pstats
import time
import io
import torch
# ββ PAWN imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
from pawn.config import CLMConfig
from pawn.model import PAWNCLM
import chess_engine as engine
def generate_clm_batch(batch_size: int, device: str):
"""Generate a CLM batch and move to device."""
input_ids, targets, loss_mask, _mid, _gl, _tc = \
engine.generate_clm_batch(batch_size, 256, seed=42)
return {
"input_ids": torch.from_numpy(input_ids).long().to(device),
"targets": torch.from_numpy(targets).long().to(device),
"loss_mask": torch.from_numpy(loss_mask).to(device),
}
def clm_iteration(model, optimizer, scaler, batch, device, use_amp):
"""One full CLM train step: forward + loss + backward + optimizer."""
model.train()
with torch.amp.autocast(device, enabled=use_amp):
loss, _metrics = model.forward_train(
batch["input_ids"], batch["loss_mask"], batch["targets"]
)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
return loss.item()
def profile_fn(fn, n_warmup=3, n_profile=10):
"""Run warmup iterations, then profile n_profile iterations.
Returns (cProfile.Profile, wall_times_list).
"""
# Warmup (let CUDA kernels JIT, allocator settle)
for _ in range(n_warmup):
fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
# Timed runs (wall clock)
wall_times = []
for _ in range(n_profile):
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
wall_times.append(time.perf_counter() - t0)
# cProfile run (single iteration for call graph)
if torch.cuda.is_available():
torch.cuda.synchronize()
prof = cProfile.Profile()
prof.enable()
fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
prof.disable()
return prof, wall_times
def print_profile(name: str, prof: cProfile.Profile, wall_times: list[float],
top_n: int = 25):
"""Print profile summary."""
mean_ms = sum(wall_times) / len(wall_times) * 1000
min_ms = min(wall_times) * 1000
max_ms = max(wall_times) * 1000
print(f"\n{'='*70}")
print(f" {name}")
print(f"{'='*70}")
print(f" Wall time: {mean_ms:.1f}ms mean | {min_ms:.1f}ms min | {max_ms:.1f}ms max "
f"({len(wall_times)} runs)")
print(f"{'β'*70}")
s = io.StringIO()
ps = pstats.Stats(prof, stream=s)
ps.sort_stats("cumulative")
ps.print_stats(top_n)
print(s.getvalue())
def main():
parser = argparse.ArgumentParser(description="Profile PAWN training iteration")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--warmup", type=int, default=3)
parser.add_argument("--iterations", type=int, default=10)
parser.add_argument("--no-amp", action="store_true", help="Disable mixed precision")
args = parser.parse_args()
device = args.device
bs = args.batch_size
use_amp = not args.no_amp
print(f"Device: {device}")
print(f"Batch size: {bs}")
print(f"AMP: {use_amp}")
print(f"Warmup: {args.warmup} | Profile: {args.iterations} iterations")
# ββ Profile PAWN βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\nGenerating data...")
batch = generate_clm_batch(bs, device)
print("Setting up PAWN model...")
cfg = CLMConfig()
model = PAWNCLM(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95))
scaler = torch.amp.GradScaler(device, enabled=use_amp)
n_params = sum(p.numel() for p in model.parameters())
print(f" Params: {n_params:,}")
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
print("Profiling...")
prof, times = profile_fn(
lambda: clm_iteration(model, opt, scaler, batch, device, use_amp),
n_warmup=args.warmup,
n_profile=args.iterations,
)
peak_mb = 0
if torch.cuda.is_available():
peak_mb = torch.cuda.max_memory_allocated() / (1024**2)
# ββ Print results βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print_profile("PAWN (next-token softmax)", prof, times)
mean_ms = sum(times) / len(times) * 1000
print(f"\n{'='*70}")
print(f" SUMMARY")
print(f"{'='*70}")
print(f" Parameters: {n_params:,}")
print(f" Mean step time (ms): {mean_ms:.1f}")
print(f" Min step time (ms): {min(times)*1000:.1f}")
print(f" Peak GPU memory (MB):{peak_mb:.0f}")
print(f" Throughput (games/s): {bs/mean_ms*1000:.0f}")
print()
if __name__ == "__main__":
main()
|