| """tilelli.core.sparse_attention β the Sparse pathway of Tilelli. |
| |
| From ARCHITECTURE.md: |
| Sparse path: top-k = 8 selective attention. Precise lookup only. O(nΒ·k). |
| |
| Classic scaled dot-product attention is O(LΒ²) because every query attends |
| to every key. Our claim is that most tokens do *not* need dense |
| lookup β the Local conv and the State SSM already handle adjacency and |
| long-range carry, leaving the Sparse path for the rare precise lookups |
| ("fetch the variable named `x` defined 40 tokens ago"). For those cases, |
| a single query only needs to find its top few matches. |
| |
| Day-0 design: |
| - Q, K, V projections are `TernaryLinear`. This keeps the thesis |
| intact: every learned matmul in the block is ternary. |
| - Attention is single-head at first. Multi-head is an easy addition |
| once the single-head path is tested and trained. |
| - Causal mask + top-k: per query row, keep the k highest-scoring |
| *past* positions, set the rest to -inf, softmax over the rest. |
| - Because we only softmax over k values per row, the output is |
| trivially the weighted sum of k V-rows. That's the O(LΒ·k) claim. |
| |
| Two subtleties: |
| - At position t < k, fewer than k past positions exist. The top-k |
| over a row containing (t+1) real scores and (L - t - 1) -infs just |
| returns those (t+1) reals in the first slots and -infs in the rest; |
| softmax happily turns the -infs into zero. Nothing to special-case. |
| - scaled_dot_product uses sqrt(d_head) as the temperature. Keep it. |
| """ |
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from tilelli.core.ternary_linear import TernaryLinear |
|
|
|
|
| class SparseCausalAttention(nn.Module): |
| """Single-head causal top-k attention with ternary Q/K/V projections. |
| |
| Parameters |
| ---------- |
| d_model : int |
| Input and output channel count. |
| d_head : int |
| Query/key dimensionality. V keeps d_model so the output width |
| matches the input width without an extra projection. |
| top_k : int |
| How many past positions each query attends to. Defaults to 8 per |
| the architecture spec. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| d_head: int = 32, |
| top_k: int = 8, |
| quantize: bool = True, |
| ) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.d_head = d_head |
| self.top_k = top_k |
| self.Wq = TernaryLinear(d_model, d_head, quantize=quantize) |
| self.Wk = TernaryLinear(d_model, d_head, quantize=quantize) |
| self.Wv = TernaryLinear(d_model, d_model, quantize=quantize) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if x.dim() != 3: |
| raise ValueError(f"expected (B, L, D), got shape {tuple(x.shape)}") |
| B, L, D = x.shape |
| if D != self.d_model: |
| raise ValueError(f"d_model mismatch: module has {self.d_model}, input has {D}") |
|
|
| q = self.Wq(x) |
| k = self.Wk(x) |
| v = self.Wv(x) |
|
|
| |
| scale = 1.0 / math.sqrt(self.d_head) |
| scores = (q @ k.transpose(-1, -2)) * scale |
|
|
| |
| causal = torch.ones(L, L, dtype=torch.bool, device=x.device).triu(1) |
| scores = scores.masked_fill(causal, float("-inf")) |
|
|
| |
| |
| k_eff = min(self.top_k, L) |
| topk_vals, topk_idx = scores.topk(k_eff, dim=-1) |
|
|
| |
| sparse_scores = torch.full_like(scores, float("-inf")) |
| sparse_scores.scatter_(-1, topk_idx, topk_vals) |
|
|
| |
| |
| attn = torch.softmax(sparse_scores, dim=-1) |
| attn = torch.nan_to_num(attn, nan=0.0) |
|
|
| return attn @ v |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def empty_cache(self, batch_size: int, device, dtype) -> dict: |
| return { |
| "K": torch.empty(batch_size, 0, self.d_head, device=device, dtype=dtype), |
| "V": torch.empty(batch_size, 0, self.d_model, device=device, dtype=dtype), |
| } |
|
|
| def warmup_cache(self, x: Tensor) -> dict: |
| """Compute K, V for the full prompt and stash them as the cache.""" |
| return { |
| "K": self.Wk(x).contiguous(), |
| "V": self.Wv(x).contiguous(), |
| } |
|
|
| def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]: |
| """One-token step. Returns (y_step, new_cache) where y_step is (B, 1, D) |
| and new_cache is the cache extended by one position. |
| """ |
| if x_step.dim() != 3 or x_step.size(1) != 1: |
| raise ValueError(f"forward_incremental expects (B, 1, D), got {tuple(x_step.shape)}") |
|
|
| q_new = self.Wq(x_step) |
| k_new = self.Wk(x_step) |
| v_new = self.Wv(x_step) |
|
|
| |
| K = torch.cat([cache["K"], k_new], dim=1) |
| V = torch.cat([cache["V"], v_new], dim=1) |
|
|
| |
| scale = 1.0 / math.sqrt(self.d_head) |
| scores = (q_new @ K.transpose(-1, -2)) * scale |
| |
| |
|
|
| |
| L_eff = scores.size(-1) |
| k_eff = min(self.top_k, L_eff) |
| topk_vals, topk_idx = scores.topk(k_eff, dim=-1) |
| sparse_scores = torch.full_like(scores, float("-inf")) |
| sparse_scores.scatter_(-1, topk_idx, topk_vals) |
|
|
| attn = torch.softmax(sparse_scores, dim=-1) |
| attn = torch.nan_to_num(attn, nan=0.0) |
| y_step = attn @ V |
|
|
| return y_step, {"K": K, "V": V} |
|
|