| """tilelli.core.tilelli_lm — minimal byte-level language model built on |
| ternary primitives + heterogeneous-pathway blocks. |
| |
| Stacks TilelliBlock layers on top of a byte embedding and a ternary |
| unembedding, plus a learned positional embedding. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
| from tilelli.core.ternary_linear import TernaryLinear |
| from tilelli.core.tilelli_block import TilelliBlock |
|
|
|
|
| class TilelliLM(nn.Module): |
| """Byte-level Tilelli language model.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 256, |
| d_model: int = 128, |
| n_layers: int = 4, |
| d_head: int = 32, |
| top_k: int = 8, |
| pathways: int = 5, |
| max_seq_len: int = 512, |
| quantize: bool = True, |
| n_banks: int = 1, |
| per_row: bool = False, |
| hadamard: bool = False, |
| lsq: bool = False, |
| dense_expand: int = 2, |
| fp_attention: bool = False, |
| top_k_routing: int = 0, |
| ) -> None: |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.quantize = quantize |
| self.n_banks = n_banks |
| self.per_row = per_row |
| self.hadamard = hadamard |
| self.lsq = lsq |
| self.dense_expand = dense_expand |
| self.fp_attention = fp_attention |
| self.top_k_routing = top_k_routing |
|
|
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(max_seq_len, d_model) |
| self.blocks = nn.ModuleList( |
| [ |
| TilelliBlock( |
| d_model=d_model, |
| d_head=d_head, |
| top_k=top_k, |
| pathways=pathways, |
| n_banks=n_banks, |
| quantize=quantize, |
| per_row=per_row, |
| hadamard=hadamard, |
| lsq=lsq, |
| dense_expand=dense_expand, |
| fp_attention=fp_attention, |
| top_k_routing=top_k_routing, |
| ) |
| for _ in range(n_layers) |
| ] |
| ) |
| self.norm_out = nn.LayerNorm(d_model) |
| self.unembed = TernaryLinear( |
| d_model, vocab_size, |
| quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, |
| ) |
|
|
| def forward(self, ids: Tensor) -> Tensor: |
| if ids.dim() != 2: |
| raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") |
| B, L = ids.shape |
| if L > self.max_seq_len: |
| raise ValueError(f"sequence length {L} exceeds max_seq_len {self.max_seq_len}") |
| positions = torch.arange(L, device=ids.device) |
| x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] |
| for block in self.blocks: |
| x = block(x) |
| x = self.norm_out(x) |
| return self.unembed(x) |
|
|
| @property |
| def aux_loss(self) -> Tensor: |
| """Sum of per-block load-balancing aux losses. Zero when n_banks=1.""" |
| if self.n_banks <= 1: |
| return torch.tensor(0.0, device=self.token_emb.weight.device) |
| return sum(b.aux_loss for b in self.blocks) |
|
|
| def loss(self, ids: Tensor, targets: Tensor) -> Tensor: |
| """Cross-entropy loss + load-balance aux when banking is on.""" |
| logits = self.forward(ids) |
| ce = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1)) |
| if self.n_banks > 1: |
| return ce + self.aux_loss |
| return ce |
|
|
| @torch.no_grad() |
| def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor: |
| was_training = self.training |
| self.eval() |
| try: |
| for _ in range(n_new_tokens): |
| ids_in = ids[:, -self.max_seq_len:] |
| logits = self.forward(ids_in)[:, -1, :] |
| next_id = logits.argmax(dim=-1, keepdim=True) |
| ids = torch.cat([ids, next_id], dim=1) |
| return ids |
| finally: |
| if was_training: |
| self.train() |
|
|
| @torch.no_grad() |
| def router_entropies(self, ids: Tensor) -> list[Tensor]: |
| if ids.dim() != 2: |
| raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") |
| positions = torch.arange(ids.size(1), device=ids.device) |
| x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] |
| out = [] |
| for block in self.blocks: |
| out.append(block.router_entropy(x)) |
| x = block(x) |
| return out |
|
|
| def parameter_count(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|