#!/usr/bin/env python3 """scripts/train_demo.py — minimal Tilelli demo trainer. Trains a tiny TilelliLM on a small text file. Useful as a smoke test that the stack composes end-to-end. Not a serious training recipe — see PAPER.md for the full setup. Usage: python scripts/train_demo.py --data path/to/text.txt --steps 1000 \ --d-model 128 --n-layers 4 --output checkpoints/tilelli_demo.pt """ from __future__ import annotations import argparse import time from pathlib import Path import torch from tilelli.core.tilelli_lm import TilelliLM from tilelli.distillery.tokenize import ByteTokenizer from tilelli.utils.runtime import ThermalGuard, polite_training def load_data(path: Path, tokenizer: ByteTokenizer, seq_len: int) -> torch.Tensor: text = path.read_text(encoding="utf-8", errors="replace") print(f"data: {len(text):,} chars from {path}") ids = tokenizer.encode(text) n_chunks = ids.numel() // seq_len return ids[: n_chunks * seq_len].view(n_chunks, seq_len) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--data", type=Path, required=True) ap.add_argument("--steps", type=int, default=1000) ap.add_argument("--seq-len", type=int, default=256) ap.add_argument("--batch-size", type=int, default=4) ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--d-model", type=int, default=128) ap.add_argument("--n-layers", type=int, default=4) ap.add_argument("--d-head", type=int, default=32) ap.add_argument("--top-k", type=int, default=8) ap.add_argument("--output", type=Path, default=Path("checkpoints/tilelli_demo.pt")) args = ap.parse_args() tok = ByteTokenizer() data = load_data(args.data, tok, args.seq_len) print(f"chunks: {data.size(0):,} of {args.seq_len}") model = TilelliLM( vocab_size=256, d_model=args.d_model, n_layers=args.n_layers, d_head=args.d_head, top_k=args.top_k, max_seq_len=args.seq_len, ) print(f"params: {model.parameter_count():,}") optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) sched = torch.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=args.steps, eta_min=args.lr * 0.01 ) guard = ThermalGuard() model.train() t0 = time.time() best_loss = float("inf") for step in range(args.steps): guard.maybe_throttle(step) idx = torch.randint(0, data.size(0), (args.batch_size,)) chunk = data[idx] loss = model.loss(chunk[:, :-1], chunk[:, 1:]) optim.zero_grad() loss.backward() optim.step() sched.step() if loss.item() < best_loss: best_loss = loss.item() if step % 50 == 0: print(f"step {step:5d} loss {loss.item():.4f} best {best_loss:.4f}") polite_training() args.output.parent.mkdir(parents=True, exist_ok=True) torch.save({"model": model.state_dict(), "config": vars(args)}, args.output) print(f"saved to {args.output} after {time.time() - t0:.1f}s; best loss {best_loss:.4f}") if __name__ == "__main__": main()