File size: 6,846 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """tilelli.core.sparse_attention β the Sparse pathway of Tilelli.
From ARCHITECTURE.md:
Sparse path: top-k = 8 selective attention. Precise lookup only. O(nΒ·k).
Classic scaled dot-product attention is O(LΒ²) because every query attends
to every key. Our claim is that most tokens do *not* need dense
lookup β the Local conv and the State SSM already handle adjacency and
long-range carry, leaving the Sparse path for the rare precise lookups
("fetch the variable named `x` defined 40 tokens ago"). For those cases,
a single query only needs to find its top few matches.
Day-0 design:
- Q, K, V projections are `TernaryLinear`. This keeps the thesis
intact: every learned matmul in the block is ternary.
- Attention is single-head at first. Multi-head is an easy addition
once the single-head path is tested and trained.
- Causal mask + top-k: per query row, keep the k highest-scoring
*past* positions, set the rest to -inf, softmax over the rest.
- Because we only softmax over k values per row, the output is
trivially the weighted sum of k V-rows. That's the O(LΒ·k) claim.
Two subtleties:
- At position t < k, fewer than k past positions exist. The top-k
over a row containing (t+1) real scores and (L - t - 1) -infs just
returns those (t+1) reals in the first slots and -infs in the rest;
softmax happily turns the -infs into zero. Nothing to special-case.
- scaled_dot_product uses sqrt(d_head) as the temperature. Keep it.
"""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from tilelli.core.ternary_linear import TernaryLinear
class SparseCausalAttention(nn.Module):
"""Single-head causal top-k attention with ternary Q/K/V projections.
Parameters
----------
d_model : int
Input and output channel count.
d_head : int
Query/key dimensionality. V keeps d_model so the output width
matches the input width without an extra projection.
top_k : int
How many past positions each query attends to. Defaults to 8 per
the architecture spec.
"""
def __init__(
self,
d_model: int,
d_head: int = 32,
top_k: int = 8,
quantize: bool = True,
) -> None:
super().__init__()
self.d_model = d_model
self.d_head = d_head
self.top_k = top_k
self.Wq = TernaryLinear(d_model, d_head, quantize=quantize)
self.Wk = TernaryLinear(d_model, d_head, quantize=quantize)
self.Wv = TernaryLinear(d_model, d_model, quantize=quantize)
def forward(self, x: Tensor) -> Tensor:
if x.dim() != 3:
raise ValueError(f"expected (B, L, D), got shape {tuple(x.shape)}")
B, L, D = x.shape
if D != self.d_model:
raise ValueError(f"d_model mismatch: module has {self.d_model}, input has {D}")
q = self.Wq(x) # (B, L, d_head)
k = self.Wk(x) # (B, L, d_head)
v = self.Wv(x) # (B, L, D)
# scores: (B, L_q, L_k)
scale = 1.0 / math.sqrt(self.d_head)
scores = (q @ k.transpose(-1, -2)) * scale
# causal mask: j > i is forbidden
causal = torch.ones(L, L, dtype=torch.bool, device=x.device).triu(1)
scores = scores.masked_fill(causal, float("-inf"))
# top-k per query row. `torch.topk` on a row containing -infs just
# ranks the real scores first β nothing to special-case for t < k.
k_eff = min(self.top_k, L)
topk_vals, topk_idx = scores.topk(k_eff, dim=-1)
# sparse score matrix: -inf everywhere except the top-k slots
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_idx, topk_vals)
# softmax over the sparse matrix. Rows that are entirely -inf (t=0
# with no past) can produce NaNs; clean them up to zero.
attn = torch.softmax(sparse_scores, dim=-1)
attn = torch.nan_to_num(attn, nan=0.0)
return attn @ v # (B, L, D)
# ββ Incremental-decode helpers (KV cache) βββββββββββββββββββββββββββ #
# Cache layout per head: a dict {"K": (B, L_past, d_head), "V": (B, L_past, D)}
# On a 1-token step we project Q/K/V for the single new position,
# APPEND K/V to the cache, then attend the new Q over the (now-extended)
# K/V β applying the same top-k + softmax rules as the full-sequence
# forward. Output is (B, 1, D), identical to what a full forward would
# produce for that final position (bit-exact in float, modulo float
# ordering, which doesn't affect argmax).
def empty_cache(self, batch_size: int, device, dtype) -> dict:
return {
"K": torch.empty(batch_size, 0, self.d_head, device=device, dtype=dtype),
"V": torch.empty(batch_size, 0, self.d_model, device=device, dtype=dtype),
}
def warmup_cache(self, x: Tensor) -> dict:
"""Compute K, V for the full prompt and stash them as the cache."""
return {
"K": self.Wk(x).contiguous(),
"V": self.Wv(x).contiguous(),
}
def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]:
"""One-token step. Returns (y_step, new_cache) where y_step is (B, 1, D)
and new_cache is the cache extended by one position.
"""
if x_step.dim() != 3 or x_step.size(1) != 1:
raise ValueError(f"forward_incremental expects (B, 1, D), got {tuple(x_step.shape)}")
q_new = self.Wq(x_step) # (B, 1, d_head)
k_new = self.Wk(x_step) # (B, 1, d_head)
v_new = self.Wv(x_step) # (B, 1, D)
# Append to cache
K = torch.cat([cache["K"], k_new], dim=1) # (B, L+1, d_head)
V = torch.cat([cache["V"], v_new], dim=1) # (B, L+1, D)
# Single-row attention: query is q_new (B, 1, d_head), keys are K (B, L+1, d_head)
scale = 1.0 / math.sqrt(self.d_head)
scores = (q_new @ K.transpose(-1, -2)) * scale # (B, 1, L+1)
# Causal: the new query CAN attend to itself + all past β no mask needed
# (everything in K up to and including the new position is valid).
# Top-k over the single row
L_eff = scores.size(-1)
k_eff = min(self.top_k, L_eff)
topk_vals, topk_idx = scores.topk(k_eff, dim=-1)
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_idx, topk_vals)
attn = torch.softmax(sparse_scores, dim=-1)
attn = torch.nan_to_num(attn, nan=0.0)
y_step = attn @ V # (B, 1, D)
return y_step, {"K": K, "V": V}
|