Tilelli-llm / src /tilelli /core /sparse_attention.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
6.85 kB
"""tilelli.core.sparse_attention β€” the Sparse pathway of Tilelli.
From ARCHITECTURE.md:
Sparse path: top-k = 8 selective attention. Precise lookup only. O(nΒ·k).
Classic scaled dot-product attention is O(LΒ²) because every query attends
to every key. Our claim is that most tokens do *not* need dense
lookup β€” the Local conv and the State SSM already handle adjacency and
long-range carry, leaving the Sparse path for the rare precise lookups
("fetch the variable named `x` defined 40 tokens ago"). For those cases,
a single query only needs to find its top few matches.
Day-0 design:
- Q, K, V projections are `TernaryLinear`. This keeps the thesis
intact: every learned matmul in the block is ternary.
- Attention is single-head at first. Multi-head is an easy addition
once the single-head path is tested and trained.
- Causal mask + top-k: per query row, keep the k highest-scoring
*past* positions, set the rest to -inf, softmax over the rest.
- Because we only softmax over k values per row, the output is
trivially the weighted sum of k V-rows. That's the O(LΒ·k) claim.
Two subtleties:
- At position t < k, fewer than k past positions exist. The top-k
over a row containing (t+1) real scores and (L - t - 1) -infs just
returns those (t+1) reals in the first slots and -infs in the rest;
softmax happily turns the -infs into zero. Nothing to special-case.
- scaled_dot_product uses sqrt(d_head) as the temperature. Keep it.
"""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from tilelli.core.ternary_linear import TernaryLinear
class SparseCausalAttention(nn.Module):
"""Single-head causal top-k attention with ternary Q/K/V projections.
Parameters
----------
d_model : int
Input and output channel count.
d_head : int
Query/key dimensionality. V keeps d_model so the output width
matches the input width without an extra projection.
top_k : int
How many past positions each query attends to. Defaults to 8 per
the architecture spec.
"""
def __init__(
self,
d_model: int,
d_head: int = 32,
top_k: int = 8,
quantize: bool = True,
) -> None:
super().__init__()
self.d_model = d_model
self.d_head = d_head
self.top_k = top_k
self.Wq = TernaryLinear(d_model, d_head, quantize=quantize)
self.Wk = TernaryLinear(d_model, d_head, quantize=quantize)
self.Wv = TernaryLinear(d_model, d_model, quantize=quantize)
def forward(self, x: Tensor) -> Tensor:
if x.dim() != 3:
raise ValueError(f"expected (B, L, D), got shape {tuple(x.shape)}")
B, L, D = x.shape
if D != self.d_model:
raise ValueError(f"d_model mismatch: module has {self.d_model}, input has {D}")
q = self.Wq(x) # (B, L, d_head)
k = self.Wk(x) # (B, L, d_head)
v = self.Wv(x) # (B, L, D)
# scores: (B, L_q, L_k)
scale = 1.0 / math.sqrt(self.d_head)
scores = (q @ k.transpose(-1, -2)) * scale
# causal mask: j > i is forbidden
causal = torch.ones(L, L, dtype=torch.bool, device=x.device).triu(1)
scores = scores.masked_fill(causal, float("-inf"))
# top-k per query row. `torch.topk` on a row containing -infs just
# ranks the real scores first β€” nothing to special-case for t < k.
k_eff = min(self.top_k, L)
topk_vals, topk_idx = scores.topk(k_eff, dim=-1)
# sparse score matrix: -inf everywhere except the top-k slots
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_idx, topk_vals)
# softmax over the sparse matrix. Rows that are entirely -inf (t=0
# with no past) can produce NaNs; clean them up to zero.
attn = torch.softmax(sparse_scores, dim=-1)
attn = torch.nan_to_num(attn, nan=0.0)
return attn @ v # (B, L, D)
# ── Incremental-decode helpers (KV cache) ─────────────────────────── #
# Cache layout per head: a dict {"K": (B, L_past, d_head), "V": (B, L_past, D)}
# On a 1-token step we project Q/K/V for the single new position,
# APPEND K/V to the cache, then attend the new Q over the (now-extended)
# K/V β€” applying the same top-k + softmax rules as the full-sequence
# forward. Output is (B, 1, D), identical to what a full forward would
# produce for that final position (bit-exact in float, modulo float
# ordering, which doesn't affect argmax).
def empty_cache(self, batch_size: int, device, dtype) -> dict:
return {
"K": torch.empty(batch_size, 0, self.d_head, device=device, dtype=dtype),
"V": torch.empty(batch_size, 0, self.d_model, device=device, dtype=dtype),
}
def warmup_cache(self, x: Tensor) -> dict:
"""Compute K, V for the full prompt and stash them as the cache."""
return {
"K": self.Wk(x).contiguous(),
"V": self.Wv(x).contiguous(),
}
def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]:
"""One-token step. Returns (y_step, new_cache) where y_step is (B, 1, D)
and new_cache is the cache extended by one position.
"""
if x_step.dim() != 3 or x_step.size(1) != 1:
raise ValueError(f"forward_incremental expects (B, 1, D), got {tuple(x_step.shape)}")
q_new = self.Wq(x_step) # (B, 1, d_head)
k_new = self.Wk(x_step) # (B, 1, d_head)
v_new = self.Wv(x_step) # (B, 1, D)
# Append to cache
K = torch.cat([cache["K"], k_new], dim=1) # (B, L+1, d_head)
V = torch.cat([cache["V"], v_new], dim=1) # (B, L+1, D)
# Single-row attention: query is q_new (B, 1, d_head), keys are K (B, L+1, d_head)
scale = 1.0 / math.sqrt(self.d_head)
scores = (q_new @ K.transpose(-1, -2)) * scale # (B, 1, L+1)
# Causal: the new query CAN attend to itself + all past β†’ no mask needed
# (everything in K up to and including the new position is valid).
# Top-k over the single row
L_eff = scores.size(-1)
k_eff = min(self.top_k, L_eff)
topk_vals, topk_idx = scores.topk(k_eff, dim=-1)
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_idx, topk_vals)
attn = torch.softmax(sparse_scores, dim=-1)
attn = torch.nan_to_num(attn, nan=0.0)
y_step = attn @ V # (B, 1, D)
return y_step, {"K": K, "V": V}