File size: 10,511 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""tilelli.core.tilelli_block — heterogeneous-pathway block with a per-token
soft router.

Up to five structurally-different operations run in parallel on the same
input, mixed by a per-token softmax router. Optional Ternary Dispenser
(n_banks > 1) replicates each pathway across n_banks weight banks; the
router dispatches both pathway and bank per token. Compute per token stays
constant; parameter capacity multiplies by n_banks.
"""
from __future__ import annotations

import torch
from torch import Tensor, nn

from tilelli.core.sparse_attention import SparseCausalAttention
from tilelli.core.ssm import DiagonalSSM
from tilelli.core.ternary_conv import TernaryCausalConv1d
from tilelli.core.ternary_linear import TernaryLinear


PATHWAY_NAMES_3 = ("local", "state", "sparse")
PATHWAY_NAMES_5 = ("local", "wide", "state", "sparse", "dense")


class TernaryFFN(nn.Module):
    """Tiny feed-forward network with ternary weights: d → expand·d → d."""

    def __init__(
        self,
        d_model: int,
        expand: int = 2,
        quantize: bool = True,
        per_row: bool = False,
        hadamard: bool = False,
        lsq: bool = False,
    ) -> None:
        super().__init__()
        d_inner = d_model * expand
        self.up = TernaryLinear(
            d_model, d_inner,
            quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
        )
        self.down = TernaryLinear(
            d_inner, d_model,
            quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.down(torch.nn.functional.gelu(self.up(x)))


def _make_pathway(
    kind: str,
    d_model: int,
    d_head: int,
    kernel_size: int,
    wide_kernel_size: int,
    top_k: int,
    quantize: bool,
    per_row: bool,
    hadamard: bool,
    lsq: bool,
    dense_expand: int,
    fp_attention: bool,
) -> nn.Module:
    """Build a single pathway module of the named kind.

    fp_attention=True forces the Sparse pathway's Q/K/V projections to FP32
    even when the global quantize is True. From the Spectrum spinoff insight:
    attention is the precision-critical operation where ternary hurts most.
    """
    if kind == "local":
        return TernaryCausalConv1d(
            d_model, kernel_size=kernel_size,
            quantize=quantize, per_row=per_row, lsq=lsq,
        )
    if kind == "wide":
        return TernaryCausalConv1d(
            d_model, kernel_size=wide_kernel_size,
            quantize=quantize, per_row=per_row, lsq=lsq,
        )
    if kind == "state":
        return DiagonalSSM(d_model)
    if kind == "sparse":
        attn_quantize = False if fp_attention else quantize
        return SparseCausalAttention(
            d_model, d_head=d_head, top_k=top_k, quantize=attn_quantize,
        )
    if kind == "dense":
        return TernaryFFN(
            d_model, expand=dense_expand,
            quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
        )
    raise ValueError(f"unknown pathway kind: {kind}")


class TilelliBlock(nn.Module):
    """One Tilelli block: parallel heterogeneous pathways mixed by a router.

    Parameters
    ----------
    n_banks : int, default 1
        Number of weight banks per pathway (Ternary Dispenser). 1 = original.
        >1 = MoE at the weight level: each pathway holds n_banks copies, the
        router argmax-picks one bank per token. Adds a load-balancing aux
        loss accessible via .aux_loss after each forward.
    per_row, hadamard, lsq : bool
        Ternary-quantization tricks forwarded to TernaryLinear / Conv. All
        default off so the existing aurora-ternary baseline stays identical.
    skip_threshold, skip_mode : as before — only used by .infer().
    """

    def __init__(
        self,
        d_model: int,
        d_head: int = 32,
        kernel_size: int = 5,
        wide_kernel_size: int = 21,
        top_k: int = 8,
        pathways: int = 5,
        n_banks: int = 1,
        skip_threshold: float = 0.05,
        skip_mode: str = "per_call",
        quantize: bool = True,
        per_row: bool = False,
        hadamard: bool = False,
        lsq: bool = False,
        dense_expand: int = 2,
        fp_attention: bool = False,
        top_k_routing: int = 0,
    ) -> None:
        super().__init__()
        if pathways not in (3, 5):
            raise ValueError(f"pathways must be 3 or 5, got {pathways}")
        if skip_mode not in ("per_call", "per_token"):
            raise ValueError(f"skip_mode must be 'per_call' or 'per_token', got {skip_mode!r}")
        if n_banks < 1:
            raise ValueError(f"n_banks must be >= 1, got {n_banks}")
        self.d_model = d_model
        self.pathways = pathways
        self.n_banks = n_banks
        self.skip_threshold = skip_threshold
        self.skip_mode = skip_mode
        self.quantize = quantize
        self.top_k_routing = top_k_routing
        self.pathway_names = PATHWAY_NAMES_5 if pathways == 5 else PATHWAY_NAMES_3

        self.norm = nn.LayerNorm(d_model)

        def _build(kind: str) -> nn.Module | nn.ModuleList:
            mk = lambda: _make_pathway(
                kind, d_model, d_head, kernel_size, wide_kernel_size,
                top_k, quantize, per_row, hadamard, lsq, dense_expand,
                fp_attention,
            )
            if n_banks <= 1:
                return mk()
            return nn.ModuleList([mk() for _ in range(n_banks)])

        self.local = _build("local")
        self.state = _build("state")
        self.sparse = _build("sparse")
        if pathways == 5:
            self.wide = _build("wide")
            self.dense = _build("dense")

        # Router: routes over (pathway × bank) when n_banks > 1, else pathways.
        n_router_outputs = pathways * n_banks
        self.router = TernaryLinear(
            d_model, n_router_outputs,
            quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
        )

        self._aux_loss = torch.tensor(0.0)

    def _pathway_modules(self) -> list[tuple[str, nn.Module | nn.ModuleList]]:
        if self.pathways == 5:
            return [
                ("local", self.local),
                ("wide", self.wide),
                ("state", self.state),
                ("sparse", self.sparse),
                ("dense", self.dense),
            ]
        return [
            ("local", self.local),
            ("state", self.state),
            ("sparse", self.sparse),
        ]

    def _compute_single_bank(self, h: Tensor, r: Tensor) -> Tensor:
        outputs = [mod(h) for _, mod in self._pathway_modules()]
        return sum(r[..., i:i + 1] * outputs[i] for i in range(len(outputs)))

    def _compute_multi_bank(self, h: Tensor, r: Tensor) -> Tensor:
        """Multi-bank dispenser: per-token top-1 bank selection per pathway.

        r shape: (B, L, n_pathways * n_banks)
        """
        B, L, _ = r.shape
        plist = self._pathway_modules()
        n_paths = len(plist)
        r_2d = r.view(B, L, n_paths, self.n_banks)

        pathway_weights = r_2d.sum(dim=-1)  # (B, L, n_paths)
        bank_idx = r_2d.argmax(dim=-1)      # (B, L, n_paths)

        # Load balance: each bank should be selected ~1/n_banks of the time.
        bank_probs = r_2d.mean(dim=(0, 1))  # (n_paths, n_banks)
        target = 1.0 / self.n_banks
        self._aux_loss = ((bank_probs - target) ** 2).mean() * 0.01

        mixed = torch.zeros(B, L, self.d_model, device=h.device, dtype=h.dtype)
        for p_idx, (_name, banks) in enumerate(plist):
            pw = pathway_weights[..., p_idx:p_idx + 1]  # (B, L, 1)
            bidx = bank_idx[..., p_idx]                 # (B, L)
            for b in range(self.n_banks):
                mask = (bidx == b)
                if not mask.any():
                    continue
                out = banks[b](h)
                mixed = mixed + pw * out * mask.unsqueeze(-1).to(out.dtype)
        return mixed

    def _maybe_topk_route(self, r: Tensor) -> Tensor:
        """Optionally restrict routing to the top-k pathways per token (Mixtral-style)."""
        if self.top_k_routing <= 0 or self.top_k_routing >= r.shape[-1]:
            return r
        top_vals, top_idx = r.topk(self.top_k_routing, dim=-1)
        mask = torch.zeros_like(r)
        mask.scatter_(-1, top_idx, top_vals)
        return mask / mask.sum(dim=-1, keepdim=True).clamp(min=1e-12)

    def forward(self, x: Tensor) -> Tensor:
        h = self.norm(x)
        r = torch.softmax(self.router(h), dim=-1)
        r = self._maybe_topk_route(r)
        if self.n_banks <= 1:
            mixed = self._compute_single_bank(h, r)
        else:
            mixed = self._compute_multi_bank(h, r)
        return x + mixed

    @property
    def aux_loss(self) -> Tensor:
        """Load-balancing loss for multi-bank. Add to main loss during training."""
        return self._aux_loss

    @torch.no_grad()
    def infer(self, x: Tensor) -> Tensor:
        h = self.norm(x)
        r = torch.softmax(self.router(h), dim=-1)
        if self.n_banks > 1:
            return x + self._compute_multi_bank(h, r)
        y = torch.zeros_like(x)
        if self.skip_mode == "per_call":
            r_max = r.amax(dim=(0, 1))
            for i, (_, mod) in enumerate(self._pathway_modules()):
                if r_max[i].item() >= self.skip_threshold:
                    step = mod.infer(h) if hasattr(mod, "infer") else mod(h)
                    y = y + r[..., i:i + 1] * step
            return x + y
        for i, (_, mod) in enumerate(self._pathway_modules()):
            step = mod.infer(h) if hasattr(mod, "infer") else mod(h)
            mask = (r[..., i:i + 1] >= self.skip_threshold).to(step.dtype)
            y = y + mask * r[..., i:i + 1] * step
        return x + y

    @torch.no_grad()
    def router_weights(self, x: Tensor) -> Tensor:
        """Per-token router distribution.

        For single-bank: shape (B, L, n_pathways).
        For multi-bank: pathway-level weights (banks summed). Shape (B, L, n_pathways).
        """
        r = torch.softmax(self.router(self.norm(x)), dim=-1)
        if self.n_banks > 1:
            B, L, _ = r.shape
            n_paths = len(self._pathway_modules())
            return r.view(B, L, n_paths, self.n_banks).sum(dim=-1)
        return r

    @torch.no_grad()
    def router_entropy(self, x: Tensor) -> Tensor:
        r = self.router_weights(x).clamp_min(1e-12)
        return -(r * r.log()).sum(dim=-1)