"""tilelli.core.sparse_attention — the Sparse pathway of Tilelli. From ARCHITECTURE.md: Sparse path: top-k = 8 selective attention. Precise lookup only. O(n·k). Classic scaled dot-product attention is O(L²) because every query attends to every key. Our claim is that most tokens do *not* need dense lookup — the Local conv and the State SSM already handle adjacency and long-range carry, leaving the Sparse path for the rare precise lookups ("fetch the variable named `x` defined 40 tokens ago"). For those cases, a single query only needs to find its top few matches. Day-0 design: - Q, K, V projections are `TernaryLinear`. This keeps the thesis intact: every learned matmul in the block is ternary. - Attention is single-head at first. Multi-head is an easy addition once the single-head path is tested and trained. - Causal mask + top-k: per query row, keep the k highest-scoring *past* positions, set the rest to -inf, softmax over the rest. - Because we only softmax over k values per row, the output is trivially the weighted sum of k V-rows. That's the O(L·k) claim. Two subtleties: - At position t < k, fewer than k past positions exist. The top-k over a row containing (t+1) real scores and (L - t - 1) -infs just returns those (t+1) reals in the first slots and -infs in the rest; softmax happily turns the -infs into zero. Nothing to special-case. - scaled_dot_product uses sqrt(d_head) as the temperature. Keep it. """ from __future__ import annotations import math import torch from torch import Tensor, nn from tilelli.core.ternary_linear import TernaryLinear class SparseCausalAttention(nn.Module): """Single-head causal top-k attention with ternary Q/K/V projections. Parameters ---------- d_model : int Input and output channel count. d_head : int Query/key dimensionality. V keeps d_model so the output width matches the input width without an extra projection. top_k : int How many past positions each query attends to. Defaults to 8 per the architecture spec. """ def __init__( self, d_model: int, d_head: int = 32, top_k: int = 8, quantize: bool = True, ) -> None: super().__init__() self.d_model = d_model self.d_head = d_head self.top_k = top_k self.Wq = TernaryLinear(d_model, d_head, quantize=quantize) self.Wk = TernaryLinear(d_model, d_head, quantize=quantize) self.Wv = TernaryLinear(d_model, d_model, quantize=quantize) def forward(self, x: Tensor) -> Tensor: if x.dim() != 3: raise ValueError(f"expected (B, L, D), got shape {tuple(x.shape)}") B, L, D = x.shape if D != self.d_model: raise ValueError(f"d_model mismatch: module has {self.d_model}, input has {D}") q = self.Wq(x) # (B, L, d_head) k = self.Wk(x) # (B, L, d_head) v = self.Wv(x) # (B, L, D) # scores: (B, L_q, L_k) scale = 1.0 / math.sqrt(self.d_head) scores = (q @ k.transpose(-1, -2)) * scale # causal mask: j > i is forbidden causal = torch.ones(L, L, dtype=torch.bool, device=x.device).triu(1) scores = scores.masked_fill(causal, float("-inf")) # top-k per query row. `torch.topk` on a row containing -infs just # ranks the real scores first — nothing to special-case for t < k. k_eff = min(self.top_k, L) topk_vals, topk_idx = scores.topk(k_eff, dim=-1) # sparse score matrix: -inf everywhere except the top-k slots sparse_scores = torch.full_like(scores, float("-inf")) sparse_scores.scatter_(-1, topk_idx, topk_vals) # softmax over the sparse matrix. Rows that are entirely -inf (t=0 # with no past) can produce NaNs; clean them up to zero. attn = torch.softmax(sparse_scores, dim=-1) attn = torch.nan_to_num(attn, nan=0.0) return attn @ v # (B, L, D) # ── Incremental-decode helpers (KV cache) ─────────────────────────── # # Cache layout per head: a dict {"K": (B, L_past, d_head), "V": (B, L_past, D)} # On a 1-token step we project Q/K/V for the single new position, # APPEND K/V to the cache, then attend the new Q over the (now-extended) # K/V — applying the same top-k + softmax rules as the full-sequence # forward. Output is (B, 1, D), identical to what a full forward would # produce for that final position (bit-exact in float, modulo float # ordering, which doesn't affect argmax). def empty_cache(self, batch_size: int, device, dtype) -> dict: return { "K": torch.empty(batch_size, 0, self.d_head, device=device, dtype=dtype), "V": torch.empty(batch_size, 0, self.d_model, device=device, dtype=dtype), } def warmup_cache(self, x: Tensor) -> dict: """Compute K, V for the full prompt and stash them as the cache.""" return { "K": self.Wk(x).contiguous(), "V": self.Wv(x).contiguous(), } def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]: """One-token step. Returns (y_step, new_cache) where y_step is (B, 1, D) and new_cache is the cache extended by one position. """ if x_step.dim() != 3 or x_step.size(1) != 1: raise ValueError(f"forward_incremental expects (B, 1, D), got {tuple(x_step.shape)}") q_new = self.Wq(x_step) # (B, 1, d_head) k_new = self.Wk(x_step) # (B, 1, d_head) v_new = self.Wv(x_step) # (B, 1, D) # Append to cache K = torch.cat([cache["K"], k_new], dim=1) # (B, L+1, d_head) V = torch.cat([cache["V"], v_new], dim=1) # (B, L+1, D) # Single-row attention: query is q_new (B, 1, d_head), keys are K (B, L+1, d_head) scale = 1.0 / math.sqrt(self.d_head) scores = (q_new @ K.transpose(-1, -2)) * scale # (B, 1, L+1) # Causal: the new query CAN attend to itself + all past → no mask needed # (everything in K up to and including the new position is valid). # Top-k over the single row L_eff = scores.size(-1) k_eff = min(self.top_k, L_eff) topk_vals, topk_idx = scores.topk(k_eff, dim=-1) sparse_scores = torch.full_like(scores, float("-inf")) sparse_scores.scatter_(-1, topk_idx, topk_vals) attn = torch.softmax(sparse_scores, dim=-1) attn = torch.nan_to_num(attn, nan=0.0) y_step = attn @ V # (B, 1, D) return y_step, {"K": K, "V": V}