| |
| """scripts/train_demo.py — minimal Tilelli demo trainer. |
| |
| Trains a tiny TilelliLM on a small text file. Useful as a smoke |
| test that the stack composes end-to-end. Not a serious training |
| recipe — see PAPER.md for the full setup. |
| |
| Usage: |
| python scripts/train_demo.py --data path/to/text.txt --steps 1000 \ |
| --d-model 128 --n-layers 4 --output checkpoints/tilelli_demo.pt |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| from tilelli.core.tilelli_lm import TilelliLM |
| from tilelli.distillery.tokenize import ByteTokenizer |
| from tilelli.utils.runtime import ThermalGuard, polite_training |
|
|
|
|
| def load_data(path: Path, tokenizer: ByteTokenizer, seq_len: int) -> torch.Tensor: |
| text = path.read_text(encoding="utf-8", errors="replace") |
| print(f"data: {len(text):,} chars from {path}") |
| ids = tokenizer.encode(text) |
| n_chunks = ids.numel() // seq_len |
| return ids[: n_chunks * seq_len].view(n_chunks, seq_len) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data", type=Path, required=True) |
| ap.add_argument("--steps", type=int, default=1000) |
| ap.add_argument("--seq-len", type=int, default=256) |
| ap.add_argument("--batch-size", type=int, default=4) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--d-model", type=int, default=128) |
| ap.add_argument("--n-layers", type=int, default=4) |
| ap.add_argument("--d-head", type=int, default=32) |
| ap.add_argument("--top-k", type=int, default=8) |
| ap.add_argument("--output", type=Path, default=Path("checkpoints/tilelli_demo.pt")) |
| args = ap.parse_args() |
|
|
| tok = ByteTokenizer() |
| data = load_data(args.data, tok, args.seq_len) |
| print(f"chunks: {data.size(0):,} of {args.seq_len}") |
|
|
| model = TilelliLM( |
| vocab_size=256, |
| d_model=args.d_model, |
| n_layers=args.n_layers, |
| d_head=args.d_head, |
| top_k=args.top_k, |
| max_seq_len=args.seq_len, |
| ) |
| print(f"params: {model.parameter_count():,}") |
|
|
| optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optim, T_max=args.steps, eta_min=args.lr * 0.01 |
| ) |
| guard = ThermalGuard() |
|
|
| model.train() |
| t0 = time.time() |
| best_loss = float("inf") |
| for step in range(args.steps): |
| guard.maybe_throttle(step) |
| idx = torch.randint(0, data.size(0), (args.batch_size,)) |
| chunk = data[idx] |
| loss = model.loss(chunk[:, :-1], chunk[:, 1:]) |
| optim.zero_grad() |
| loss.backward() |
| optim.step() |
| sched.step() |
| if loss.item() < best_loss: |
| best_loss = loss.item() |
| if step % 50 == 0: |
| print(f"step {step:5d} loss {loss.item():.4f} best {best_loss:.4f}") |
| polite_training() |
|
|
| args.output.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({"model": model.state_dict(), "config": vars(args)}, args.output) |
| print(f"saved to {args.output} after {time.time() - t0:.1f}s; best loss {best_loss:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|