Tilelli-llm / src /tilelli /distillery /tokenize.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
2.51 kB
"""tilelli.distillery.tokenize β€” day-0 byte-level tokenizer.
Why byte-level:
- Zero training. Deterministic. No BPE merges table, no corpus sweep.
- Universal coverage: any text, any language, any code, any math symbol
fits in 256 ids. Perfect for our four initial sources β€” English,
Python, Ubuntu commands, math β€” without a single special case.
- Aligns with the manifesto's "built from absolute zero" clause. We
literally implemented it in twenty lines.
- A BPE-style learned tokenizer can replace this later as a Distillery
upgrade. Until then, every downstream piece (shard, trainer,
probes) works against the byte interface and benefits for free when
the tokenizer improves.
Limits we accept day-0:
- Sequence length in bytes is ~3-4Γ— that of a good BPE tokenizer for
English, ~1Γ— for code. This matters for context-window calculations
but not for correctness. We're validating the architecture, not
pushing tokens/second yet.
"""
from __future__ import annotations
from typing import Iterable
import torch
from torch import Tensor
class ByteTokenizer:
"""UTF-8 byte-level tokenizer. Vocab size is fixed at 256.
encode(text) and decode(ids) are exact inverses for any str input:
the encode path is ``text.encode("utf-8")`` and decode is
``bytes(ids).decode("utf-8", errors="replace")``. The ``errors="replace"``
is a conservative default so decode never raises β€” useful when
sampling mid-sequence leaves us with a dangling multi-byte
codepoint.
"""
vocab_size: int = 256
def encode(self, text: str) -> Tensor:
"""str β†’ 1-D int64 tensor of byte ids.
Uses ``torch.frombuffer`` so encoding a 50 MB text doesn't
allocate a 1.4 GB Python list of ints on the way through.
The ``bytearray`` wrapper is what makes the buffer writable,
which ``frombuffer`` requires.
"""
data = text.encode("utf-8")
if not data:
return torch.empty(0, dtype=torch.int64)
buf = torch.frombuffer(bytearray(data), dtype=torch.uint8)
return buf.to(torch.int64)
def decode(self, ids: Tensor | Iterable[int]) -> str:
"""1-D tensor (or iterable of ints) β†’ str."""
if isinstance(ids, Tensor):
if ids.dim() != 1:
raise ValueError(f"expected 1-D tensor, got shape {tuple(ids.shape)}")
ids = ids.tolist()
return bytes(int(i) for i in ids).decode("utf-8", errors="replace")