"""tilelli.core.tilelli_lm — minimal byte-level language model built on ternary primitives + heterogeneous-pathway blocks. Stacks TilelliBlock layers on top of a byte embedding and a ternary unembedding, plus a learned positional embedding. """ from __future__ import annotations import torch from torch import Tensor, nn from torch.nn import functional as F from tilelli.core.ternary_linear import TernaryLinear from tilelli.core.tilelli_block import TilelliBlock class TilelliLM(nn.Module): """Byte-level Tilelli language model.""" def __init__( self, vocab_size: int = 256, d_model: int = 128, n_layers: int = 4, d_head: int = 32, top_k: int = 8, pathways: int = 5, max_seq_len: int = 512, quantize: bool = True, n_banks: int = 1, per_row: bool = False, hadamard: bool = False, lsq: bool = False, dense_expand: int = 2, fp_attention: bool = False, top_k_routing: int = 0, ) -> None: super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.quantize = quantize self.n_banks = n_banks self.per_row = per_row self.hadamard = hadamard self.lsq = lsq self.dense_expand = dense_expand self.fp_attention = fp_attention self.top_k_routing = top_k_routing self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(max_seq_len, d_model) self.blocks = nn.ModuleList( [ TilelliBlock( d_model=d_model, d_head=d_head, top_k=top_k, pathways=pathways, n_banks=n_banks, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, dense_expand=dense_expand, fp_attention=fp_attention, top_k_routing=top_k_routing, ) for _ in range(n_layers) ] ) self.norm_out = nn.LayerNorm(d_model) self.unembed = TernaryLinear( d_model, vocab_size, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, ) def forward(self, ids: Tensor) -> Tensor: if ids.dim() != 2: raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") B, L = ids.shape if L > self.max_seq_len: raise ValueError(f"sequence length {L} exceeds max_seq_len {self.max_seq_len}") positions = torch.arange(L, device=ids.device) x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] for block in self.blocks: x = block(x) x = self.norm_out(x) return self.unembed(x) @property def aux_loss(self) -> Tensor: """Sum of per-block load-balancing aux losses. Zero when n_banks=1.""" if self.n_banks <= 1: return torch.tensor(0.0, device=self.token_emb.weight.device) return sum(b.aux_loss for b in self.blocks) def loss(self, ids: Tensor, targets: Tensor) -> Tensor: """Cross-entropy loss + load-balance aux when banking is on.""" logits = self.forward(ids) ce = F.cross_entropy(logits.reshape(-1, self.vocab_size), targets.reshape(-1)) if self.n_banks > 1: return ce + self.aux_loss return ce @torch.no_grad() def generate(self, ids: Tensor, n_new_tokens: int) -> Tensor: was_training = self.training self.eval() try: for _ in range(n_new_tokens): ids_in = ids[:, -self.max_seq_len:] logits = self.forward(ids_in)[:, -1, :] next_id = logits.argmax(dim=-1, keepdim=True) ids = torch.cat([ids, next_id], dim=1) return ids finally: if was_training: self.train() @torch.no_grad() def router_entropies(self, ids: Tensor) -> list[Tensor]: if ids.dim() != 2: raise ValueError(f"expected (B, L), got shape {tuple(ids.shape)}") positions = torch.arange(ids.size(1), device=ids.device) x = self.token_emb(ids) + self.pos_emb(positions)[None, :, :] out = [] for block in self.blocks: out.append(block.router_entropy(x)) x = block(x) return out def parameter_count(self) -> int: return sum(p.numel() for p in self.parameters())