"""tilelli.distillery.tokenize — day-0 byte-level tokenizer. Why byte-level: - Zero training. Deterministic. No BPE merges table, no corpus sweep. - Universal coverage: any text, any language, any code, any math symbol fits in 256 ids. Perfect for our four initial sources — English, Python, Ubuntu commands, math — without a single special case. - Aligns with the manifesto's "built from absolute zero" clause. We literally implemented it in twenty lines. - A BPE-style learned tokenizer can replace this later as a Distillery upgrade. Until then, every downstream piece (shard, trainer, probes) works against the byte interface and benefits for free when the tokenizer improves. Limits we accept day-0: - Sequence length in bytes is ~3-4× that of a good BPE tokenizer for English, ~1× for code. This matters for context-window calculations but not for correctness. We're validating the architecture, not pushing tokens/second yet. """ from __future__ import annotations from typing import Iterable import torch from torch import Tensor class ByteTokenizer: """UTF-8 byte-level tokenizer. Vocab size is fixed at 256. encode(text) and decode(ids) are exact inverses for any str input: the encode path is ``text.encode("utf-8")`` and decode is ``bytes(ids).decode("utf-8", errors="replace")``. The ``errors="replace"`` is a conservative default so decode never raises — useful when sampling mid-sequence leaves us with a dangling multi-byte codepoint. """ vocab_size: int = 256 def encode(self, text: str) -> Tensor: """str → 1-D int64 tensor of byte ids. Uses ``torch.frombuffer`` so encoding a 50 MB text doesn't allocate a 1.4 GB Python list of ints on the way through. The ``bytearray`` wrapper is what makes the buffer writable, which ``frombuffer`` requires. """ data = text.encode("utf-8") if not data: return torch.empty(0, dtype=torch.int64) buf = torch.frombuffer(bytearray(data), dtype=torch.uint8) return buf.to(torch.int64) def decode(self, ids: Tensor | Iterable[int]) -> str: """1-D tensor (or iterable of ints) → str.""" if isinstance(ids, Tensor): if ids.dim() != 1: raise ValueError(f"expected 1-D tensor, got shape {tuple(ids.shape)}") ids = ids.tolist() return bytes(int(i) for i in ids).decode("utf-8", errors="replace")