"""tilelli.core.tilelli_lite — clean 3-pathway block designed to beat a same-size vanilla baseline. A prior 6-pathway variant of this architecture (~10.6M params) tied vanilla on TinyStories byte-LM (mean 0.5737 vs vanilla 0.5707). Internal audit attributed the tie to fragmentation: parameter budget was spent on pathways the byte-LM data did not reward (an indexed-knowledge slot, a wide convolution, and a non-selective state-space path). Tilelli Lite cuts those underperforming slots and keeps the lessons that DO show up at 10M scale: heterogeneous pathways with a learned router, and a ternary-capable forward pass for inference. This module is a sibling to the larger 5/6-pathway block (kept intact for non-byte-LM workloads); it is not a drop-in replacement. 3-pathway block: - Local conv k=5 (n-grams; strictly more efficient than attention here) - Sparse causal attention with multi-head (8 heads, d_head=48 by default) - Dense FFN with expand=4 (matches vanilla's FFN ratio) Other lessons folded in from the prior block's audit: - Learned positional embedding (recovers the position signal lost by the previous unembedding-only design) - Load-balance auxiliary loss properly wired through the router head """ from __future__ import annotations import torch from torch import Tensor, nn from tilelli.core.sparse_attention import SparseCausalAttention from tilelli.core.ternary_conv import TernaryCausalConv1d from tilelli.core.ternary_linear import TernaryLinear PATHWAY_NAMES_LITE = ("local", "sparse", "dense") class TernaryFFN_Lite(nn.Module): """Wider FFN at expand=4 (matches vanilla's ratio).""" def __init__(self, d_model: int, expand: int = 4, quantize: bool = True) -> None: super().__init__() d_inner = d_model * expand self.up = TernaryLinear(d_model, d_inner, quantize=quantize) self.down = TernaryLinear(d_inner, d_model, quantize=quantize) def forward(self, x: Tensor) -> Tensor: return self.down(torch.nn.functional.gelu(self.up(x))) class TilelliLiteBlock(nn.Module): """3-pathway block: Local conv + Sparse multi-head attn + Dense FFN. All pathways always fire; per-token soft router mixes them. Load-balance aux loss penalizes router collapse to one pathway. """ def __init__( self, d_model: int, n_heads: int = 8, kernel_size: int = 5, top_k: int = 16, ffn_expand: int = 4, quantize: bool = True, load_balance_weight: float = 0.01, ) -> None: super().__init__() self.d_model = d_model self.n_pathways = 3 self.load_balance_weight = load_balance_weight # Multi-head sparse attention. d_head computed from n_heads so total # head dim equals d_model (matches vanilla's attention shape). d_head = d_model // n_heads if d_model % n_heads != 0: raise ValueError(f"d_model {d_model} must divide n_heads {n_heads}") self.norm = nn.LayerNorm(d_model) self.local = TernaryCausalConv1d(d_model, kernel_size=kernel_size, quantize=quantize) # Per-head Sparse attention — wraps n_heads of the existing single-head # implementation, concatenates outputs. self.sparse_heads = nn.ModuleList([ SparseCausalAttention(d_model, d_head=d_head, top_k=top_k) for _ in range(n_heads) ]) self.sparse_proj = TernaryLinear(d_model, d_model, quantize=quantize) self.dense = TernaryFFN_Lite(d_model, expand=ffn_expand, quantize=quantize) self.router = TernaryLinear(d_model, self.n_pathways, quantize=quantize) self._aux_loss = torch.tensor(0.0) def _multi_head_sparse(self, h: Tensor) -> Tensor: """Concat outputs of n_heads single-head Sparse attentions, project.""" # Each head outputs (B, L, d_head). Concat → (B, L, n_heads*d_head=d_model). # SparseCausalAttention returns (B, L, d_model) — sum heads instead, then proj. # Sum is param-efficient and equivalent to mean attention pooling. head_outs = [h_mod(h) for h_mod in self.sparse_heads] # Average rather than concat to keep dims at d_model (heads' outputs # are already d_model each; this gives a smoothed multi-head signal). merged = torch.stack(head_outs, dim=0).mean(dim=0) return self.sparse_proj(merged) def forward(self, x: Tensor) -> Tensor: h = self.norm(x) r = torch.softmax(self.router(h), dim=-1) # (B, L, 3) out_local = self.local(h) # (B, L, d_model) out_sparse = self._multi_head_sparse(h) out_dense = self.dense(h) mixed = ( r[..., 0:1] * out_local + r[..., 1:2] * out_sparse + r[..., 2:3] * out_dense ) # Load-balance: per-pathway mean usage should approach 1/3. pathway_use = r.mean(dim=(0, 1)) # (3,) target = 1.0 / self.n_pathways self._aux_loss = ((pathway_use - target) ** 2).mean() * self.load_balance_weight # Cache per-token router entropy on this forward call so an outer # training loop can read it for a metacognition aux loss (see # scripts/train_router_metacog.py). Shape (B, L). On the # inference path nothing reads this; cheap to compute. self._router_entropy = -(r * (r + 1e-12).log()).sum(dim=-1) return x + mixed @property def aux_loss(self) -> Tensor: return self._aux_loss @torch.no_grad() def router_weights(self, x: Tensor) -> Tensor: h = self.norm(x) return torch.softmax(self.router(h), dim=-1) @torch.no_grad() def router_entropy(self, x: Tensor) -> Tensor: """Per-token entropy of router distribution. Low → committed to one pathway (high confidence). High → uncertain mix.""" r = self.router_weights(x) return -(r * (r + 1e-12).log()).sum(dim=-1) # ── Incremental-decode helpers ────────────────────────────────────── # # A block "cache" is a dict: # {"conv_buffer": (B, k-1, D), # "sparse_caches": [head_cache_dict for each head]} def empty_cache(self, batch_size: int, device, dtype) -> dict: return { "conv_buffer": self.local.empty_buffer(batch_size, device, dtype), "sparse_caches": [h.empty_cache(batch_size, device, dtype) for h in self.sparse_heads], } def warmup_cache(self, x: Tensor) -> dict: """Build the cache from a full-prompt input x (B, L, D) — the SAME x that was fed to forward() during prompt processing. This is what the norm-then-pathway view sees, so we pass `h = self.norm(x)` here.""" h = self.norm(x) return { "conv_buffer": self.local.warmup_buffer(h), "sparse_caches": [head.warmup_cache(h) for head in self.sparse_heads], } def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]: """One-token step through the block. Returns (out_step, new_cache). out_step is the new residual contribution + x (so caller doesn't need to re-add the residual).""" h = self.norm(x_step) # (B, 1, D) r = torch.softmax(self.router(h), dim=-1) # (B, 1, 3) # Local conv: prepend buffer, conv → 1 output, slide buffer out_local, new_conv_buf = self.local.forward_incremental(h, cache["conv_buffer"]) # Sparse multi-head: each head incrementally updates its cache head_outs = [] new_sparse_caches = [] for head, hc in zip(self.sparse_heads, cache["sparse_caches"]): y_h, hc_new = head.forward_incremental(h, hc) head_outs.append(y_h) new_sparse_caches.append(hc_new) merged = torch.stack(head_outs, dim=0).mean(dim=0) # (B, 1, D) out_sparse = self.sparse_proj(merged) # Dense FFN: stateless out_dense = self.dense(h) mixed = ( r[..., 0:1] * out_local + r[..., 1:2] * out_sparse + r[..., 2:3] * out_dense ) new_cache = { "conv_buffer": new_conv_buf, "sparse_caches": new_sparse_caches, } return x_step + mixed, new_cache class TernaryEmbeddingLite(nn.Module): """Token id → ternary vector. Embedding weights are quantized to {-1,0,+1} with a per-tensor scale at forward time.""" def __init__(self, vocab_size: int, d_model: int, quantize: bool = True) -> None: super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.quantize = quantize w = torch.randn(vocab_size, d_model) * (1.0 / d_model**0.5) self.weight = nn.Parameter(w) def forward(self, ids: Tensor) -> Tensor: if self.quantize: from tilelli.core.ternary import ternarize w_q = ternarize(self.weight) else: w_q = self.weight return w_q[ids] class TilelliLiteLM(nn.Module): """Byte-level LM with TilelliLiteBlock stack + learned positional embed.""" def __init__( self, vocab_size: int = 256, d_model: int = 384, n_layers: int = 8, n_heads: int = 8, kernel_size: int = 5, top_k: int = 16, ffn_expand: int = 4, max_seq_len: int = 2048, quantize: bool = True, load_balance_weight: float = 0.01, ) -> None: super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.max_seq_len = max_seq_len self.quantize = quantize self.embed = TernaryEmbeddingLite(vocab_size, d_model, quantize=quantize) # Learned positional embedding — FP32 even in ternary mode (position # info must survive quantization). self.pos_embed = nn.Embedding(max_seq_len, d_model) nn.init.normal_(self.pos_embed.weight, std=0.02) self.blocks = nn.ModuleList([ TilelliLiteBlock( d_model=d_model, n_heads=n_heads, kernel_size=kernel_size, top_k=top_k, ffn_expand=ffn_expand, quantize=quantize, load_balance_weight=load_balance_weight, ) for _ in range(n_layers) ]) self.final_norm = nn.LayerNorm(d_model) self.unembed = TernaryLinear(d_model, vocab_size, quantize=quantize) def forward(self, ids: Tensor) -> Tensor: L = ids.size(1) if L > self.max_seq_len: raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}") x = self.embed(ids) pos = torch.arange(L, device=ids.device) x = x + self.pos_embed(pos) for blk in self.blocks: x = blk(x) x = self.final_norm(x) return self.unembed(x) def loss(self, ids: Tensor, targets: Tensor | None = None) -> Tensor: """Autoregressive next-token loss + load-balance aux. Compatible with both the (ids,) "shift internally" convention and the (ids, targets) "caller-supplied targets" convention. If targets is None we shift ids ourselves; otherwise we trust the caller (train.py-style). """ if targets is None: if ids.size(1) < 2: raise ValueError("loss needs sequence length >= 2") inp = ids[:, :-1] tgt = ids[:, 1:] else: inp, tgt = ids, targets logits = self(inp) ce = torch.nn.functional.cross_entropy( logits.reshape(-1, self.vocab_size), tgt.reshape(-1), ) aux = sum(blk.aux_loss for blk in self.blocks) return ce + aux @torch.no_grad() def router_entropies(self, ids: Tensor) -> Tensor: """Per-layer router entropy, shape (n_layers, B, L).""" x = self.embed(ids) pos = torch.arange(ids.size(1), device=ids.device) x = x + self.pos_embed(pos) ents = [] for blk in self.blocks: ents.append(blk.router_entropy(x)) x = blk(x) return torch.stack(ents, dim=0) # ── Incremental generation with KV cache ──────────────────────────── # # Big perf win: each step does one forward pass over a SINGLE new token, # using cached K/V for attention and a sliding buffer for the conv. The # dense FFN was the dominant cost without cache; with cache it runs once # per step, not L times. # # Correctness: bit-exact equivalent of the non-cached forward at the # final position (up to float-ordering noise, which doesn't change # argmax). Verified by tests/test_kv_cache_parity.py. @torch.no_grad() def warmup_caches(self, ids: Tensor) -> tuple[Tensor, list[dict]]: """Run the full prompt forward, build per-layer caches, return the final hidden state at the LAST position (for the first next-token sample) plus the caches. """ L = ids.size(1) if L > self.max_seq_len: raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}") x = self.embed(ids) pos = torch.arange(L, device=ids.device) x = x + self.pos_embed(pos) caches = [] for blk in self.blocks: caches.append(blk.warmup_cache(x)) x = blk(x) return x, caches @torch.no_grad() def step_with_cache(self, next_id: Tensor, pos_index: int, caches: list[dict]) -> tuple[Tensor, list[dict]]: """Forward ONE new token (B, 1) at absolute position pos_index. Uses + updates the per-layer caches in-place-ish (returns new list).""" x = self.embed(next_id) # (B, 1, D) pos = torch.tensor([pos_index], device=next_id.device) x = x + self.pos_embed(pos) new_caches = [] for blk, c in zip(self.blocks, caches): x, c_new = blk.forward_incremental(x, c) new_caches.append(c_new) x = self.final_norm(x) return self.unembed(x), new_caches @torch.no_grad() def generate_with_cache( self, ids: Tensor, n_new_tokens: int, stop_ids: tuple[int, ...] = (10, 0), return_logits: bool = False, ) -> tuple[Tensor, list[int], list[float]]: """Greedy generate up to n_new_tokens using the KV cache. Returns (full_ids, generated_id_list, confidence_per_step). For non-greedy sampling, callers should use step_with_cache directly. """ was_training = self.training self.eval() try: # Warm caches on the prompt; get the final-position logits via # one extra final_norm + unembed of the last hidden state. h_last, caches = self.warmup_caches(ids) # (B, L, D) h_last_pos = self.final_norm(h_last[:, -1:, :]) # (B, 1, D) logits = self.unembed(h_last_pos) # (B, 1, V) cur_pos = ids.size(1) # next pos to fill full = ids generated: list[int] = [] confs: list[float] = [] for _ in range(n_new_tokens): probs = torch.softmax(logits[:, -1, :], dim=-1) next_id = probs.argmax(dim=-1, keepdim=True) # (B, 1) nid_int = int(next_id) confs.append(float(probs.max())) generated.append(nid_int) full = torch.cat([full, next_id], dim=1) if nid_int in stop_ids: break if cur_pos + 1 > self.max_seq_len: break logits, caches = self.step_with_cache(next_id, cur_pos, caches) cur_pos += 1 return full, generated, confs finally: if was_training: self.train()