File size: 5,705 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230508d
 
 
 
 
 
 
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python3
"""Profile a single training iteration of PAWN.

Runs cProfile on the full train_step + optimizer_step pipeline,
prints comparative stats, and identifies bottlenecks.

Usage:
    uv run python scripts/profile_step.py [--device cuda] [--batch-size 256]
"""

import argparse
import cProfile
import pstats
import time
import io
import torch

# ── PAWN imports ─────────────────────────────────────────────────────────────
from pawn.config import CLMConfig
from pawn.model import PAWNCLM
import chess_engine as engine


def generate_clm_batch(batch_size: int, device: str):
    """Generate a CLM batch and move to device."""
    input_ids, targets, loss_mask, _mid, _gl, _tc = \
        engine.generate_clm_batch(batch_size, 256, seed=42)
    return {
        "input_ids": torch.from_numpy(input_ids).long().to(device),
        "targets": torch.from_numpy(targets).long().to(device),
        "loss_mask": torch.from_numpy(loss_mask).to(device),
    }



def clm_iteration(model, optimizer, scaler, batch, device, use_amp):
    """One full CLM train step: forward + loss + backward + optimizer."""
    model.train()
    with torch.amp.autocast(device, enabled=use_amp):
        loss, _metrics = model.forward_train(
            batch["input_ids"], batch["loss_mask"], batch["targets"]
        )

    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
    return loss.item()



def profile_fn(fn, n_warmup=3, n_profile=10):
    """Run warmup iterations, then profile n_profile iterations.

    Returns (cProfile.Profile, wall_times_list).
    """
    # Warmup (let CUDA kernels JIT, allocator settle)
    for _ in range(n_warmup):
        fn()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    # Timed runs (wall clock)
    wall_times = []
    for _ in range(n_profile):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        fn()
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        wall_times.append(time.perf_counter() - t0)

    # cProfile run (single iteration for call graph)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    prof = cProfile.Profile()
    prof.enable()
    fn()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    prof.disable()

    return prof, wall_times


def print_profile(name: str, prof: cProfile.Profile, wall_times: list[float],
                  top_n: int = 25):
    """Print profile summary."""
    mean_ms = sum(wall_times) / len(wall_times) * 1000
    min_ms = min(wall_times) * 1000
    max_ms = max(wall_times) * 1000

    print(f"\n{'='*70}")
    print(f" {name}")
    print(f"{'='*70}")
    print(f" Wall time: {mean_ms:.1f}ms mean | {min_ms:.1f}ms min | {max_ms:.1f}ms max "
          f"({len(wall_times)} runs)")
    print(f"{'─'*70}")

    s = io.StringIO()
    ps = pstats.Stats(prof, stream=s)
    ps.sort_stats("cumulative")
    ps.print_stats(top_n)
    print(s.getvalue())


def main():
    parser = argparse.ArgumentParser(description="Profile PAWN training iteration")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--batch-size", type=int, default=256)
    parser.add_argument("--warmup", type=int, default=3)
    parser.add_argument("--iterations", type=int, default=10)
    parser.add_argument("--no-amp", action="store_true", help="Disable mixed precision")
    args = parser.parse_args()

    device = args.device
    bs = args.batch_size
    use_amp = not args.no_amp

    print(f"Device: {device}")
    print(f"Batch size: {bs}")
    print(f"AMP: {use_amp}")
    print(f"Warmup: {args.warmup} | Profile: {args.iterations} iterations")

    # ── Profile PAWN ─────────────────────────────────────────────────────────
    print("\nGenerating data...")
    batch = generate_clm_batch(bs, device)

    print("Setting up PAWN model...")
    cfg = CLMConfig()
    model = PAWNCLM(cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95))
    scaler = torch.amp.GradScaler(device, enabled=use_amp)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"  Params: {n_params:,}")

    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    print("Profiling...")
    prof, times = profile_fn(
        lambda: clm_iteration(model, opt, scaler, batch, device, use_amp),
        n_warmup=args.warmup,
        n_profile=args.iterations,
    )

    peak_mb = 0
    if torch.cuda.is_available():
        peak_mb = torch.cuda.max_memory_allocated() / (1024**2)

    # ── Print results ─────────────────────────────────────────────────────────
    print_profile("PAWN (next-token softmax)", prof, times)

    mean_ms = sum(times) / len(times) * 1000
    print(f"\n{'='*70}")
    print(f" SUMMARY")
    print(f"{'='*70}")
    print(f" Parameters:          {n_params:,}")
    print(f" Mean step time (ms): {mean_ms:.1f}")
    print(f" Min step time (ms):  {min(times)*1000:.1f}")
    print(f" Peak GPU memory (MB):{peak_mb:.0f}")
    print(f" Throughput (games/s): {bs/mean_ms*1000:.0f}")
    print()


if __name__ == "__main__":
    main()