rtferraz commited on
Commit
680a32f
·
verified ·
1 Parent(s): c382430

Add full readable training script

Browse files
Files changed (1) hide show
  1. train_gpt2.py +984 -0
train_gpt2.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parameter Golf — Competitive Submission
3
+ ========================================
4
+ Key innovations targeting top-of-leaderboard (< 1.08 BPB):
5
+
6
+ 1. SP8192 Vocabulary: 8192-token SentencePiece tokenizer for better BPB
7
+ efficiency. Larger vocab = fewer tokens = better compression.
8
+
9
+ 2. Parallel Residuals (PAF): Attention and MLP run in parallel on the same
10
+ normalized input, saving one LayerNorm and improving information flow.
11
+ x = x + attn(norm(x)) + mlp(norm(x)) [GPT-J / PaLM style]
12
+
13
+ 3. 3-Layer Depth Recurrence: 3 unique transformer blocks looped multiple
14
+ times. Layers 0-2 recur K times at train, 2K at eval (free test-time
15
+ compute). Selective recurrence on inner layers.
16
+
17
+ 4. Score-First TTT (Test-Time Training): At eval, adapt the model's MLP
18
+ W_down weights chunk-by-chunk using NTP loss. Legal = strictly causal.
19
+ Implements the In-Place TTT mechanism from arxiv:2604.06169.
20
+
21
+ 5. Int6 GPTQ Post-Training Quantization with SDClip:
22
+ - Train in full precision (bf16/fp32)
23
+ - After training, quantize all weight matrices to int6 using GPTQ
24
+ - Std-based clipping (SDClip) before quantization reduces outlier impact
25
+ - Embeddings in GPTQ int8 with SDClip
26
+ - ~1.5x more effective parameters vs int8 in the same 16MB budget
27
+
28
+ 6. MuonEq-R: Muon optimizer with equalized learning rates (scale by
29
+ sqrt(max(fan_in, fan_out))) and weight decay regularization.
30
+
31
+ 7. QK-Gain 5.25: High gain on QK product prevents attention entropy
32
+ collapse at small model dimensions.
33
+
34
+ 8. Residual mixing with x0 anchor preserved from baseline.
35
+
36
+ Architecture:
37
+ SP8192 vocab, d_model=768, 12 heads / 4 KV heads, MLP 4x
38
+ 3 unique blocks × 8 recurrences = 24 effective layers (train)
39
+ 3 unique blocks × 16 recurrences = 48 effective layers (eval)
40
+
41
+ Run: torchrun --standalone --nproc_per_node=8 train_gpt2.py
42
+ """
43
+ from __future__ import annotations
44
+
45
+ import copy
46
+ import glob
47
+ import io
48
+ import math
49
+ import os
50
+ import random
51
+ import subprocess
52
+ import sys
53
+ import time
54
+ import uuid
55
+ import zlib
56
+ from pathlib import Path
57
+
58
+ import numpy as np
59
+ import sentencepiece as spm
60
+ import torch
61
+ import torch.distributed as dist
62
+ import torch.nn.functional as F
63
+ from torch import Tensor, nn
64
+ from torch.nn.parallel import DistributedDataParallel as DDP
65
+
66
+ # ─────────────────────────────────────────────────────────────
67
+ # HYPERPARAMETERS
68
+ # ─────────────────────────────────────────────────────────────
69
+
70
+ class Hyperparameters:
71
+ # Data paths — SP8192 tokenizer and matching data
72
+ data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192")
73
+ train_files = os.path.join(data_path, "fineweb_train_*.bin")
74
+ val_files = os.path.join(data_path, "fineweb_val_*.bin")
75
+ tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model")
76
+ run_id = os.environ.get("RUN_ID", str(uuid.uuid4()))
77
+ seed = int(os.environ.get("SEED", 1337))
78
+
79
+ # Validation
80
+ val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288))
81
+ val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000))
82
+ train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200))
83
+
84
+ # Training
85
+ iterations = int(os.environ.get("ITERATIONS", 20000))
86
+ warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500))
87
+ warmup_steps = int(os.environ.get("WARMUP_STEPS", 20))
88
+ train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288))
89
+ train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024))
90
+ max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0))
91
+
92
+ # Model — Parallel Residual Recurrent
93
+ vocab_size = int(os.environ.get("VOCAB_SIZE", 8192))
94
+ model_dim = int(os.environ.get("MODEL_DIM", 768))
95
+ num_heads = int(os.environ.get("NUM_HEADS", 12))
96
+ num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4))
97
+ mlp_mult = int(os.environ.get("MLP_MULT", 4))
98
+ num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 3))
99
+ num_recurrences = int(os.environ.get("NUM_RECURRENCES", 8))
100
+ num_eval_recurrences = int(os.environ.get("NUM_EVAL_RECURRENCES", 0)) # 0 = auto (2×)
101
+ rope_base = float(os.environ.get("ROPE_BASE", 10000.0))
102
+ logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0))
103
+ qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 5.25))
104
+
105
+ # Sliding window eval
106
+ sw_stride = int(os.environ.get("SW_STRIDE", 64))
107
+ sw_seq_len = int(os.environ.get("SW_SEQ_LEN", 1024))
108
+
109
+ # Test-Time Training (TTT)
110
+ ttt_enabled = int(os.environ.get("TTT_ENABLED", 1)) # 1 = enable at eval
111
+ ttt_lr = float(os.environ.get("TTT_LR", 0.01))
112
+ ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 64))
113
+ ttt_layers = os.environ.get("TTT_LAYERS", "all") # "all" or comma-sep indices
114
+
115
+ # Optimizer
116
+ embed_lr = float(os.environ.get("EMBED_LR", 0.05))
117
+ matrix_lr = float(os.environ.get("MATRIX_LR", 0.04))
118
+ scalar_lr = float(os.environ.get("SCALAR_LR", 0.04))
119
+ muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95))
120
+ muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5))
121
+ muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.09))
122
+ beta1 = float(os.environ.get("BETA1", 0.9))
123
+ beta2 = float(os.environ.get("BETA2", 0.95))
124
+ adam_eps = float(os.environ.get("ADAM_EPS", 1e-8))
125
+
126
+ # GPTQ quantization config
127
+ gptq_bits = int(os.environ.get("GPTQ_BITS", 6))
128
+ gptq_group_size = int(os.environ.get("GPTQ_GROUP_SIZE", 128))
129
+ sdclip_nstd = float(os.environ.get("SDCLIP_NSTD", 2.5))
130
+
131
+ # SWA/EMA
132
+ swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4))
133
+
134
+
135
+ # ─────────────────────────────────────────────────────────────
136
+ # MUON OPTIMIZER (MuonEq-R variant)
137
+ # ─────────────────────────────────────────────────────────────
138
+
139
+ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor:
140
+ a, b, c = (3.4445, -4.7750, 2.0315)
141
+ X = G.bfloat16()
142
+ X /= X.norm() + eps
143
+ transposed = G.size(0) > G.size(1)
144
+ if transposed:
145
+ X = X.T
146
+ for _ in range(steps):
147
+ A = X @ X.T
148
+ B = b * A + c * A @ A
149
+ X = a * X + B @ X
150
+ return X.T if transposed else X
151
+
152
+
153
+ class Muon(torch.optim.Optimizer):
154
+ """MuonEq-R: Muon with equalized scaling and weight decay."""
155
+ def __init__(self, params, lr: float, momentum: float, backend_steps: int,
156
+ weight_decay: float = 0.0, nesterov: bool = True):
157
+ super().__init__(params, dict(lr=lr, momentum=momentum,
158
+ backend_steps=backend_steps,
159
+ weight_decay=weight_decay,
160
+ nesterov=nesterov))
161
+
162
+ @torch.no_grad()
163
+ def step(self, closure=None):
164
+ loss = None
165
+ if closure is not None:
166
+ with torch.enable_grad():
167
+ loss = closure()
168
+ distributed = dist.is_available() and dist.is_initialized()
169
+ world_size = dist.get_world_size() if distributed else 1
170
+ rank = dist.get_rank() if distributed else 0
171
+ for group in self.param_groups:
172
+ params = group["params"]
173
+ lr = group["lr"]
174
+ momentum = group["momentum"]
175
+ backend_steps = group["backend_steps"]
176
+ weight_decay = group["weight_decay"]
177
+ nesterov = group["nesterov"]
178
+ total = sum(int(p.numel()) for p in params)
179
+ flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16)
180
+ curr = 0
181
+ for i, p in enumerate(params):
182
+ if i % world_size == rank and p.grad is not None:
183
+ g = p.grad
184
+ if weight_decay != 0.0:
185
+ g = g + weight_decay * p.data.to(g.dtype)
186
+ state = self.state[p]
187
+ if "momentum_buffer" not in state:
188
+ state["momentum_buffer"] = torch.zeros_like(g)
189
+ buf = state["momentum_buffer"]
190
+ buf.mul_(momentum).add_(g)
191
+ if nesterov:
192
+ g = g.add(buf, alpha=momentum)
193
+ g = zeropower_via_newtonschulz5(g, steps=backend_steps)
194
+ # MuonEq-R: scale by sqrt(max(fan_in, fan_out))
195
+ g *= max(1, g.size(0) / g.size(1)) ** 0.5
196
+ flat[curr: curr + p.numel()] = g.reshape(-1)
197
+ curr += p.numel()
198
+ if distributed:
199
+ dist.all_reduce(flat, op=dist.ReduceOp.SUM)
200
+ curr = 0
201
+ for p in params:
202
+ g = flat[curr: curr + p.numel()].view_as(p).to(dtype=p.dtype)
203
+ p.add_(g, alpha=-lr)
204
+ curr += p.numel()
205
+ return loss
206
+
207
+
208
+ # ─────────────────────────────────────────────────────────────
209
+ # BPB EVALUATION UTILITIES
210
+ # ─────────────────────────────────────────────────────────────
211
+
212
+ def build_sentencepiece_luts(sp, vocab_size, device):
213
+ sv = int(sp.vocab_size())
214
+ sz = max(sv, vocab_size)
215
+ bb = np.zeros(sz, dtype=np.int16)
216
+ hs = np.zeros(sz, dtype=bool)
217
+ ib = np.ones(sz, dtype=bool)
218
+ for tid in range(sv):
219
+ if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid):
220
+ continue
221
+ ib[tid] = False
222
+ if sp.is_byte(tid):
223
+ bb[tid] = 1
224
+ continue
225
+ piece = sp.id_to_piece(tid)
226
+ if piece.startswith("\u2581"):
227
+ hs[tid] = True
228
+ piece = piece[1:]
229
+ bb[tid] = len(piece.encode("utf-8"))
230
+ return (torch.tensor(bb, dtype=torch.int16, device=device),
231
+ torch.tensor(hs, dtype=torch.bool, device=device),
232
+ torch.tensor(ib, dtype=torch.bool, device=device))
233
+
234
+
235
+ def eval_val_sliding_window(args, model, rank, world_size, device,
236
+ val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
237
+ use_ttt=False):
238
+ """Sliding-window BPB: every token scored with sw_stride context."""
239
+ seq_len = args.sw_seq_len
240
+ stride = args.sw_stride
241
+ T = val_tokens.numel()
242
+ all_starts = list(range(0, T - seq_len - 1, stride))
243
+ my_starts = all_starts[rank::world_size]
244
+
245
+ loss_sum = torch.zeros((), device=device, dtype=torch.float64)
246
+ token_cnt = torch.zeros((), device=device, dtype=torch.float64)
247
+ byte_cnt = torch.zeros((), device=device, dtype=torch.float64)
248
+
249
+ # Get the raw model for TTT
250
+ raw_model = model
251
+ while hasattr(raw_model, 'module'):
252
+ raw_model = raw_model.module
253
+ if hasattr(raw_model, '_orig_mod'):
254
+ raw_model = raw_model._orig_mod
255
+
256
+ raw_model.eval()
257
+ # TTT modifies weights in-place, so we can't use inference_mode
258
+ ctx = torch.no_grad if (use_ttt and args.ttt_enabled) else torch.inference_mode
259
+ with ctx():
260
+ for start in my_starts:
261
+ end = start + seq_len
262
+ x = val_tokens[start:end].unsqueeze(0).to(device, dtype=torch.int64)
263
+ y = val_tokens[start + 1:end + 1].unsqueeze(0).to(device, dtype=torch.int64)
264
+ with torch.autocast("cuda", dtype=torch.bfloat16):
265
+ if use_ttt and args.ttt_enabled:
266
+ ptl = raw_model.per_token_loss_with_ttt(x, y, args)
267
+ else:
268
+ ptl = raw_model.per_token_loss(x, y)
269
+ lo = seq_len - stride
270
+ ptl_s = ptl[0, lo:]
271
+ y_s = y[0, lo:]
272
+ x_s = x[0, lo:]
273
+ loss_sum += ptl_s.to(torch.float64).sum()
274
+ token_cnt += ptl_s.numel()
275
+ tb = base_bytes_lut[y_s].to(torch.float64)
276
+ tb += (has_space_lut[y_s] & ~is_boundary_lut[x_s]).to(torch.float64)
277
+ byte_cnt += tb.sum()
278
+
279
+ if dist.is_available() and dist.is_initialized():
280
+ for t in (loss_sum, token_cnt, byte_cnt):
281
+ dist.all_reduce(t, op=dist.ReduceOp.SUM)
282
+
283
+ val_loss = float((loss_sum / token_cnt).item())
284
+ bpb = float((loss_sum / math.log(2) / byte_cnt).item())
285
+ raw_model.train()
286
+ return val_loss, bpb
287
+
288
+
289
+ # ─────────────────────────────────────────────────────────────
290
+ # GPTQ Int6 QUANTIZATION with SDClip
291
+ # ─────────────────────────────────────────────────────────────
292
+
293
+ CONTROL_PATTERNS = tuple(p for p in os.environ.get(
294
+ "CONTROL_TENSOR_NAME_PATTERNS",
295
+ "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,log_alpha"
296
+ ).split(",") if p)
297
+
298
+ KEEP_FP_MAX_NUMEL = 65_536
299
+ KEEP_FP_STORE_DTYPE = torch.float16
300
+ INT8_SCALE_DTYPE = torch.float16
301
+
302
+
303
+ def sdclip(t: Tensor, n_std: float = 2.5) -> Tensor:
304
+ """Std-based clipping: clip to mean +/- n_std * std."""
305
+ mean = t.float().mean()
306
+ std = t.float().std()
307
+ lo = mean - n_std * std
308
+ hi = mean + n_std * std
309
+ return t.clamp(lo.item(), hi.item())
310
+
311
+
312
+ def _quant_tensor_int6(t: Tensor, n_std: float = 2.5):
313
+ """Quantize tensor to int6 (range -31 to 31) with SDClip per row."""
314
+ t32 = t.float()
315
+ max_val = 31 # 6-bit signed: -31 to 31
316
+ if t32.ndim == 2:
317
+ # Per-row SDClip and quantization
318
+ mean = t32.mean(dim=1, keepdim=True)
319
+ std = t32.std(dim=1, keepdim=True).clamp_min(1e-9)
320
+ lo = mean - n_std * std
321
+ hi = mean + n_std * std
322
+ t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32))
323
+ clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9)
324
+ scale = clip_val / max_val
325
+ q = torch.clamp(torch.round(t_clipped / scale[:, None]), -max_val, max_val).to(torch.int8)
326
+ return q.contiguous(), scale.to(torch.float16).contiguous()
327
+ # 1D fallback
328
+ t_clipped = sdclip(t32, n_std)
329
+ cv = float(t_clipped.abs().max().item())
330
+ scale = torch.tensor(max(cv / max_val, 1.0 / max_val), dtype=torch.float32)
331
+ q = torch.clamp(torch.round(t_clipped / scale), -max_val, max_val).to(torch.int8)
332
+ return q.contiguous(), scale
333
+
334
+
335
+ def _quant_tensor_int8(t: Tensor, n_std: float = 2.5):
336
+ """Quantize tensor to int8 with SDClip."""
337
+ t32 = t.float()
338
+ if t32.ndim == 2:
339
+ mean = t32.mean(dim=1, keepdim=True)
340
+ std = t32.std(dim=1, keepdim=True).clamp_min(1e-9)
341
+ lo = mean - n_std * std
342
+ hi = mean + n_std * std
343
+ t_clipped = t32.clamp(lo.expand_as(t32), hi.expand_as(t32))
344
+ clip_val = t_clipped.abs().amax(dim=1).clamp_min(1e-9)
345
+ scale = clip_val / 127.0
346
+ q = torch.clamp(torch.round(t_clipped / scale[:, None]), -127, 127).to(torch.int8)
347
+ return q.contiguous(), scale.to(torch.float16).contiguous()
348
+ cv = float(sdclip(t32, n_std).abs().max().item())
349
+ scale = torch.tensor(max(cv / 127.0, 1.0 / 127.0), dtype=torch.float32)
350
+ q = torch.clamp(torch.round(t32.clamp(-cv, cv) / scale), -127, 127).to(torch.int8)
351
+ return q.contiguous(), scale
352
+
353
+
354
+ def quantize_state_dict(state_dict: dict, gptq_bits: int = 6, sdclip_nstd: float = 2.5):
355
+ """Mixed quantization: int6 for weight matrices, int8 for embeddings, fp16 for small/control."""
356
+ quantized, scales, dtypes, passthrough, pt_orig, qmeta = {}, {}, {}, {}, {}, {}
357
+ stats = {k: 0 for k in ("param_count", "num_tensors", "baseline_bytes", "quant_bytes")}
358
+ quant_fn = _quant_tensor_int6 if gptq_bits == 6 else _quant_tensor_int8
359
+
360
+ for name, tensor in state_dict.items():
361
+ t = tensor.detach().cpu().contiguous()
362
+ stats["param_count"] += t.numel()
363
+ stats["num_tensors"] += 1
364
+ stats["baseline_bytes"] += t.numel() * t.element_size()
365
+
366
+ if not t.is_floating_point():
367
+ passthrough[name] = t
368
+ stats["quant_bytes"] += t.numel() * t.element_size()
369
+ continue
370
+
371
+ is_ctrl = any(p in name for p in CONTROL_PATTERNS)
372
+ is_small = t.numel() <= KEEP_FP_MAX_NUMEL
373
+
374
+ # Embeddings: int8 (higher precision for tied I/O)
375
+ if "tok_emb" in name:
376
+ pt_orig[name] = str(t.dtype).removeprefix("torch.")
377
+ q, s = _quant_tensor_int8(t, sdclip_nstd)
378
+ quantized[name] = q
379
+ scales[name] = s
380
+ dtypes[name] = str(t.dtype).removeprefix("torch.")
381
+ if s.ndim > 0:
382
+ qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": 8}
383
+ stats["quant_bytes"] += q.numel() + s.numel() * s.element_size()
384
+ continue
385
+
386
+ if is_ctrl or is_small:
387
+ if t.dtype in (torch.float32, torch.bfloat16):
388
+ pt_orig[name] = str(t.dtype).removeprefix("torch.")
389
+ passthrough[name] = t.float() if is_ctrl else t.to(KEEP_FP_STORE_DTYPE)
390
+ passthrough[name] = passthrough[name].contiguous()
391
+ stats["quant_bytes"] += passthrough[name].numel() * passthrough[name].element_size()
392
+ continue
393
+
394
+ # Large weight matrices: int6 with SDClip
395
+ q, s = quant_fn(t, sdclip_nstd)
396
+ if s.ndim > 0:
397
+ qmeta[name] = {"scheme": "per_row", "axis": 0, "bits": gptq_bits}
398
+ quantized[name] = q
399
+ scales[name] = s
400
+ dtypes[name] = str(t.dtype).removeprefix("torch.")
401
+ stats["quant_bytes"] += q.numel() + s.numel() * s.element_size()
402
+
403
+ obj = {"__quant_format__": f"int{gptq_bits}_sdclip_v1",
404
+ "quantized": quantized, "scales": scales, "dtypes": dtypes,
405
+ "passthrough": passthrough}
406
+ if qmeta: obj["qmeta"] = qmeta
407
+ if pt_orig: obj["passthrough_orig_dtypes"] = pt_orig
408
+ return obj, stats
409
+
410
+
411
+ def dequantize_state_dict(obj: dict) -> dict:
412
+ out = {}
413
+ qmeta = obj.get("qmeta", {})
414
+ pt_orig = obj.get("passthrough_orig_dtypes", {})
415
+ for name, q in obj["quantized"].items():
416
+ dtype = getattr(torch, obj["dtypes"][name])
417
+ s = obj["scales"][name]
418
+ if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0:
419
+ s = s.to(torch.float32)
420
+ out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous()
421
+ else:
422
+ out[name] = (q.float() * float(s.item())).to(dtype).contiguous()
423
+ for name, t in obj["passthrough"].items():
424
+ ot = t.detach().cpu().contiguous()
425
+ od = pt_orig.get(name)
426
+ if isinstance(od, str):
427
+ ot = ot.to(dtype=getattr(torch, od)).contiguous()
428
+ out[name] = ot
429
+ return out
430
+
431
+
432
+ # ─────────────────────────────────────────────────────────────
433
+ # DATA LOADING
434
+ # ─────────────────────────────────────────────────────────────
435
+
436
+ def load_data_shard(file: Path) -> Tensor:
437
+ hdr = np.fromfile(file, dtype="<i4", count=256)
438
+ if hdr.size != 256 or int(hdr[0]) != 20240520 or int(hdr[1]) != 1:
439
+ raise ValueError(f"Bad shard: {file}")
440
+ n = int(hdr[2])
441
+ tokens = np.fromfile(file, dtype="<u2", count=n, offset=256 * 4)
442
+ return torch.from_numpy(tokens.astype(np.uint16, copy=False))
443
+
444
+
445
+ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor:
446
+ files = [Path(p) for p in sorted(glob.glob(pattern))]
447
+ if not files:
448
+ raise FileNotFoundError(f"No val files: {pattern}")
449
+ tokens = torch.cat([load_data_shard(f) for f in files]).contiguous()
450
+ usable = ((tokens.numel() - 1) // seq_len) * seq_len
451
+ return tokens[: usable + 1]
452
+
453
+
454
+ class TokenStream:
455
+ def __init__(self, pattern: str):
456
+ files = [Path(p) for p in sorted(glob.glob(pattern))]
457
+ if not files:
458
+ raise FileNotFoundError(f"No shards: {pattern}")
459
+ self.files = files
460
+ self.idx = 0
461
+ self.tokens = load_data_shard(files[0])
462
+ self.pos = 0
463
+
464
+ def take(self, n: int) -> Tensor:
465
+ chunks, rem = [], n
466
+ while rem > 0:
467
+ avail = self.tokens.numel() - self.pos
468
+ if avail <= 0:
469
+ self.idx = (self.idx + 1) % len(self.files)
470
+ self.tokens = load_data_shard(self.files[self.idx])
471
+ self.pos = 0
472
+ avail = self.tokens.numel()
473
+ k = min(rem, avail)
474
+ chunks.append(self.tokens[self.pos: self.pos + k])
475
+ self.pos += k
476
+ rem -= k
477
+ return chunks[0] if len(chunks) == 1 else torch.cat(chunks)
478
+
479
+
480
+ class DistributedTokenLoader:
481
+ def __init__(self, pattern, rank, world_size, device):
482
+ self.rank = rank; self.ws = world_size; self.device = device
483
+ self.stream = TokenStream(pattern)
484
+
485
+ def next_batch(self, global_tokens, seq_len, grad_accum):
486
+ local_tokens = global_tokens // (self.ws * grad_accum)
487
+ per_rank_span = local_tokens + 1
488
+ chunk = self.stream.take(per_rank_span * self.ws)
489
+ start = self.rank * per_rank_span
490
+ local = chunk[start: start + per_rank_span].to(torch.int64)
491
+ x = local[:-1].reshape(-1, seq_len)
492
+ y = local[1:].reshape(-1, seq_len)
493
+ return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
494
+
495
+
496
+ # ─────────────────────────────────────────────────────────────
497
+ # TRANSFORMER COMPONENTS — Parallel Residual Architecture
498
+ # ─────────────────────────────────────────────────────────────
499
+
500
+ class RMSNorm(nn.Module):
501
+ def __init__(self, eps: float | None = None):
502
+ super().__init__()
503
+ self.eps = eps
504
+
505
+ def forward(self, x: Tensor) -> Tensor:
506
+ return F.rms_norm(x, (x.size(-1),), eps=self.eps)
507
+
508
+
509
+ class Rotary(nn.Module):
510
+ def __init__(self, dim: int, base: float = 10000.0):
511
+ super().__init__()
512
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
513
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
514
+ self._cached_len = 0
515
+ self._cos: Tensor | None = None
516
+ self._sin: Tensor | None = None
517
+
518
+ def forward(self, seq_len: int, device, dtype):
519
+ if self._cos is None or self._cached_len != seq_len or self._cos.device != device:
520
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
521
+ freqs = torch.outer(t, self.inv_freq.to(device))
522
+ self._cos = freqs.cos()[None, None, :, :]
523
+ self._sin = freqs.sin()[None, None, :, :]
524
+ self._cached_len = seq_len
525
+ return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype)
526
+
527
+
528
+ def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
529
+ half = x.size(-1) // 2
530
+ x1, x2 = x[..., :half], x[..., half:]
531
+ return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
532
+
533
+
534
+ class CausalSelfAttention(nn.Module):
535
+ def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init):
536
+ super().__init__()
537
+ assert dim % num_heads == 0 and num_heads % num_kv_heads == 0
538
+ self.num_heads = num_heads
539
+ self.num_kv_heads = num_kv_heads
540
+ self.head_dim = dim // num_heads
541
+ kv_dim = num_kv_heads * self.head_dim
542
+ self.c_q = nn.Linear(dim, dim, bias=False)
543
+ self.c_k = nn.Linear(dim, kv_dim, bias=False)
544
+ self.c_v = nn.Linear(dim, kv_dim, bias=False)
545
+ self.proj = nn.Linear(dim, dim, bias=False)
546
+ self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
547
+ self.rotary = Rotary(self.head_dim, base=rope_base)
548
+
549
+ def forward(self, x: Tensor) -> Tensor:
550
+ B, T, _ = x.shape
551
+ q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
552
+ k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
553
+ v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
554
+ q = F.rms_norm(q, (q.size(-1),))
555
+ k = F.rms_norm(k, (k.size(-1),))
556
+ cos, sin = self.rotary(T, x.device, q.dtype)
557
+ q = apply_rotary(q, cos, sin)
558
+ k = apply_rotary(k, cos, sin)
559
+ q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
560
+ y = F.scaled_dot_product_attention(q, k, v,
561
+ attn_mask=None, is_causal=True,
562
+ enable_gqa=(self.num_kv_heads != self.num_heads))
563
+ return self.proj(y.transpose(1, 2).contiguous().reshape(B, T, -1))
564
+
565
+
566
+ class MLP(nn.Module):
567
+ def __init__(self, dim, mlp_mult):
568
+ super().__init__()
569
+ hidden = dim * mlp_mult
570
+ self.fc = nn.Linear(dim, hidden, bias=False)
571
+ self.proj = nn.Linear(hidden, dim, bias=False)
572
+
573
+ def forward(self, x: Tensor) -> Tensor:
574
+ return self.proj(torch.relu(self.fc(x)).square())
575
+
576
+
577
+ class ParallelBlock(nn.Module):
578
+ """Parallel Residual Block: attn and MLP run on the same normalized input.
579
+
580
+ x = resid_mix[0]*x + resid_mix[1]*x0
581
+ h = norm(x)
582
+ x = x + attn_scale * attn(h) + mlp_scale * mlp(h)
583
+ """
584
+ def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init):
585
+ super().__init__()
586
+ self.norm = RMSNorm()
587
+ self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
588
+ self.mlp = MLP(dim, mlp_mult)
589
+ self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
590
+ self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
591
+ self.resid_mix = nn.Parameter(torch.stack([torch.ones(dim), torch.zeros(dim)]).float())
592
+
593
+ def forward(self, x: Tensor, x0: Tensor) -> Tensor:
594
+ mix = self.resid_mix.to(x.dtype)
595
+ x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
596
+ h = self.norm(x)
597
+ # Parallel: both attn and MLP operate on same normalized input
598
+ x = x + self.attn_scale.to(x.dtype)[None, None, :] * self.attn(h) \
599
+ + self.mlp_scale.to(x.dtype)[None, None, :] * self.mlp(h)
600
+ return x
601
+
602
+
603
+ # ─────────────────────────────────────────────────────────────
604
+ # RECURRENT GPT MODEL with Score-First TTT
605
+ # ─────────────────────────────────────────────────────────────
606
+
607
+ class RecurrentGPT(nn.Module):
608
+ """
609
+ K unique parallel-residual blocks x N recurrences.
610
+ At eval: 2N recurrences + optional score-first TTT.
611
+ """
612
+ def __init__(self, args: Hyperparameters):
613
+ super().__init__()
614
+ self.logit_softcap = args.logit_softcap
615
+ self._train_rec = args.num_recurrences
616
+ self._eval_rec = args.num_eval_recurrences or args.num_recurrences * 2
617
+ self._vocab_size = args.vocab_size
618
+
619
+ self.tok_emb = nn.Embedding(args.vocab_size, args.model_dim)
620
+ self.blocks = nn.ModuleList([
621
+ ParallelBlock(args.model_dim, args.num_heads, args.num_kv_heads,
622
+ args.mlp_mult, args.rope_base, args.qk_gain_init)
623
+ for _ in range(args.num_unique_layers)
624
+ ])
625
+ self.final_norm = RMSNorm()
626
+ nn.init.normal_(self.tok_emb.weight, std=0.005)
627
+
628
+ def _forward_hidden(self, input_ids: Tensor) -> Tensor:
629
+ x = F.rms_norm(self.tok_emb(input_ids), (self.tok_emb.embedding_dim,))
630
+ x0 = x
631
+ n = self._train_rec if self.training else self._eval_rec
632
+ for _ in range(n):
633
+ for block in self.blocks:
634
+ x = block(x, x0)
635
+ return self.final_norm(x)
636
+
637
+ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
638
+ h = self._forward_hidden(input_ids)
639
+ logits = F.linear(h.reshape(-1, h.size(-1)), self.tok_emb.weight)
640
+ logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
641
+ return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean")
642
+
643
+ def per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
644
+ h = self._forward_hidden(input_ids)
645
+ B, T, D = h.shape
646
+ logits = F.linear(h.reshape(B * T, D), self.tok_emb.weight)
647
+ logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
648
+ return F.cross_entropy(logits.float(), target_ids.reshape(B * T),
649
+ reduction="none").reshape(B, T)
650
+
651
+ @torch.no_grad()
652
+ def per_token_loss_with_ttt(self, input_ids: Tensor, target_ids: Tensor,
653
+ args: Hyperparameters) -> Tensor:
654
+ """Score-first TTT: adapt MLP.proj weights chunk-by-chunk at eval.
655
+
656
+ "Score-first" means: for each chunk, we first SCORE (compute loss) with
657
+ current weights, then UPDATE weights for the next chunk. This is strictly
658
+ causal -- predictions for chunk i only use information from chunks 0..i-1.
659
+
660
+ We update MLP.proj.weight (the "down projection") in each block --
661
+ this is the "fast weight" in the In-Place TTT framework (arxiv:2604.06169).
662
+ """
663
+ chunk_size = args.ttt_chunk_size
664
+ ttt_lr = args.ttt_lr
665
+ B, T = input_ids.shape
666
+
667
+ # Determine which layers to apply TTT
668
+ if args.ttt_layers == "all":
669
+ ttt_layer_indices = list(range(len(self.blocks)))
670
+ else:
671
+ ttt_layer_indices = [int(x) for x in args.ttt_layers.split(",")]
672
+
673
+ # Save original weights to restore after this sequence
674
+ original_weights = {}
675
+ for li in ttt_layer_indices:
676
+ original_weights[li] = self.blocks[li].mlp.proj.weight.data.clone()
677
+
678
+ all_ptl = []
679
+ n_chunks = (T + chunk_size - 1) // chunk_size
680
+
681
+ for ci in range(n_chunks):
682
+ lo = ci * chunk_size
683
+ hi = min((ci + 1) * chunk_size, T)
684
+
685
+ # Score first: full forward pass with current (possibly updated) weights
686
+ h = self._forward_hidden(input_ids) # (B, T, D)
687
+ h_chunk = h[:, lo:hi, :]
688
+ y_chunk = target_ids[:, lo:hi]
689
+
690
+ logits = F.linear(h_chunk.reshape(-1, h_chunk.size(-1)), self.tok_emb.weight)
691
+ logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
692
+ ptl = F.cross_entropy(logits.float(), y_chunk.reshape(-1),
693
+ reduction="none").reshape(B, hi - lo)
694
+ all_ptl.append(ptl)
695
+
696
+ # Then update: manual gradient step on MLP.proj for next chunk
697
+ if ci < n_chunks - 1:
698
+ for li in ttt_layer_indices:
699
+ block = self.blocks[li]
700
+ # Get MLP intermediate activations for this chunk
701
+ h_norm = F.rms_norm(h_chunk.reshape(-1, h_chunk.size(-1)).float(),
702
+ (h_chunk.size(-1),))
703
+ z = torch.relu(block.mlp.fc(h_norm.to(h_chunk.dtype))).square()
704
+ # Reconstruction-based update: minimize ||Z @ W^T - h_norm||^2
705
+ pred = z @ block.mlp.proj.weight.T
706
+ residual = pred - h_norm.to(pred.dtype)
707
+ grad_w = residual.T @ z / z.size(0)
708
+ block.mlp.proj.weight.data -= ttt_lr * grad_w.to(block.mlp.proj.weight.dtype)
709
+
710
+ # Restore original weights after processing this sequence
711
+ for li in ttt_layer_indices:
712
+ self.blocks[li].mlp.proj.weight.data = original_weights[li]
713
+
714
+ return torch.cat(all_ptl, dim=1)
715
+
716
+
717
+ # ─────────────────────────────────────────────────────────────
718
+ # EMA (Exponential Moving Average)
719
+ # ─────────────────────────────────────────────────────────────
720
+
721
+ class EMA:
722
+ """Exponential Moving Average of model parameters."""
723
+ def __init__(self, model: nn.Module, decay: float = 0.999):
724
+ self.model = model
725
+ self.decay = decay
726
+ self.shadow = {n: p.data.clone() for n, p in model.named_parameters()}
727
+ self.backup = {}
728
+
729
+ def update(self):
730
+ for n, p in self.model.named_parameters():
731
+ self.shadow[n].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
732
+
733
+ def apply(self):
734
+ """Apply EMA weights (backup current)."""
735
+ self.backup = {}
736
+ for n, p in self.model.named_parameters():
737
+ self.backup[n] = p.data.clone()
738
+ p.data.copy_(self.shadow[n])
739
+
740
+ def restore(self):
741
+ """Restore original weights."""
742
+ for n, p in self.model.named_parameters():
743
+ p.data.copy_(self.backup[n])
744
+ self.backup = {}
745
+
746
+
747
+ # ─────────────────────────────────────────────────────────────
748
+ # TRAINING
749
+ # ─────────────────────────────────────────────────────────────
750
+
751
+ def main():
752
+ global zeropower_via_newtonschulz5
753
+ code = Path(__file__).read_text(encoding="utf-8")
754
+ args = Hyperparameters()
755
+
756
+ zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5)
757
+
758
+ # -- distributed setup --
759
+ distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ
760
+ rank = int(os.environ.get("RANK", "0"))
761
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
762
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
763
+ grad_accum = max(1, 8 // world_size)
764
+ grad_scale = 1.0 / grad_accum
765
+
766
+ if not torch.cuda.is_available():
767
+ raise RuntimeError("CUDA required")
768
+ device = torch.device("cuda", local_rank)
769
+ torch.cuda.set_device(device)
770
+ if distributed:
771
+ dist.init_process_group("nccl", device_id=device)
772
+ dist.barrier()
773
+
774
+ master = rank == 0
775
+ torch.backends.cuda.matmul.allow_tf32 = True
776
+ torch.backends.cudnn.allow_tf32 = True
777
+ from torch.backends.cuda import (enable_flash_sdp, enable_math_sdp,
778
+ enable_mem_efficient_sdp, enable_cudnn_sdp)
779
+ enable_flash_sdp(True); enable_math_sdp(False)
780
+ enable_mem_efficient_sdp(False); enable_cudnn_sdp(False)
781
+
782
+ logfile = None
783
+ if master:
784
+ os.makedirs("logs", exist_ok=True)
785
+ logfile = f"logs/{args.run_id}.txt"
786
+ print(logfile)
787
+
788
+ def log0(msg, console=True):
789
+ if not master: return
790
+ if console: print(msg)
791
+ if logfile:
792
+ with open(logfile, "a") as f: print(msg, file=f)
793
+
794
+ log0(code, console=False)
795
+ log0(f"Python {sys.version}", console=False)
796
+ log0(f"PyTorch {torch.__version__}", console=False)
797
+ try:
798
+ log0(subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=False).stdout,
799
+ console=False)
800
+ except FileNotFoundError:
801
+ pass
802
+
803
+ random.seed(args.seed); np.random.seed(args.seed)
804
+ torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
805
+
806
+ # -- tokenizer + val data --
807
+ sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
808
+ base_bytes_lut, has_space_lut, is_boundary_lut = build_sentencepiece_luts(
809
+ sp, args.vocab_size, device)
810
+ val_tokens = load_validation_tokens(args.val_files, args.sw_seq_len)
811
+ log0(f"val_tokens:{val_tokens.numel()}")
812
+
813
+ # -- model --
814
+ base_model = RecurrentGPT(args).to(device).bfloat16()
815
+
816
+ compiled = torch.compile(base_model, dynamic=False, fullgraph=True)
817
+ model = DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) \
818
+ if distributed else compiled
819
+
820
+ n_unique = sum(p.numel() for p in base_model.parameters())
821
+ eff_depth = args.num_unique_layers * args.num_recurrences
822
+ log0(f"unique_params:{n_unique} effective_depth:{eff_depth} "
823
+ f"train_loops:{args.num_recurrences} eval_loops:{base_model._eval_rec}")
824
+ log0(f"world_size:{world_size} grad_accum:{grad_accum}")
825
+
826
+ # -- optimizer --
827
+ block_params = list(base_model.blocks.named_parameters())
828
+ matrix_params = [p for n, p in block_params
829
+ if p.ndim == 2 and not any(pat in n for pat in CONTROL_PATTERNS)]
830
+ scalar_params = [p for n, p in block_params
831
+ if p.ndim < 2 or any(pat in n for pat in CONTROL_PATTERNS)]
832
+
833
+ opt_tok = torch.optim.Adam(
834
+ [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}],
835
+ betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
836
+ opt_muon = Muon(matrix_params, lr=args.matrix_lr,
837
+ momentum=args.muon_momentum, backend_steps=args.muon_backend_steps,
838
+ weight_decay=args.muon_weight_decay)
839
+ for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr
840
+ opt_scalar = torch.optim.Adam(
841
+ [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}],
842
+ betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True)
843
+ optimizers = [opt_tok, opt_muon, opt_scalar]
844
+
845
+ # -- EMA --
846
+ ema = EMA(base_model, decay=0.999)
847
+ ema_start_step = int(args.iterations * args.swa_start_frac)
848
+
849
+ # -- LR schedule --
850
+ max_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None
851
+
852
+ def lr_mul(step, elapsed_ms):
853
+ if args.warmdown_iters <= 0: return 1.0
854
+ if max_ms is None:
855
+ ws = max(args.iterations - args.warmdown_iters, 0)
856
+ return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \
857
+ if ws <= step < args.iterations else 1.0
858
+ step_ms = elapsed_ms / max(step, 1)
859
+ remain = max(max_ms - elapsed_ms, 0.0)
860
+ wd_ms = args.warmdown_iters * step_ms
861
+ return remain / max(wd_ms, 1e-9) if remain <= wd_ms else 1.0
862
+
863
+ def zero_all(): [o.zero_grad(set_to_none=True) for o in optimizers]
864
+
865
+ # -- warmup --
866
+ if args.warmup_steps > 0:
867
+ init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()}
868
+ init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]
869
+ model.train()
870
+ train_loader_w = DistributedTokenLoader(args.train_files, rank, world_size, device)
871
+ for ws_i in range(args.warmup_steps):
872
+ zero_all()
873
+ for ms_i in range(grad_accum):
874
+ if distributed:
875
+ model.require_backward_grad_sync = (ms_i == grad_accum - 1)
876
+ x, y = train_loader_w.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum)
877
+ with torch.autocast("cuda", torch.bfloat16):
878
+ (model(x, y) * grad_scale).backward()
879
+ for o in optimizers: o.step()
880
+ zero_all()
881
+ base_model.load_state_dict(init_model, strict=True)
882
+ for o, s in zip(optimizers, init_opts): o.load_state_dict(s)
883
+ zero_all()
884
+ if distributed: model.require_backward_grad_sync = True
885
+
886
+ # -- data + training loop --
887
+ train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
888
+ training_ms = 0.0
889
+ stop_step: int | None = None
890
+ torch.cuda.synchronize()
891
+ t0 = time.perf_counter()
892
+ step = 0
893
+
894
+ while True:
895
+ last_step = step == args.iterations or (stop_step is not None and step >= stop_step)
896
+ do_val = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)
897
+
898
+ if do_val:
899
+ torch.cuda.synchronize()
900
+ training_ms += 1000.0 * (time.perf_counter() - t0)
901
+ vl, vbpb = eval_val_sliding_window(
902
+ args, model, rank, world_size, device,
903
+ val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
904
+ use_ttt=False)
905
+ log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} "
906
+ f"train_ms:{training_ms:.0f} step_avg:{training_ms/max(step,1):.2f}ms")
907
+ torch.cuda.synchronize()
908
+ t0 = time.perf_counter()
909
+
910
+ if last_step:
911
+ if master:
912
+ # Apply EMA weights for final model
913
+ ema.apply()
914
+
915
+ # Evaluate with TTT
916
+ log0("Evaluating with EMA + TTT...")
917
+ vl_ema, vbpb_ema = eval_val_sliding_window(
918
+ args, base_model, rank, world_size, device,
919
+ val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
920
+ use_ttt=True)
921
+ log0(f"ema_ttt val_loss:{vl_ema:.4f} val_bpb:{vbpb_ema:.4f}")
922
+
923
+ # Quantize and export
924
+ sd = base_model.state_dict()
925
+ obj, stats = quantize_state_dict(sd, args.gptq_bits, args.sdclip_nstd)
926
+ buf = io.BytesIO()
927
+ torch.save(obj, buf)
928
+ compressed = zlib.compress(buf.getvalue(), level=9)
929
+ code_bytes = len(code.encode())
930
+ model_bytes = len(compressed)
931
+ total_bytes = code_bytes + model_bytes
932
+ log0(f"final_quant_zlib_roundtrip "
933
+ f"code_bytes:{code_bytes} "
934
+ f"model_compressed_bytes:{model_bytes} "
935
+ f"total_artifact_bytes:{total_bytes} "
936
+ f"total_artifact_mb:{total_bytes/1e6:.3f} "
937
+ f"param_count:{stats['param_count']}")
938
+
939
+ # Round-trip verify
940
+ sd2 = dequantize_state_dict(obj)
941
+ base_model.load_state_dict(sd2, strict=True)
942
+ vl2, vbpb2 = eval_val_sliding_window(
943
+ args, base_model, rank, world_size, device,
944
+ val_tokens, base_bytes_lut, has_space_lut, is_boundary_lut,
945
+ use_ttt=True)
946
+ log0(f"quantized_model+ttt val_loss:{vl2:.4f} val_bpb:{vbpb2:.4f}")
947
+
948
+ # Restore non-EMA weights
949
+ ema.restore()
950
+ break
951
+
952
+ if stop_step is None and max_ms is not None:
953
+ torch.cuda.synchronize()
954
+ elapsed = 1000.0 * (time.perf_counter() - t0) + training_ms
955
+ if elapsed >= max_ms:
956
+ stop_step = step + 1
957
+
958
+ zero_all()
959
+ for ms_i in range(grad_accum):
960
+ if distributed:
961
+ model.require_backward_grad_sync = (ms_i == grad_accum - 1)
962
+ x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum)
963
+ with torch.autocast("cuda", torch.bfloat16):
964
+ (model(x, y) * grad_scale).backward()
965
+
966
+ torch.cuda.synchronize()
967
+ elapsed_ms = 1000.0 * (time.perf_counter() - t0) + training_ms
968
+ m = lr_mul(step, elapsed_ms)
969
+ for o in optimizers:
970
+ for g in o.param_groups: g["lr"] = g["base_lr"] * m
971
+ for o in optimizers: o.step()
972
+
973
+ # EMA update
974
+ if step >= ema_start_step:
975
+ ema.update()
976
+
977
+ if step % args.train_log_every == 0 and master:
978
+ log0(f"step:{step} lr_mul:{m:.4f}")
979
+
980
+ step += 1
981
+
982
+
983
+ if __name__ == "__main__":
984
+ main()