| """tilelli.core.tilelli_lite β clean 3-pathway block designed to beat a same-size vanilla baseline. |
| |
| A prior 6-pathway variant of this architecture (~10.6M params) tied vanilla on |
| TinyStories byte-LM (mean 0.5737 vs vanilla 0.5707). Internal audit attributed |
| the tie to fragmentation: parameter budget was spent on pathways the byte-LM |
| data did not reward (an indexed-knowledge slot, a wide convolution, and a |
| non-selective state-space path). |
| |
| Tilelli Lite cuts those underperforming slots and keeps the lessons that DO |
| show up at 10M scale: heterogeneous pathways with a learned router, and a |
| ternary-capable forward pass for inference. This module is a sibling to the |
| larger 5/6-pathway block (kept intact for non-byte-LM workloads); it is not |
| a drop-in replacement. |
| |
| 3-pathway block: |
| - Local conv k=5 (n-grams; strictly more efficient than attention here) |
| - Sparse causal attention with multi-head (8 heads, d_head=48 by default) |
| - Dense FFN with expand=4 (matches vanilla's FFN ratio) |
| |
| Other lessons folded in from the prior block's audit: |
| - Learned positional embedding (recovers the position signal lost by |
| the previous unembedding-only design) |
| - Load-balance auxiliary loss properly wired through the router head |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from tilelli.core.sparse_attention import SparseCausalAttention |
| from tilelli.core.ternary_conv import TernaryCausalConv1d |
| from tilelli.core.ternary_linear import TernaryLinear |
|
|
|
|
| PATHWAY_NAMES_LITE = ("local", "sparse", "dense") |
|
|
|
|
| class TernaryFFN_Lite(nn.Module): |
| """Wider FFN at expand=4 (matches vanilla's ratio).""" |
|
|
| def __init__(self, d_model: int, expand: int = 4, quantize: bool = True) -> None: |
| super().__init__() |
| d_inner = d_model * expand |
| self.up = TernaryLinear(d_model, d_inner, quantize=quantize) |
| self.down = TernaryLinear(d_inner, d_model, quantize=quantize) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.down(torch.nn.functional.gelu(self.up(x))) |
|
|
|
|
| class TilelliLiteBlock(nn.Module): |
| """3-pathway block: Local conv + Sparse multi-head attn + Dense FFN. |
| |
| All pathways always fire; per-token soft router mixes them. Load-balance |
| aux loss penalizes router collapse to one pathway. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int = 8, |
| kernel_size: int = 5, |
| top_k: int = 16, |
| ffn_expand: int = 4, |
| quantize: bool = True, |
| load_balance_weight: float = 0.01, |
| ) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.n_pathways = 3 |
| self.load_balance_weight = load_balance_weight |
|
|
| |
| |
| d_head = d_model // n_heads |
| if d_model % n_heads != 0: |
| raise ValueError(f"d_model {d_model} must divide n_heads {n_heads}") |
|
|
| self.norm = nn.LayerNorm(d_model) |
| self.local = TernaryCausalConv1d(d_model, kernel_size=kernel_size, quantize=quantize) |
| |
| |
| self.sparse_heads = nn.ModuleList([ |
| SparseCausalAttention(d_model, d_head=d_head, top_k=top_k) |
| for _ in range(n_heads) |
| ]) |
| self.sparse_proj = TernaryLinear(d_model, d_model, quantize=quantize) |
| self.dense = TernaryFFN_Lite(d_model, expand=ffn_expand, quantize=quantize) |
|
|
| self.router = TernaryLinear(d_model, self.n_pathways, quantize=quantize) |
| self._aux_loss = torch.tensor(0.0) |
|
|
| def _multi_head_sparse(self, h: Tensor) -> Tensor: |
| """Concat outputs of n_heads single-head Sparse attentions, project.""" |
| |
| |
| |
| head_outs = [h_mod(h) for h_mod in self.sparse_heads] |
| |
| |
| merged = torch.stack(head_outs, dim=0).mean(dim=0) |
| return self.sparse_proj(merged) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| h = self.norm(x) |
| r = torch.softmax(self.router(h), dim=-1) |
|
|
| out_local = self.local(h) |
| out_sparse = self._multi_head_sparse(h) |
| out_dense = self.dense(h) |
|
|
| mixed = ( |
| r[..., 0:1] * out_local |
| + r[..., 1:2] * out_sparse |
| + r[..., 2:3] * out_dense |
| ) |
|
|
| |
| pathway_use = r.mean(dim=(0, 1)) |
| target = 1.0 / self.n_pathways |
| self._aux_loss = ((pathway_use - target) ** 2).mean() * self.load_balance_weight |
|
|
| |
| |
| |
| |
| self._router_entropy = -(r * (r + 1e-12).log()).sum(dim=-1) |
|
|
| return x + mixed |
|
|
| @property |
| def aux_loss(self) -> Tensor: |
| return self._aux_loss |
|
|
| @torch.no_grad() |
| def router_weights(self, x: Tensor) -> Tensor: |
| h = self.norm(x) |
| return torch.softmax(self.router(h), dim=-1) |
|
|
| @torch.no_grad() |
| def router_entropy(self, x: Tensor) -> Tensor: |
| """Per-token entropy of router distribution. Low β committed to one |
| pathway (high confidence). High β uncertain mix.""" |
| r = self.router_weights(x) |
| return -(r * (r + 1e-12).log()).sum(dim=-1) |
|
|
| |
| |
| |
| |
|
|
| def empty_cache(self, batch_size: int, device, dtype) -> dict: |
| return { |
| "conv_buffer": self.local.empty_buffer(batch_size, device, dtype), |
| "sparse_caches": [h.empty_cache(batch_size, device, dtype) |
| for h in self.sparse_heads], |
| } |
|
|
| def warmup_cache(self, x: Tensor) -> dict: |
| """Build the cache from a full-prompt input x (B, L, D) β the SAME x |
| that was fed to forward() during prompt processing. This is what the |
| norm-then-pathway view sees, so we pass `h = self.norm(x)` here.""" |
| h = self.norm(x) |
| return { |
| "conv_buffer": self.local.warmup_buffer(h), |
| "sparse_caches": [head.warmup_cache(h) for head in self.sparse_heads], |
| } |
|
|
| def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]: |
| """One-token step through the block. Returns (out_step, new_cache). |
| out_step is the new residual contribution + x (so caller doesn't need |
| to re-add the residual).""" |
| h = self.norm(x_step) |
| r = torch.softmax(self.router(h), dim=-1) |
|
|
| |
| out_local, new_conv_buf = self.local.forward_incremental(h, cache["conv_buffer"]) |
|
|
| |
| head_outs = [] |
| new_sparse_caches = [] |
| for head, hc in zip(self.sparse_heads, cache["sparse_caches"]): |
| y_h, hc_new = head.forward_incremental(h, hc) |
| head_outs.append(y_h) |
| new_sparse_caches.append(hc_new) |
| merged = torch.stack(head_outs, dim=0).mean(dim=0) |
| out_sparse = self.sparse_proj(merged) |
|
|
| |
| out_dense = self.dense(h) |
|
|
| mixed = ( |
| r[..., 0:1] * out_local |
| + r[..., 1:2] * out_sparse |
| + r[..., 2:3] * out_dense |
| ) |
| new_cache = { |
| "conv_buffer": new_conv_buf, |
| "sparse_caches": new_sparse_caches, |
| } |
| return x_step + mixed, new_cache |
|
|
|
|
| class TernaryEmbeddingLite(nn.Module): |
| """Token id β ternary vector. Embedding weights are quantized to {-1,0,+1} with a per-tensor scale at forward time.""" |
|
|
| def __init__(self, vocab_size: int, d_model: int, quantize: bool = True) -> None: |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.quantize = quantize |
| w = torch.randn(vocab_size, d_model) * (1.0 / d_model**0.5) |
| self.weight = nn.Parameter(w) |
|
|
| def forward(self, ids: Tensor) -> Tensor: |
| if self.quantize: |
| from tilelli.core.ternary import ternarize |
| w_q = ternarize(self.weight) |
| else: |
| w_q = self.weight |
| return w_q[ids] |
|
|
|
|
| class TilelliLiteLM(nn.Module): |
| """Byte-level LM with TilelliLiteBlock stack + learned positional embed.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 256, |
| d_model: int = 384, |
| n_layers: int = 8, |
| n_heads: int = 8, |
| kernel_size: int = 5, |
| top_k: int = 16, |
| ffn_expand: int = 4, |
| max_seq_len: int = 2048, |
| quantize: bool = True, |
| load_balance_weight: float = 0.01, |
| ) -> None: |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.max_seq_len = max_seq_len |
| self.quantize = quantize |
|
|
| self.embed = TernaryEmbeddingLite(vocab_size, d_model, quantize=quantize) |
| |
| |
| self.pos_embed = nn.Embedding(max_seq_len, d_model) |
| nn.init.normal_(self.pos_embed.weight, std=0.02) |
|
|
| self.blocks = nn.ModuleList([ |
| TilelliLiteBlock( |
| d_model=d_model, n_heads=n_heads, kernel_size=kernel_size, |
| top_k=top_k, ffn_expand=ffn_expand, quantize=quantize, |
| load_balance_weight=load_balance_weight, |
| ) |
| for _ in range(n_layers) |
| ]) |
|
|
| self.final_norm = nn.LayerNorm(d_model) |
| self.unembed = TernaryLinear(d_model, vocab_size, quantize=quantize) |
|
|
| def forward(self, ids: Tensor) -> Tensor: |
| L = ids.size(1) |
| if L > self.max_seq_len: |
| raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}") |
| x = self.embed(ids) |
| pos = torch.arange(L, device=ids.device) |
| x = x + self.pos_embed(pos) |
| for blk in self.blocks: |
| x = blk(x) |
| x = self.final_norm(x) |
| return self.unembed(x) |
|
|
| def loss(self, ids: Tensor, targets: Tensor | None = None) -> Tensor: |
| """Autoregressive next-token loss + load-balance aux. |
| |
| Compatible with both the (ids,) "shift internally" convention and the |
| (ids, targets) "caller-supplied targets" convention. If targets is None |
| we shift ids ourselves; otherwise we trust the caller (train.py-style). |
| """ |
| if targets is None: |
| if ids.size(1) < 2: |
| raise ValueError("loss needs sequence length >= 2") |
| inp = ids[:, :-1] |
| tgt = ids[:, 1:] |
| else: |
| inp, tgt = ids, targets |
| logits = self(inp) |
| ce = torch.nn.functional.cross_entropy( |
| logits.reshape(-1, self.vocab_size), |
| tgt.reshape(-1), |
| ) |
| aux = sum(blk.aux_loss for blk in self.blocks) |
| return ce + aux |
|
|
| @torch.no_grad() |
| def router_entropies(self, ids: Tensor) -> Tensor: |
| """Per-layer router entropy, shape (n_layers, B, L).""" |
| x = self.embed(ids) |
| pos = torch.arange(ids.size(1), device=ids.device) |
| x = x + self.pos_embed(pos) |
| ents = [] |
| for blk in self.blocks: |
| ents.append(blk.router_entropy(x)) |
| x = blk(x) |
| return torch.stack(ents, dim=0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def warmup_caches(self, ids: Tensor) -> tuple[Tensor, list[dict]]: |
| """Run the full prompt forward, build per-layer caches, return the |
| final hidden state at the LAST position (for the first next-token |
| sample) plus the caches. |
| """ |
| L = ids.size(1) |
| if L > self.max_seq_len: |
| raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}") |
| x = self.embed(ids) |
| pos = torch.arange(L, device=ids.device) |
| x = x + self.pos_embed(pos) |
| caches = [] |
| for blk in self.blocks: |
| caches.append(blk.warmup_cache(x)) |
| x = blk(x) |
| return x, caches |
|
|
| @torch.no_grad() |
| def step_with_cache(self, next_id: Tensor, pos_index: int, |
| caches: list[dict]) -> tuple[Tensor, list[dict]]: |
| """Forward ONE new token (B, 1) at absolute position pos_index. Uses |
| + updates the per-layer caches in-place-ish (returns new list).""" |
| x = self.embed(next_id) |
| pos = torch.tensor([pos_index], device=next_id.device) |
| x = x + self.pos_embed(pos) |
| new_caches = [] |
| for blk, c in zip(self.blocks, caches): |
| x, c_new = blk.forward_incremental(x, c) |
| new_caches.append(c_new) |
| x = self.final_norm(x) |
| return self.unembed(x), new_caches |
|
|
| @torch.no_grad() |
| def generate_with_cache( |
| self, |
| ids: Tensor, |
| n_new_tokens: int, |
| stop_ids: tuple[int, ...] = (10, 0), |
| return_logits: bool = False, |
| ) -> tuple[Tensor, list[int], list[float]]: |
| """Greedy generate up to n_new_tokens using the KV cache. Returns |
| (full_ids, generated_id_list, confidence_per_step). |
| |
| For non-greedy sampling, callers should use step_with_cache directly. |
| """ |
| was_training = self.training |
| self.eval() |
| try: |
| |
| |
| h_last, caches = self.warmup_caches(ids) |
| h_last_pos = self.final_norm(h_last[:, -1:, :]) |
| logits = self.unembed(h_last_pos) |
| cur_pos = ids.size(1) |
| full = ids |
| generated: list[int] = [] |
| confs: list[float] = [] |
| for _ in range(n_new_tokens): |
| probs = torch.softmax(logits[:, -1, :], dim=-1) |
| next_id = probs.argmax(dim=-1, keepdim=True) |
| nid_int = int(next_id) |
| confs.append(float(probs.max())) |
| generated.append(nid_int) |
| full = torch.cat([full, next_id], dim=1) |
| if nid_int in stop_ids: |
| break |
| if cur_pos + 1 > self.max_seq_len: |
| break |
| logits, caches = self.step_with_cache(next_id, cur_pos, caches) |
| cur_pos += 1 |
| return full, generated, confs |
| finally: |
| if was_training: |
| self.train() |
|
|