#!/usr/bin/env python3 """Generic text generator — works with both bundled checkpoints. Auto-routes between the two architectures based on the checkpoint config: python infer.py # default: chat with v4 (FP32 chat-SFT'd, deployed) python infer.py --ckpt checkpoints/tilelli_pretrain_v1_ternary.pt --prompt "Once upon a time" For v4 (the deployed chat model), the prompt is wrapped as `USER: ... TILELLI:` automatically unless you pass --raw. For pretrain checkpoints there's no chat format, so the prompt is used verbatim. """ import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent / "src")) import torch from tilelli.utils import safe_load_checkpoint from tilelli.distillery.tokenize import ByteTokenizer def _strip_prefix(state_dict, prefix): return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} def load_model(ckpt_path: str): """Inspect the checkpoint config and instantiate the right model class.""" ckpt = safe_load_checkpoint(ckpt_path, trusted=True) cfg = ckpt.get("base_model_cfg") or ckpt.get("model_cfg") or ckpt.get("config") or {} raw = ckpt.get("model", ckpt) builder = cfg.get("builder", "tilelli_lite") if builder == "tilelli_lite" or "abstain.weight" in raw or "abstain.bias" in raw: # Lite 3-pathway — the deployed chat v4 lives here from tilelli.core.tilelli_lite import TilelliLiteLM model = TilelliLiteLM( vocab_size=cfg.get("vocab_size", 256), d_model=cfg.get("d_model", 256), n_layers=cfg.get("n_layers", 8), n_heads=cfg.get("n_heads", 8), top_k=cfg.get("top_k", 16), ffn_expand=cfg.get("dense_expand", 4), max_seq_len=cfg.get("max_seq_len", 256), quantize=cfg.get("quantize", False), ) base = { k.replace("base.", "", 1): v for k, v in raw.items() if not k.startswith("abstain.") } model.load_state_dict(base, strict=False) kind = "lite" else: # Parent multi-pathway (TilelliLM) — the ternary pretrain lives here from tilelli.core.tilelli_lm import TilelliLM model = TilelliLM( vocab_size=cfg.get("vocab_size", 256), d_model=cfg.get("d_model", 512), n_layers=cfg.get("n_layers", 7), d_head=cfg.get("d_head", 64), top_k=cfg.get("top_k", 8), pathways=cfg.get("pathways", 5), max_seq_len=cfg.get("max_seq_len", 256), quantize=cfg.get("quantize", True), n_banks=cfg.get("n_banks", 1), per_row=cfg.get("per_row", False), hadamard=cfg.get("hadamard", False), lsq=cfg.get("lsq", False), dense_expand=cfg.get("dense_expand", 2), fp_attention=cfg.get("fp_attention", False), ) model.load_state_dict(raw, strict=False) kind = "parent" model.eval() return model, cfg, kind @torch.no_grad() def generate(model, prompt_ids: torch.Tensor, n_new: int = 120, stop_ids=(10, 0)) -> torch.Tensor: """Generic greedy generation that works for both architectures.""" if hasattr(model, "generate_with_cache"): full, _, _ = model.generate_with_cache(prompt_ids, n_new_tokens=n_new, stop_ids=stop_ids) return full if hasattr(model, "generate"): return model.generate(prompt_ids, n_new_tokens=n_new) # Fall back to a slow loop ids = prompt_ids max_ctx = getattr(model, "max_seq_len", 256) for _ in range(n_new): window = ids[:, -max_ctx:] logits = model(window) if logits.ndim == 3: logits = logits[:, -1, :] nxt = logits.argmax(dim=-1, keepdim=True) ids = torch.cat([ids, nxt], dim=1) if int(nxt) in stop_ids: break return ids def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="checkpoints/tilelli_chat_v4.pt", help="Checkpoint to load. Default = the FP32 chat-SFT'd v4 (deployed).") ap.add_argument("--prompt", default=None, help="Text to continue. For v4 it gets wrapped as USER:/TILELLI:.") ap.add_argument("--raw", action="store_true", help="Skip the USER:/TILELLI: wrapping (treat prompt as continuation seed).") ap.add_argument("--max-new", type=int, default=120) args = ap.parse_args() tok = ByteTokenizer() model, cfg, kind = load_model(args.ckpt) n_params = sum(p.numel() for p in model.parameters()) print( f"[infer] {args.ckpt}", f"({kind}, {n_params/1e6:.2f}M params, quantize={cfg.get('quantize')})", file=sys.stderr, ) prompt = args.prompt or ("Hello, who are you?" if kind == "lite" else "Once upon a time") if kind == "lite" and not args.raw: prompt = f"USER: {prompt}\nTILELLI:" ids = tok.encode(prompt).long().unsqueeze(0) out = generate(model, ids, n_new=args.max_new) text = tok.decode(out[0].tolist()) print(text) if __name__ == "__main__": main()