ereniko commited on
Commit
edfd803
·
verified ·
1 Parent(s): 0d72706

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +276 -0
train.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train İvme-Conversate.
3
+
4
+ Pulls together every decision we locked in:
5
+ - ~22M decoder (model.py)
6
+ - Muon + AdamW hybrid (muon.py)
7
+ - Warmup-Stable-Decay LR schedule
8
+ - Curriculum data (sequential read of train.bin = ascending quality)
9
+ - bf16 autocast + gradient accumulation to an effective batch of 256 seqs
10
+ - Live weight EMA (the "checkpoint averaging" win, applied continuously)
11
+ - Flash attention via HF Kernels on the training box (set attn_backend)
12
+
13
+ Target run: ~1.57B tokens / 262K tokens-per-step ≈ 6000 steps.
14
+ On an RTX 4090 (bf16, FA2) that's roughly an hour and well under $1.
15
+
16
+ Usage:
17
+ python train.py # full run, reads data/train.bin
18
+ python train.py --smoke # 50-step run on random data, no files needed
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import math
25
+ import os
26
+ import time
27
+ from copy import deepcopy
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from model import IvmeConfig, IvmeConversate
33
+ from muon import build_optimizers, wsd_lr_multiplier
34
+
35
+
36
+ # --------------------------------------------------------------------------- #
37
+ # Training config
38
+ # --------------------------------------------------------------------------- #
39
+ class TrainConfig:
40
+ data_dir = "data"
41
+ out_dir = "checkpoints"
42
+
43
+ # Effective batch = micro_batch * grad_accum * seq_len tokens.
44
+ # On the RTX PRO 6000 Blackwell (96GB): 128 * 8 * 1024 = 1.05M tokens/step.
45
+ seq_len = 1024
46
+ micro_batch = 128
47
+ grad_accum = 8
48
+ # 1.518B train tokens / 1.05M per step ≈ 1447 steps for one Chinchilla-optimal pass.
49
+ total_steps = 1447
50
+
51
+ muon_lr = 0.02
52
+ adamw_lr = 3e-4
53
+ weight_decay = 0.1
54
+ grad_clip = 1.0
55
+ warmup_steps = 100
56
+ decay_frac = 0.2 # WSD decay over final 20% (now starts ~step 1158)
57
+
58
+ ema_decay = 0.999 # live weight EMA
59
+ eval_interval = 500
60
+ eval_iters = 50
61
+ ckpt_interval = 1000
62
+
63
+ attn_backend = "sdpa" # switch to "kernels" on the training box
64
+ seed = 1337
65
+
66
+
67
+ # --------------------------------------------------------------------------- #
68
+ # Data
69
+ # --------------------------------------------------------------------------- #
70
+ class BinDataset:
71
+ """Reads a packed uint16 .bin. Sequential pointer preserves the curriculum;
72
+ a small local shuffle buffer avoids pathological micro-ordering."""
73
+
74
+ def __init__(self, path, seq_len, micro_batch, device, curriculum=True):
75
+ self.data = np.memmap(path, dtype=np.uint16, mode="r")
76
+ self.seq_len = seq_len
77
+ self.micro_batch = micro_batch
78
+ self.device = device
79
+ self.curriculum = curriculum
80
+ self.ptr = 0
81
+
82
+ def get_batch(self):
83
+ span = self.seq_len + 1
84
+ need = self.micro_batch
85
+ if self.curriculum:
86
+ # Sequential windows from the curriculum-ordered stream.
87
+ starts = [self.ptr + i * span for i in range(need)]
88
+ self.ptr += need * span
89
+ if self.ptr + need * span >= len(self.data):
90
+ self.ptr = 0 # wrap (a new epoch; rare at Chinchilla-optimal)
91
+ else:
92
+ starts = np.random.randint(0, len(self.data) - span, size=need).tolist()
93
+
94
+ x = np.stack([self.data[s : s + self.seq_len] for s in starts])
95
+ y = np.stack([self.data[s + 1 : s + 1 + self.seq_len] for s in starts])
96
+ x = torch.from_numpy(x.astype(np.int64))
97
+ y = torch.from_numpy(y.astype(np.int64))
98
+ return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
99
+
100
+
101
+ class RandomDataset:
102
+ """Stand-in for --smoke runs: random tokens, no files needed."""
103
+
104
+ def __init__(self, vocab, seq_len, micro_batch, device):
105
+ self.vocab, self.seq_len, self.micro_batch, self.device = vocab, seq_len, micro_batch, device
106
+
107
+ def get_batch(self):
108
+ x = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device)
109
+ y = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device)
110
+ return x, y
111
+
112
+
113
+ # --------------------------------------------------------------------------- #
114
+ # EMA
115
+ # --------------------------------------------------------------------------- #
116
+ class EMA:
117
+ """Live exponential moving average of model weights — a continuous version
118
+ of the checkpoint-averaging trick that reliably nudges final quality up."""
119
+
120
+ def __init__(self, model, decay):
121
+ self.decay = decay
122
+ self.shadow = deepcopy(model.state_dict())
123
+ for v in self.shadow.values():
124
+ v.requires_grad_(False)
125
+
126
+ @torch.no_grad()
127
+ def update(self, model):
128
+ for k, v in model.state_dict().items():
129
+ if v.dtype.is_floating_point:
130
+ self.shadow[k].mul_(self.decay).add_(v, alpha=1 - self.decay)
131
+ else:
132
+ self.shadow[k].copy_(v)
133
+
134
+
135
+ # --------------------------------------------------------------------------- #
136
+ # Train
137
+ # --------------------------------------------------------------------------- #
138
+ def main(smoke=False, resume=None):
139
+ cfg = TrainConfig()
140
+ if smoke:
141
+ cfg.total_steps = 50
142
+ cfg.eval_interval = 25
143
+ cfg.eval_iters = 5
144
+ cfg.ckpt_interval = 9999
145
+ cfg.warmup_steps = 5
146
+ cfg.micro_batch = 4
147
+ cfg.grad_accum = 2
148
+ cfg.seq_len = 128
149
+
150
+ torch.manual_seed(cfg.seed)
151
+ device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ use_amp = device == "cuda"
153
+ print(f"[train] device={device} amp(bf16)={use_amp} smoke={smoke}")
154
+
155
+ mcfg = IvmeConfig(max_seq_len=cfg.seq_len, attn_backend=cfg.attn_backend)
156
+ model = IvmeConversate(mcfg).to(device)
157
+ print(f"[train] model params: {model.num_params()/1e6:.1f}M")
158
+
159
+ muon, adamw = build_optimizers(
160
+ model, muon_lr=cfg.muon_lr, adamw_lr=cfg.adamw_lr, weight_decay=cfg.weight_decay
161
+ )
162
+ ema = EMA(model, cfg.ema_decay)
163
+
164
+ if smoke:
165
+ train_ds = RandomDataset(mcfg.vocab_size, cfg.seq_len, cfg.micro_batch, device)
166
+ val_ds = train_ds
167
+ else:
168
+ train_ds = BinDataset(os.path.join(cfg.data_dir, "train.bin"),
169
+ cfg.seq_len, cfg.micro_batch, device, curriculum=True)
170
+ val_ds = BinDataset(os.path.join(cfg.data_dir, "val.bin"),
171
+ cfg.seq_len, cfg.micro_batch, device, curriculum=False)
172
+
173
+ os.makedirs(cfg.out_dir, exist_ok=True)
174
+
175
+ # ---- Resume from a checkpoint, if requested ----
176
+ start_step = 0
177
+ if resume:
178
+ print(f"[resume] loading {resume}")
179
+ ckpt = torch.load(resume, map_location=device, weights_only=False)
180
+ model.load_state_dict(ckpt["model"])
181
+ ema.shadow = ckpt["ema"]
182
+ start_step = ckpt.get("step", 0)
183
+ # Optimizer momentum buffers (Muon) and moments (AdamW) — restore if the
184
+ # checkpoint has them; older checkpoints won't, so we warn and continue.
185
+ if "muon" in ckpt and "adamw" in ckpt:
186
+ muon.load_state_dict(ckpt["muon"])
187
+ adamw.load_state_dict(ckpt["adamw"])
188
+ print(f"[resume] restored optimizer states")
189
+ else:
190
+ print("[resume] WARNING: checkpoint has no optimizer state — "
191
+ "Muon/AdamW restart cold (a brief loss bump for ~20-50 steps is normal)")
192
+ # Fast-forward the curriculum data pointer to where we left off so we
193
+ # don't re-read from the top of train.bin and break the curriculum order.
194
+ if not smoke:
195
+ train_ds.ptr = start_step * cfg.grad_accum * cfg.micro_batch * (cfg.seq_len + 1)
196
+ if train_ds.ptr >= len(train_ds.data):
197
+ train_ds.ptr = 0
198
+ print(f"[resume] data pointer -> token {train_ds.ptr:,} "
199
+ f"(resuming at step {start_step})")
200
+
201
+ amp_ctx = (torch.autocast(device_type="cuda", dtype=torch.bfloat16)
202
+ if use_amp else torch.autocast(device_type="cpu", enabled=False))
203
+
204
+ @torch.no_grad()
205
+ def evaluate():
206
+ model.eval()
207
+ losses = []
208
+ for _ in range(cfg.eval_iters):
209
+ x, y = val_ds.get_batch()
210
+ with amp_ctx:
211
+ _, loss = model(x, y)
212
+ losses.append(loss.item())
213
+ model.train()
214
+ return sum(losses) / len(losses)
215
+
216
+ model.train()
217
+ t0 = time.time()
218
+ tokens_seen = 0
219
+
220
+ for step in range(start_step, cfg.total_steps):
221
+ # Set the WSD-scheduled lr on both optimizers.
222
+ mult = wsd_lr_multiplier(step, cfg.total_steps, cfg.warmup_steps, cfg.decay_frac)
223
+ for g in muon.param_groups:
224
+ g["lr"] = cfg.muon_lr * mult
225
+ for g in adamw.param_groups:
226
+ g["lr"] = cfg.adamw_lr * mult
227
+
228
+ muon.zero_grad(set_to_none=True)
229
+ adamw.zero_grad(set_to_none=True)
230
+
231
+ accum_loss = 0.0
232
+ for _ in range(cfg.grad_accum):
233
+ x, y = train_ds.get_batch()
234
+ with amp_ctx:
235
+ _, loss = model(x, y)
236
+ loss = loss / cfg.grad_accum
237
+ loss.backward()
238
+ accum_loss += loss.item()
239
+ tokens_seen += x.numel()
240
+
241
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
242
+ muon.step()
243
+ adamw.step()
244
+ ema.update(model)
245
+
246
+ if step % 10 == 0:
247
+ dt = time.time() - t0
248
+ tps = tokens_seen / max(dt, 1e-6)
249
+ print(f"step {step:>5}/{cfg.total_steps} | loss {accum_loss:.4f} "
250
+ f"| lr_mult {mult:.3f} | {tps/1e3:.0f}K tok/s | {tokens_seen/1e6:.1f}M tok")
251
+
252
+ if step > 0 and step % cfg.eval_interval == 0:
253
+ vloss = evaluate()
254
+ print(f" [eval] step {step}: val_loss {vloss:.4f} | val_ppl {math.exp(vloss):.2f}")
255
+
256
+ if step > 0 and step % cfg.ckpt_interval == 0:
257
+ path = os.path.join(cfg.out_dir, f"ivme_step{step}.pt")
258
+ torch.save({"model": model.state_dict(), "ema": ema.shadow,
259
+ "muon": muon.state_dict(), "adamw": adamw.state_dict(),
260
+ "cfg": mcfg, "step": step}, path)
261
+ print(f" [ckpt] saved {path}")
262
+
263
+ # Final save: both the trained weights and the EMA weights (use EMA for eval).
264
+ final = os.path.join(cfg.out_dir, "ivme_final.pt")
265
+ torch.save({"model": model.state_dict(), "ema": ema.shadow, "cfg": mcfg,
266
+ "step": cfg.total_steps}, final)
267
+ print(f"[train] done in {(time.time()-t0):.1f}s | final -> {final}")
268
+
269
+
270
+ if __name__ == "__main__":
271
+ ap = argparse.ArgumentParser()
272
+ ap.add_argument("--smoke", action="store_true")
273
+ ap.add_argument("--resume", type=str, default=None,
274
+ help="path to a checkpoint .pt to resume from")
275
+ args = ap.parse_args()
276
+ main(smoke=args.smoke, resume=args.resume)