File size: 6,846 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""tilelli.core.sparse_attention β€” the Sparse pathway of Tilelli.

From ARCHITECTURE.md:
    Sparse path: top-k = 8 selective attention. Precise lookup only. O(nΒ·k).

Classic scaled dot-product attention is O(LΒ²) because every query attends
to every key. Our claim is that most tokens do *not* need dense
lookup β€” the Local conv and the State SSM already handle adjacency and
long-range carry, leaving the Sparse path for the rare precise lookups
("fetch the variable named `x` defined 40 tokens ago"). For those cases,
a single query only needs to find its top few matches.

Day-0 design:
  - Q, K, V projections are `TernaryLinear`. This keeps the thesis
    intact: every learned matmul in the block is ternary.
  - Attention is single-head at first. Multi-head is an easy addition
    once the single-head path is tested and trained.
  - Causal mask + top-k: per query row, keep the k highest-scoring
    *past* positions, set the rest to -inf, softmax over the rest.
  - Because we only softmax over k values per row, the output is
    trivially the weighted sum of k V-rows. That's the O(LΒ·k) claim.

Two subtleties:
  - At position t < k, fewer than k past positions exist. The top-k
    over a row containing (t+1) real scores and (L - t - 1) -infs just
    returns those (t+1) reals in the first slots and -infs in the rest;
    softmax happily turns the -infs into zero. Nothing to special-case.
  - scaled_dot_product uses sqrt(d_head) as the temperature. Keep it.
"""
from __future__ import annotations

import math

import torch
from torch import Tensor, nn

from tilelli.core.ternary_linear import TernaryLinear


class SparseCausalAttention(nn.Module):
    """Single-head causal top-k attention with ternary Q/K/V projections.

    Parameters
    ----------
    d_model : int
        Input and output channel count.
    d_head : int
        Query/key dimensionality. V keeps d_model so the output width
        matches the input width without an extra projection.
    top_k : int
        How many past positions each query attends to. Defaults to 8 per
        the architecture spec.
    """

    def __init__(
        self,
        d_model: int,
        d_head: int = 32,
        top_k: int = 8,
        quantize: bool = True,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.d_head = d_head
        self.top_k = top_k
        self.Wq = TernaryLinear(d_model, d_head, quantize=quantize)
        self.Wk = TernaryLinear(d_model, d_head, quantize=quantize)
        self.Wv = TernaryLinear(d_model, d_model, quantize=quantize)

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() != 3:
            raise ValueError(f"expected (B, L, D), got shape {tuple(x.shape)}")
        B, L, D = x.shape
        if D != self.d_model:
            raise ValueError(f"d_model mismatch: module has {self.d_model}, input has {D}")

        q = self.Wq(x)                          # (B, L, d_head)
        k = self.Wk(x)                          # (B, L, d_head)
        v = self.Wv(x)                          # (B, L, D)

        # scores: (B, L_q, L_k)
        scale = 1.0 / math.sqrt(self.d_head)
        scores = (q @ k.transpose(-1, -2)) * scale

        # causal mask: j > i is forbidden
        causal = torch.ones(L, L, dtype=torch.bool, device=x.device).triu(1)
        scores = scores.masked_fill(causal, float("-inf"))

        # top-k per query row. `torch.topk` on a row containing -infs just
        # ranks the real scores first β€” nothing to special-case for t < k.
        k_eff = min(self.top_k, L)
        topk_vals, topk_idx = scores.topk(k_eff, dim=-1)

        # sparse score matrix: -inf everywhere except the top-k slots
        sparse_scores = torch.full_like(scores, float("-inf"))
        sparse_scores.scatter_(-1, topk_idx, topk_vals)

        # softmax over the sparse matrix. Rows that are entirely -inf (t=0
        # with no past) can produce NaNs; clean them up to zero.
        attn = torch.softmax(sparse_scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)

        return attn @ v                         # (B, L, D)

    # ── Incremental-decode helpers (KV cache) ─────────────────────────── #
    # Cache layout per head: a dict {"K": (B, L_past, d_head), "V": (B, L_past, D)}
    # On a 1-token step we project Q/K/V for the single new position,
    # APPEND K/V to the cache, then attend the new Q over the (now-extended)
    # K/V β€” applying the same top-k + softmax rules as the full-sequence
    # forward. Output is (B, 1, D), identical to what a full forward would
    # produce for that final position (bit-exact in float, modulo float
    # ordering, which doesn't affect argmax).

    def empty_cache(self, batch_size: int, device, dtype) -> dict:
        return {
            "K": torch.empty(batch_size, 0, self.d_head, device=device, dtype=dtype),
            "V": torch.empty(batch_size, 0, self.d_model, device=device, dtype=dtype),
        }

    def warmup_cache(self, x: Tensor) -> dict:
        """Compute K, V for the full prompt and stash them as the cache."""
        return {
            "K": self.Wk(x).contiguous(),
            "V": self.Wv(x).contiguous(),
        }

    def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]:
        """One-token step. Returns (y_step, new_cache) where y_step is (B, 1, D)
        and new_cache is the cache extended by one position.
        """
        if x_step.dim() != 3 or x_step.size(1) != 1:
            raise ValueError(f"forward_incremental expects (B, 1, D), got {tuple(x_step.shape)}")

        q_new = self.Wq(x_step)          # (B, 1, d_head)
        k_new = self.Wk(x_step)          # (B, 1, d_head)
        v_new = self.Wv(x_step)          # (B, 1, D)

        # Append to cache
        K = torch.cat([cache["K"], k_new], dim=1)     # (B, L+1, d_head)
        V = torch.cat([cache["V"], v_new], dim=1)     # (B, L+1, D)

        # Single-row attention: query is q_new (B, 1, d_head), keys are K (B, L+1, d_head)
        scale = 1.0 / math.sqrt(self.d_head)
        scores = (q_new @ K.transpose(-1, -2)) * scale   # (B, 1, L+1)
        # Causal: the new query CAN attend to itself + all past β†’ no mask needed
        # (everything in K up to and including the new position is valid).

        # Top-k over the single row
        L_eff = scores.size(-1)
        k_eff = min(self.top_k, L_eff)
        topk_vals, topk_idx = scores.topk(k_eff, dim=-1)
        sparse_scores = torch.full_like(scores, float("-inf"))
        sparse_scores.scatter_(-1, topk_idx, topk_vals)

        attn = torch.softmax(sparse_scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        y_step = attn @ V                                # (B, 1, D)

        return y_step, {"K": K, "V": V}