| """tilelli.core.tilelli_block — heterogeneous-pathway block with a per-token |
| soft router. |
| |
| Up to five structurally-different operations run in parallel on the same |
| input, mixed by a per-token softmax router. Optional Ternary Dispenser |
| (n_banks > 1) replicates each pathway across n_banks weight banks; the |
| router dispatches both pathway and bank per token. Compute per token stays |
| constant; parameter capacity multiplies by n_banks. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from tilelli.core.sparse_attention import SparseCausalAttention |
| from tilelli.core.ssm import DiagonalSSM |
| from tilelli.core.ternary_conv import TernaryCausalConv1d |
| from tilelli.core.ternary_linear import TernaryLinear |
|
|
|
|
| PATHWAY_NAMES_3 = ("local", "state", "sparse") |
| PATHWAY_NAMES_5 = ("local", "wide", "state", "sparse", "dense") |
|
|
|
|
| class TernaryFFN(nn.Module): |
| """Tiny feed-forward network with ternary weights: d → expand·d → d.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| expand: int = 2, |
| quantize: bool = True, |
| per_row: bool = False, |
| hadamard: bool = False, |
| lsq: bool = False, |
| ) -> None: |
| super().__init__() |
| d_inner = d_model * expand |
| self.up = TernaryLinear( |
| d_model, d_inner, |
| quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, |
| ) |
| self.down = TernaryLinear( |
| d_inner, d_model, |
| quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.down(torch.nn.functional.gelu(self.up(x))) |
|
|
|
|
| def _make_pathway( |
| kind: str, |
| d_model: int, |
| d_head: int, |
| kernel_size: int, |
| wide_kernel_size: int, |
| top_k: int, |
| quantize: bool, |
| per_row: bool, |
| hadamard: bool, |
| lsq: bool, |
| dense_expand: int, |
| fp_attention: bool, |
| ) -> nn.Module: |
| """Build a single pathway module of the named kind. |
| |
| fp_attention=True forces the Sparse pathway's Q/K/V projections to FP32 |
| even when the global quantize is True. From the Spectrum spinoff insight: |
| attention is the precision-critical operation where ternary hurts most. |
| """ |
| if kind == "local": |
| return TernaryCausalConv1d( |
| d_model, kernel_size=kernel_size, |
| quantize=quantize, per_row=per_row, lsq=lsq, |
| ) |
| if kind == "wide": |
| return TernaryCausalConv1d( |
| d_model, kernel_size=wide_kernel_size, |
| quantize=quantize, per_row=per_row, lsq=lsq, |
| ) |
| if kind == "state": |
| return DiagonalSSM(d_model) |
| if kind == "sparse": |
| attn_quantize = False if fp_attention else quantize |
| return SparseCausalAttention( |
| d_model, d_head=d_head, top_k=top_k, quantize=attn_quantize, |
| ) |
| if kind == "dense": |
| return TernaryFFN( |
| d_model, expand=dense_expand, |
| quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, |
| ) |
| raise ValueError(f"unknown pathway kind: {kind}") |
|
|
|
|
| class TilelliBlock(nn.Module): |
| """One Tilelli block: parallel heterogeneous pathways mixed by a router. |
| |
| Parameters |
| ---------- |
| n_banks : int, default 1 |
| Number of weight banks per pathway (Ternary Dispenser). 1 = original. |
| >1 = MoE at the weight level: each pathway holds n_banks copies, the |
| router argmax-picks one bank per token. Adds a load-balancing aux |
| loss accessible via .aux_loss after each forward. |
| per_row, hadamard, lsq : bool |
| Ternary-quantization tricks forwarded to TernaryLinear / Conv. All |
| default off so the existing aurora-ternary baseline stays identical. |
| skip_threshold, skip_mode : as before — only used by .infer(). |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| d_head: int = 32, |
| kernel_size: int = 5, |
| wide_kernel_size: int = 21, |
| top_k: int = 8, |
| pathways: int = 5, |
| n_banks: int = 1, |
| skip_threshold: float = 0.05, |
| skip_mode: str = "per_call", |
| quantize: bool = True, |
| per_row: bool = False, |
| hadamard: bool = False, |
| lsq: bool = False, |
| dense_expand: int = 2, |
| fp_attention: bool = False, |
| top_k_routing: int = 0, |
| ) -> None: |
| super().__init__() |
| if pathways not in (3, 5): |
| raise ValueError(f"pathways must be 3 or 5, got {pathways}") |
| if skip_mode not in ("per_call", "per_token"): |
| raise ValueError(f"skip_mode must be 'per_call' or 'per_token', got {skip_mode!r}") |
| if n_banks < 1: |
| raise ValueError(f"n_banks must be >= 1, got {n_banks}") |
| self.d_model = d_model |
| self.pathways = pathways |
| self.n_banks = n_banks |
| self.skip_threshold = skip_threshold |
| self.skip_mode = skip_mode |
| self.quantize = quantize |
| self.top_k_routing = top_k_routing |
| self.pathway_names = PATHWAY_NAMES_5 if pathways == 5 else PATHWAY_NAMES_3 |
|
|
| self.norm = nn.LayerNorm(d_model) |
|
|
| def _build(kind: str) -> nn.Module | nn.ModuleList: |
| mk = lambda: _make_pathway( |
| kind, d_model, d_head, kernel_size, wide_kernel_size, |
| top_k, quantize, per_row, hadamard, lsq, dense_expand, |
| fp_attention, |
| ) |
| if n_banks <= 1: |
| return mk() |
| return nn.ModuleList([mk() for _ in range(n_banks)]) |
|
|
| self.local = _build("local") |
| self.state = _build("state") |
| self.sparse = _build("sparse") |
| if pathways == 5: |
| self.wide = _build("wide") |
| self.dense = _build("dense") |
|
|
| |
| n_router_outputs = pathways * n_banks |
| self.router = TernaryLinear( |
| d_model, n_router_outputs, |
| quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, |
| ) |
|
|
| self._aux_loss = torch.tensor(0.0) |
|
|
| def _pathway_modules(self) -> list[tuple[str, nn.Module | nn.ModuleList]]: |
| if self.pathways == 5: |
| return [ |
| ("local", self.local), |
| ("wide", self.wide), |
| ("state", self.state), |
| ("sparse", self.sparse), |
| ("dense", self.dense), |
| ] |
| return [ |
| ("local", self.local), |
| ("state", self.state), |
| ("sparse", self.sparse), |
| ] |
|
|
| def _compute_single_bank(self, h: Tensor, r: Tensor) -> Tensor: |
| outputs = [mod(h) for _, mod in self._pathway_modules()] |
| return sum(r[..., i:i + 1] * outputs[i] for i in range(len(outputs))) |
|
|
| def _compute_multi_bank(self, h: Tensor, r: Tensor) -> Tensor: |
| """Multi-bank dispenser: per-token top-1 bank selection per pathway. |
| |
| r shape: (B, L, n_pathways * n_banks) |
| """ |
| B, L, _ = r.shape |
| plist = self._pathway_modules() |
| n_paths = len(plist) |
| r_2d = r.view(B, L, n_paths, self.n_banks) |
|
|
| pathway_weights = r_2d.sum(dim=-1) |
| bank_idx = r_2d.argmax(dim=-1) |
|
|
| |
| bank_probs = r_2d.mean(dim=(0, 1)) |
| target = 1.0 / self.n_banks |
| self._aux_loss = ((bank_probs - target) ** 2).mean() * 0.01 |
|
|
| mixed = torch.zeros(B, L, self.d_model, device=h.device, dtype=h.dtype) |
| for p_idx, (_name, banks) in enumerate(plist): |
| pw = pathway_weights[..., p_idx:p_idx + 1] |
| bidx = bank_idx[..., p_idx] |
| for b in range(self.n_banks): |
| mask = (bidx == b) |
| if not mask.any(): |
| continue |
| out = banks[b](h) |
| mixed = mixed + pw * out * mask.unsqueeze(-1).to(out.dtype) |
| return mixed |
|
|
| def _maybe_topk_route(self, r: Tensor) -> Tensor: |
| """Optionally restrict routing to the top-k pathways per token (Mixtral-style).""" |
| if self.top_k_routing <= 0 or self.top_k_routing >= r.shape[-1]: |
| return r |
| top_vals, top_idx = r.topk(self.top_k_routing, dim=-1) |
| mask = torch.zeros_like(r) |
| mask.scatter_(-1, top_idx, top_vals) |
| return mask / mask.sum(dim=-1, keepdim=True).clamp(min=1e-12) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| h = self.norm(x) |
| r = torch.softmax(self.router(h), dim=-1) |
| r = self._maybe_topk_route(r) |
| if self.n_banks <= 1: |
| mixed = self._compute_single_bank(h, r) |
| else: |
| mixed = self._compute_multi_bank(h, r) |
| return x + mixed |
|
|
| @property |
| def aux_loss(self) -> Tensor: |
| """Load-balancing loss for multi-bank. Add to main loss during training.""" |
| return self._aux_loss |
|
|
| @torch.no_grad() |
| def infer(self, x: Tensor) -> Tensor: |
| h = self.norm(x) |
| r = torch.softmax(self.router(h), dim=-1) |
| if self.n_banks > 1: |
| return x + self._compute_multi_bank(h, r) |
| y = torch.zeros_like(x) |
| if self.skip_mode == "per_call": |
| r_max = r.amax(dim=(0, 1)) |
| for i, (_, mod) in enumerate(self._pathway_modules()): |
| if r_max[i].item() >= self.skip_threshold: |
| step = mod.infer(h) if hasattr(mod, "infer") else mod(h) |
| y = y + r[..., i:i + 1] * step |
| return x + y |
| for i, (_, mod) in enumerate(self._pathway_modules()): |
| step = mod.infer(h) if hasattr(mod, "infer") else mod(h) |
| mask = (r[..., i:i + 1] >= self.skip_threshold).to(step.dtype) |
| y = y + mask * r[..., i:i + 1] * step |
| return x + y |
|
|
| @torch.no_grad() |
| def router_weights(self, x: Tensor) -> Tensor: |
| """Per-token router distribution. |
| |
| For single-bank: shape (B, L, n_pathways). |
| For multi-bank: pathway-level weights (banks summed). Shape (B, L, n_pathways). |
| """ |
| r = torch.softmax(self.router(self.norm(x)), dim=-1) |
| if self.n_banks > 1: |
| B, L, _ = r.shape |
| n_paths = len(self._pathway_modules()) |
| return r.view(B, L, n_paths, self.n_banks).sum(dim=-1) |
| return r |
|
|
| @torch.no_grad() |
| def router_entropy(self, x: Tensor) -> Tensor: |
| r = self.router_weights(x).clamp_min(1e-12) |
| return -(r * r.log()).sum(dim=-1) |
|
|