| """ |
| Train İvme-Conversate. |
| |
| Pulls together every decision we locked in: |
| - ~22M decoder (model.py) |
| - Muon + AdamW hybrid (muon.py) |
| - Warmup-Stable-Decay LR schedule |
| - Curriculum data (sequential read of train.bin = ascending quality) |
| - bf16 autocast + gradient accumulation to an effective batch of 256 seqs |
| - Live weight EMA (the "checkpoint averaging" win, applied continuously) |
| - Flash attention via HF Kernels on the training box (set attn_backend) |
| |
| Target run: ~1.57B tokens / 262K tokens-per-step ≈ 6000 steps. |
| On an RTX 4090 (bf16, FA2) that's roughly an hour and well under $1. |
| |
| Usage: |
| python train.py # full run, reads data/train.bin |
| python train.py --smoke # 50-step run on random data, no files needed |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import os |
| import time |
| from copy import deepcopy |
|
|
| import numpy as np |
| import torch |
|
|
| from model import IvmeConfig, IvmeConversate |
| from muon import build_optimizers, wsd_lr_multiplier |
|
|
|
|
| |
| |
| |
| class TrainConfig: |
| data_dir = "data" |
| out_dir = "checkpoints" |
|
|
| |
| |
| seq_len = 1024 |
| micro_batch = 128 |
| grad_accum = 8 |
| |
| total_steps = 1447 |
|
|
| muon_lr = 0.02 |
| adamw_lr = 3e-4 |
| weight_decay = 0.1 |
| grad_clip = 1.0 |
| warmup_steps = 100 |
| decay_frac = 0.2 |
|
|
| ema_decay = 0.999 |
| eval_interval = 500 |
| eval_iters = 50 |
| ckpt_interval = 1000 |
|
|
| attn_backend = "sdpa" |
| seed = 1337 |
|
|
|
|
| |
| |
| |
| class BinDataset: |
| """Reads a packed uint16 .bin. Sequential pointer preserves the curriculum; |
| a small local shuffle buffer avoids pathological micro-ordering.""" |
|
|
| def __init__(self, path, seq_len, micro_batch, device, curriculum=True): |
| self.data = np.memmap(path, dtype=np.uint16, mode="r") |
| self.seq_len = seq_len |
| self.micro_batch = micro_batch |
| self.device = device |
| self.curriculum = curriculum |
| self.ptr = 0 |
|
|
| def get_batch(self): |
| span = self.seq_len + 1 |
| need = self.micro_batch |
| if self.curriculum: |
| |
| starts = [self.ptr + i * span for i in range(need)] |
| self.ptr += need * span |
| if self.ptr + need * span >= len(self.data): |
| self.ptr = 0 |
| else: |
| starts = np.random.randint(0, len(self.data) - span, size=need).tolist() |
|
|
| x = np.stack([self.data[s : s + self.seq_len] for s in starts]) |
| y = np.stack([self.data[s + 1 : s + 1 + self.seq_len] for s in starts]) |
| x = torch.from_numpy(x.astype(np.int64)) |
| y = torch.from_numpy(y.astype(np.int64)) |
| return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) |
|
|
|
|
| class RandomDataset: |
| """Stand-in for --smoke runs: random tokens, no files needed.""" |
|
|
| def __init__(self, vocab, seq_len, micro_batch, device): |
| self.vocab, self.seq_len, self.micro_batch, self.device = vocab, seq_len, micro_batch, device |
|
|
| def get_batch(self): |
| x = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device) |
| y = torch.randint(0, self.vocab, (self.micro_batch, self.seq_len), device=self.device) |
| return x, y |
|
|
|
|
| |
| |
| |
| class EMA: |
| """Live exponential moving average of model weights — a continuous version |
| of the checkpoint-averaging trick that reliably nudges final quality up.""" |
|
|
| def __init__(self, model, decay): |
| self.decay = decay |
| self.shadow = deepcopy(model.state_dict()) |
| for v in self.shadow.values(): |
| v.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def update(self, model): |
| for k, v in model.state_dict().items(): |
| if v.dtype.is_floating_point: |
| self.shadow[k].mul_(self.decay).add_(v, alpha=1 - self.decay) |
| else: |
| self.shadow[k].copy_(v) |
|
|
|
|
| |
| |
| |
| def main(smoke=False, resume=None): |
| cfg = TrainConfig() |
| if smoke: |
| cfg.total_steps = 50 |
| cfg.eval_interval = 25 |
| cfg.eval_iters = 5 |
| cfg.ckpt_interval = 9999 |
| cfg.warmup_steps = 5 |
| cfg.micro_batch = 4 |
| cfg.grad_accum = 2 |
| cfg.seq_len = 128 |
|
|
| torch.manual_seed(cfg.seed) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| use_amp = device == "cuda" |
| print(f"[train] device={device} amp(bf16)={use_amp} smoke={smoke}") |
|
|
| mcfg = IvmeConfig(max_seq_len=cfg.seq_len, attn_backend=cfg.attn_backend) |
| model = IvmeConversate(mcfg).to(device) |
| print(f"[train] model params: {model.num_params()/1e6:.1f}M") |
|
|
| muon, adamw = build_optimizers( |
| model, muon_lr=cfg.muon_lr, adamw_lr=cfg.adamw_lr, weight_decay=cfg.weight_decay |
| ) |
| ema = EMA(model, cfg.ema_decay) |
|
|
| if smoke: |
| train_ds = RandomDataset(mcfg.vocab_size, cfg.seq_len, cfg.micro_batch, device) |
| val_ds = train_ds |
| else: |
| train_ds = BinDataset(os.path.join(cfg.data_dir, "train.bin"), |
| cfg.seq_len, cfg.micro_batch, device, curriculum=True) |
| val_ds = BinDataset(os.path.join(cfg.data_dir, "val.bin"), |
| cfg.seq_len, cfg.micro_batch, device, curriculum=False) |
|
|
| os.makedirs(cfg.out_dir, exist_ok=True) |
|
|
| |
| start_step = 0 |
| if resume: |
| print(f"[resume] loading {resume}") |
| ckpt = torch.load(resume, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| ema.shadow = ckpt["ema"] |
| start_step = ckpt.get("step", 0) |
| |
| |
| if "muon" in ckpt and "adamw" in ckpt: |
| muon.load_state_dict(ckpt["muon"]) |
| adamw.load_state_dict(ckpt["adamw"]) |
| print(f"[resume] restored optimizer states") |
| else: |
| print("[resume] WARNING: checkpoint has no optimizer state — " |
| "Muon/AdamW restart cold (a brief loss bump for ~20-50 steps is normal)") |
| |
| |
| if not smoke: |
| train_ds.ptr = start_step * cfg.grad_accum * cfg.micro_batch * (cfg.seq_len + 1) |
| if train_ds.ptr >= len(train_ds.data): |
| train_ds.ptr = 0 |
| print(f"[resume] data pointer -> token {train_ds.ptr:,} " |
| f"(resuming at step {start_step})") |
|
|
| amp_ctx = (torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| if use_amp else torch.autocast(device_type="cpu", enabled=False)) |
|
|
| @torch.no_grad() |
| def evaluate(): |
| model.eval() |
| losses = [] |
| for _ in range(cfg.eval_iters): |
| x, y = val_ds.get_batch() |
| with amp_ctx: |
| _, loss = model(x, y) |
| losses.append(loss.item()) |
| model.train() |
| return sum(losses) / len(losses) |
|
|
| model.train() |
| t0 = time.time() |
| tokens_seen = 0 |
|
|
| for step in range(start_step, cfg.total_steps): |
| |
| mult = wsd_lr_multiplier(step, cfg.total_steps, cfg.warmup_steps, cfg.decay_frac) |
| for g in muon.param_groups: |
| g["lr"] = cfg.muon_lr * mult |
| for g in adamw.param_groups: |
| g["lr"] = cfg.adamw_lr * mult |
|
|
| muon.zero_grad(set_to_none=True) |
| adamw.zero_grad(set_to_none=True) |
|
|
| accum_loss = 0.0 |
| for _ in range(cfg.grad_accum): |
| x, y = train_ds.get_batch() |
| with amp_ctx: |
| _, loss = model(x, y) |
| loss = loss / cfg.grad_accum |
| loss.backward() |
| accum_loss += loss.item() |
| tokens_seen += x.numel() |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| muon.step() |
| adamw.step() |
| ema.update(model) |
|
|
| if step % 10 == 0: |
| dt = time.time() - t0 |
| tps = tokens_seen / max(dt, 1e-6) |
| print(f"step {step:>5}/{cfg.total_steps} | loss {accum_loss:.4f} " |
| f"| lr_mult {mult:.3f} | {tps/1e3:.0f}K tok/s | {tokens_seen/1e6:.1f}M tok") |
|
|
| if step > 0 and step % cfg.eval_interval == 0: |
| vloss = evaluate() |
| print(f" [eval] step {step}: val_loss {vloss:.4f} | val_ppl {math.exp(vloss):.2f}") |
|
|
| if step > 0 and step % cfg.ckpt_interval == 0: |
| path = os.path.join(cfg.out_dir, f"ivme_step{step}.pt") |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, |
| "muon": muon.state_dict(), "adamw": adamw.state_dict(), |
| "cfg": mcfg, "step": step}, path) |
| print(f" [ckpt] saved {path}") |
|
|
| |
| final = os.path.join(cfg.out_dir, "ivme_final.pt") |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, "cfg": mcfg, |
| "step": cfg.total_steps}, final) |
| print(f"[train] done in {(time.time()-t0):.1f}s | final -> {final}") |
|
|
|
|
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--smoke", action="store_true") |
| ap.add_argument("--resume", type=str, default=None, |
| help="path to a checkpoint .pt to resume from") |
| args = ap.parse_args() |
| main(smoke=args.smoke, resume=args.resume) |