File size: 16,201 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""tilelli.core.tilelli_lite β€” clean 3-pathway block designed to beat a same-size vanilla baseline.

A prior 6-pathway variant of this architecture (~10.6M params) tied vanilla on
TinyStories byte-LM (mean 0.5737 vs vanilla 0.5707). Internal audit attributed
the tie to fragmentation: parameter budget was spent on pathways the byte-LM
data did not reward (an indexed-knowledge slot, a wide convolution, and a
non-selective state-space path).

Tilelli Lite cuts those underperforming slots and keeps the lessons that DO
show up at 10M scale: heterogeneous pathways with a learned router, and a
ternary-capable forward pass for inference. This module is a sibling to the
larger 5/6-pathway block (kept intact for non-byte-LM workloads); it is not
a drop-in replacement.

3-pathway block:
  - Local conv k=5  (n-grams; strictly more efficient than attention here)
  - Sparse causal attention with multi-head (8 heads, d_head=48 by default)
  - Dense FFN with expand=4 (matches vanilla's FFN ratio)

Other lessons folded in from the prior block's audit:
  - Learned positional embedding (recovers the position signal lost by
    the previous unembedding-only design)
  - Load-balance auxiliary loss properly wired through the router head
"""
from __future__ import annotations

import torch
from torch import Tensor, nn

from tilelli.core.sparse_attention import SparseCausalAttention
from tilelli.core.ternary_conv import TernaryCausalConv1d
from tilelli.core.ternary_linear import TernaryLinear


PATHWAY_NAMES_LITE = ("local", "sparse", "dense")


class TernaryFFN_Lite(nn.Module):
    """Wider FFN at expand=4 (matches vanilla's ratio)."""

    def __init__(self, d_model: int, expand: int = 4, quantize: bool = True) -> None:
        super().__init__()
        d_inner = d_model * expand
        self.up = TernaryLinear(d_model, d_inner, quantize=quantize)
        self.down = TernaryLinear(d_inner, d_model, quantize=quantize)

    def forward(self, x: Tensor) -> Tensor:
        return self.down(torch.nn.functional.gelu(self.up(x)))


class TilelliLiteBlock(nn.Module):
    """3-pathway block: Local conv + Sparse multi-head attn + Dense FFN.

    All pathways always fire; per-token soft router mixes them. Load-balance
    aux loss penalizes router collapse to one pathway.
    """

    def __init__(
        self,
        d_model: int,
        n_heads: int = 8,
        kernel_size: int = 5,
        top_k: int = 16,
        ffn_expand: int = 4,
        quantize: bool = True,
        load_balance_weight: float = 0.01,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.n_pathways = 3
        self.load_balance_weight = load_balance_weight

        # Multi-head sparse attention. d_head computed from n_heads so total
        # head dim equals d_model (matches vanilla's attention shape).
        d_head = d_model // n_heads
        if d_model % n_heads != 0:
            raise ValueError(f"d_model {d_model} must divide n_heads {n_heads}")

        self.norm = nn.LayerNorm(d_model)
        self.local = TernaryCausalConv1d(d_model, kernel_size=kernel_size, quantize=quantize)
        # Per-head Sparse attention β€” wraps n_heads of the existing single-head
        # implementation, concatenates outputs.
        self.sparse_heads = nn.ModuleList([
            SparseCausalAttention(d_model, d_head=d_head, top_k=top_k)
            for _ in range(n_heads)
        ])
        self.sparse_proj = TernaryLinear(d_model, d_model, quantize=quantize)
        self.dense = TernaryFFN_Lite(d_model, expand=ffn_expand, quantize=quantize)

        self.router = TernaryLinear(d_model, self.n_pathways, quantize=quantize)
        self._aux_loss = torch.tensor(0.0)

    def _multi_head_sparse(self, h: Tensor) -> Tensor:
        """Concat outputs of n_heads single-head Sparse attentions, project."""
        # Each head outputs (B, L, d_head). Concat β†’ (B, L, n_heads*d_head=d_model).
        # SparseCausalAttention returns (B, L, d_model) β€” sum heads instead, then proj.
        # Sum is param-efficient and equivalent to mean attention pooling.
        head_outs = [h_mod(h) for h_mod in self.sparse_heads]
        # Average rather than concat to keep dims at d_model (heads' outputs
        # are already d_model each; this gives a smoothed multi-head signal).
        merged = torch.stack(head_outs, dim=0).mean(dim=0)
        return self.sparse_proj(merged)

    def forward(self, x: Tensor) -> Tensor:
        h = self.norm(x)
        r = torch.softmax(self.router(h), dim=-1)   # (B, L, 3)

        out_local = self.local(h)                    # (B, L, d_model)
        out_sparse = self._multi_head_sparse(h)
        out_dense = self.dense(h)

        mixed = (
            r[..., 0:1] * out_local
            + r[..., 1:2] * out_sparse
            + r[..., 2:3] * out_dense
        )

        # Load-balance: per-pathway mean usage should approach 1/3.
        pathway_use = r.mean(dim=(0, 1))             # (3,)
        target = 1.0 / self.n_pathways
        self._aux_loss = ((pathway_use - target) ** 2).mean() * self.load_balance_weight

        # Cache per-token router entropy on this forward call so an outer
        # training loop can read it for a metacognition aux loss (see
        # scripts/train_router_metacog.py). Shape (B, L). On the
        # inference path nothing reads this; cheap to compute.
        self._router_entropy = -(r * (r + 1e-12).log()).sum(dim=-1)

        return x + mixed

    @property
    def aux_loss(self) -> Tensor:
        return self._aux_loss

    @torch.no_grad()
    def router_weights(self, x: Tensor) -> Tensor:
        h = self.norm(x)
        return torch.softmax(self.router(h), dim=-1)

    @torch.no_grad()
    def router_entropy(self, x: Tensor) -> Tensor:
        """Per-token entropy of router distribution. Low β†’ committed to one
        pathway (high confidence). High β†’ uncertain mix."""
        r = self.router_weights(x)
        return -(r * (r + 1e-12).log()).sum(dim=-1)

    # ── Incremental-decode helpers ────────────────────────────────────── #
    # A block "cache" is a dict:
    #   {"conv_buffer": (B, k-1, D),
    #    "sparse_caches": [head_cache_dict for each head]}

    def empty_cache(self, batch_size: int, device, dtype) -> dict:
        return {
            "conv_buffer": self.local.empty_buffer(batch_size, device, dtype),
            "sparse_caches": [h.empty_cache(batch_size, device, dtype)
                              for h in self.sparse_heads],
        }

    def warmup_cache(self, x: Tensor) -> dict:
        """Build the cache from a full-prompt input x (B, L, D) β€” the SAME x
        that was fed to forward() during prompt processing. This is what the
        norm-then-pathway view sees, so we pass `h = self.norm(x)` here."""
        h = self.norm(x)
        return {
            "conv_buffer": self.local.warmup_buffer(h),
            "sparse_caches": [head.warmup_cache(h) for head in self.sparse_heads],
        }

    def forward_incremental(self, x_step: Tensor, cache: dict) -> tuple[Tensor, dict]:
        """One-token step through the block. Returns (out_step, new_cache).
        out_step is the new residual contribution + x (so caller doesn't need
        to re-add the residual)."""
        h = self.norm(x_step)                                # (B, 1, D)
        r = torch.softmax(self.router(h), dim=-1)            # (B, 1, 3)

        # Local conv: prepend buffer, conv β†’ 1 output, slide buffer
        out_local, new_conv_buf = self.local.forward_incremental(h, cache["conv_buffer"])

        # Sparse multi-head: each head incrementally updates its cache
        head_outs = []
        new_sparse_caches = []
        for head, hc in zip(self.sparse_heads, cache["sparse_caches"]):
            y_h, hc_new = head.forward_incremental(h, hc)
            head_outs.append(y_h)
            new_sparse_caches.append(hc_new)
        merged = torch.stack(head_outs, dim=0).mean(dim=0)   # (B, 1, D)
        out_sparse = self.sparse_proj(merged)

        # Dense FFN: stateless
        out_dense = self.dense(h)

        mixed = (
            r[..., 0:1] * out_local
            + r[..., 1:2] * out_sparse
            + r[..., 2:3] * out_dense
        )
        new_cache = {
            "conv_buffer": new_conv_buf,
            "sparse_caches": new_sparse_caches,
        }
        return x_step + mixed, new_cache


class TernaryEmbeddingLite(nn.Module):
    """Token id β†’ ternary vector. Embedding weights are quantized to {-1,0,+1} with a per-tensor scale at forward time."""

    def __init__(self, vocab_size: int, d_model: int, quantize: bool = True) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.quantize = quantize
        w = torch.randn(vocab_size, d_model) * (1.0 / d_model**0.5)
        self.weight = nn.Parameter(w)

    def forward(self, ids: Tensor) -> Tensor:
        if self.quantize:
            from tilelli.core.ternary import ternarize
            w_q = ternarize(self.weight)
        else:
            w_q = self.weight
        return w_q[ids]


class TilelliLiteLM(nn.Module):
    """Byte-level LM with TilelliLiteBlock stack + learned positional embed."""

    def __init__(
        self,
        vocab_size: int = 256,
        d_model: int = 384,
        n_layers: int = 8,
        n_heads: int = 8,
        kernel_size: int = 5,
        top_k: int = 16,
        ffn_expand: int = 4,
        max_seq_len: int = 2048,
        quantize: bool = True,
        load_balance_weight: float = 0.01,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.quantize = quantize

        self.embed = TernaryEmbeddingLite(vocab_size, d_model, quantize=quantize)
        # Learned positional embedding β€” FP32 even in ternary mode (position
        # info must survive quantization).
        self.pos_embed = nn.Embedding(max_seq_len, d_model)
        nn.init.normal_(self.pos_embed.weight, std=0.02)

        self.blocks = nn.ModuleList([
            TilelliLiteBlock(
                d_model=d_model, n_heads=n_heads, kernel_size=kernel_size,
                top_k=top_k, ffn_expand=ffn_expand, quantize=quantize,
                load_balance_weight=load_balance_weight,
            )
            for _ in range(n_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model)
        self.unembed = TernaryLinear(d_model, vocab_size, quantize=quantize)

    def forward(self, ids: Tensor) -> Tensor:
        L = ids.size(1)
        if L > self.max_seq_len:
            raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}")
        x = self.embed(ids)
        pos = torch.arange(L, device=ids.device)
        x = x + self.pos_embed(pos)
        for blk in self.blocks:
            x = blk(x)
        x = self.final_norm(x)
        return self.unembed(x)

    def loss(self, ids: Tensor, targets: Tensor | None = None) -> Tensor:
        """Autoregressive next-token loss + load-balance aux.

        Compatible with both the (ids,) "shift internally" convention and the
        (ids, targets) "caller-supplied targets" convention. If targets is None
        we shift ids ourselves; otherwise we trust the caller (train.py-style).
        """
        if targets is None:
            if ids.size(1) < 2:
                raise ValueError("loss needs sequence length >= 2")
            inp = ids[:, :-1]
            tgt = ids[:, 1:]
        else:
            inp, tgt = ids, targets
        logits = self(inp)
        ce = torch.nn.functional.cross_entropy(
            logits.reshape(-1, self.vocab_size),
            tgt.reshape(-1),
        )
        aux = sum(blk.aux_loss for blk in self.blocks)
        return ce + aux

    @torch.no_grad()
    def router_entropies(self, ids: Tensor) -> Tensor:
        """Per-layer router entropy, shape (n_layers, B, L)."""
        x = self.embed(ids)
        pos = torch.arange(ids.size(1), device=ids.device)
        x = x + self.pos_embed(pos)
        ents = []
        for blk in self.blocks:
            ents.append(blk.router_entropy(x))
            x = blk(x)
        return torch.stack(ents, dim=0)

    # ── Incremental generation with KV cache ──────────────────────────── #
    # Big perf win: each step does one forward pass over a SINGLE new token,
    # using cached K/V for attention and a sliding buffer for the conv. The
    # dense FFN was the dominant cost without cache; with cache it runs once
    # per step, not L times.
    #
    # Correctness: bit-exact equivalent of the non-cached forward at the
    # final position (up to float-ordering noise, which doesn't change
    # argmax). Verified by tests/test_kv_cache_parity.py.

    @torch.no_grad()
    def warmup_caches(self, ids: Tensor) -> tuple[Tensor, list[dict]]:
        """Run the full prompt forward, build per-layer caches, return the
        final hidden state at the LAST position (for the first next-token
        sample) plus the caches.
        """
        L = ids.size(1)
        if L > self.max_seq_len:
            raise ValueError(f"sequence length {L} > max_seq_len {self.max_seq_len}")
        x = self.embed(ids)
        pos = torch.arange(L, device=ids.device)
        x = x + self.pos_embed(pos)
        caches = []
        for blk in self.blocks:
            caches.append(blk.warmup_cache(x))
            x = blk(x)
        return x, caches

    @torch.no_grad()
    def step_with_cache(self, next_id: Tensor, pos_index: int,
                        caches: list[dict]) -> tuple[Tensor, list[dict]]:
        """Forward ONE new token (B, 1) at absolute position pos_index. Uses
        + updates the per-layer caches in-place-ish (returns new list)."""
        x = self.embed(next_id)                                  # (B, 1, D)
        pos = torch.tensor([pos_index], device=next_id.device)
        x = x + self.pos_embed(pos)
        new_caches = []
        for blk, c in zip(self.blocks, caches):
            x, c_new = blk.forward_incremental(x, c)
            new_caches.append(c_new)
        x = self.final_norm(x)
        return self.unembed(x), new_caches

    @torch.no_grad()
    def generate_with_cache(
        self,
        ids: Tensor,
        n_new_tokens: int,
        stop_ids: tuple[int, ...] = (10, 0),
        return_logits: bool = False,
    ) -> tuple[Tensor, list[int], list[float]]:
        """Greedy generate up to n_new_tokens using the KV cache. Returns
        (full_ids, generated_id_list, confidence_per_step).

        For non-greedy sampling, callers should use step_with_cache directly.
        """
        was_training = self.training
        self.eval()
        try:
            # Warm caches on the prompt; get the final-position logits via
            # one extra final_norm + unembed of the last hidden state.
            h_last, caches = self.warmup_caches(ids)              # (B, L, D)
            h_last_pos = self.final_norm(h_last[:, -1:, :])       # (B, 1, D)
            logits = self.unembed(h_last_pos)                     # (B, 1, V)
            cur_pos = ids.size(1)                                  # next pos to fill
            full = ids
            generated: list[int] = []
            confs: list[float] = []
            for _ in range(n_new_tokens):
                probs = torch.softmax(logits[:, -1, :], dim=-1)
                next_id = probs.argmax(dim=-1, keepdim=True)       # (B, 1)
                nid_int = int(next_id)
                confs.append(float(probs.max()))
                generated.append(nid_int)
                full = torch.cat([full, next_id], dim=1)
                if nid_int in stop_ids:
                    break
                if cur_pos + 1 > self.max_seq_len:
                    break
                logits, caches = self.step_with_cache(next_id, cur_pos, caches)
                cur_pos += 1
            return full, generated, confs
        finally:
            if was_training:
                self.train()