"""tilelli.core.tilelli_block — heterogeneous-pathway block with a per-token soft router. Up to five structurally-different operations run in parallel on the same input, mixed by a per-token softmax router. Optional Ternary Dispenser (n_banks > 1) replicates each pathway across n_banks weight banks; the router dispatches both pathway and bank per token. Compute per token stays constant; parameter capacity multiplies by n_banks. """ from __future__ import annotations import torch from torch import Tensor, nn from tilelli.core.sparse_attention import SparseCausalAttention from tilelli.core.ssm import DiagonalSSM from tilelli.core.ternary_conv import TernaryCausalConv1d from tilelli.core.ternary_linear import TernaryLinear PATHWAY_NAMES_3 = ("local", "state", "sparse") PATHWAY_NAMES_5 = ("local", "wide", "state", "sparse", "dense") class TernaryFFN(nn.Module): """Tiny feed-forward network with ternary weights: d → expand·d → d.""" def __init__( self, d_model: int, expand: int = 2, quantize: bool = True, per_row: bool = False, hadamard: bool = False, lsq: bool = False, ) -> None: super().__init__() d_inner = d_model * expand self.up = TernaryLinear( d_model, d_inner, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, ) self.down = TernaryLinear( d_inner, d_model, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, ) def forward(self, x: Tensor) -> Tensor: return self.down(torch.nn.functional.gelu(self.up(x))) def _make_pathway( kind: str, d_model: int, d_head: int, kernel_size: int, wide_kernel_size: int, top_k: int, quantize: bool, per_row: bool, hadamard: bool, lsq: bool, dense_expand: int, fp_attention: bool, ) -> nn.Module: """Build a single pathway module of the named kind. fp_attention=True forces the Sparse pathway's Q/K/V projections to FP32 even when the global quantize is True. From the Spectrum spinoff insight: attention is the precision-critical operation where ternary hurts most. """ if kind == "local": return TernaryCausalConv1d( d_model, kernel_size=kernel_size, quantize=quantize, per_row=per_row, lsq=lsq, ) if kind == "wide": return TernaryCausalConv1d( d_model, kernel_size=wide_kernel_size, quantize=quantize, per_row=per_row, lsq=lsq, ) if kind == "state": return DiagonalSSM(d_model) if kind == "sparse": attn_quantize = False if fp_attention else quantize return SparseCausalAttention( d_model, d_head=d_head, top_k=top_k, quantize=attn_quantize, ) if kind == "dense": return TernaryFFN( d_model, expand=dense_expand, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, ) raise ValueError(f"unknown pathway kind: {kind}") class TilelliBlock(nn.Module): """One Tilelli block: parallel heterogeneous pathways mixed by a router. Parameters ---------- n_banks : int, default 1 Number of weight banks per pathway (Ternary Dispenser). 1 = original. >1 = MoE at the weight level: each pathway holds n_banks copies, the router argmax-picks one bank per token. Adds a load-balancing aux loss accessible via .aux_loss after each forward. per_row, hadamard, lsq : bool Ternary-quantization tricks forwarded to TernaryLinear / Conv. All default off so the existing aurora-ternary baseline stays identical. skip_threshold, skip_mode : as before — only used by .infer(). """ def __init__( self, d_model: int, d_head: int = 32, kernel_size: int = 5, wide_kernel_size: int = 21, top_k: int = 8, pathways: int = 5, n_banks: int = 1, skip_threshold: float = 0.05, skip_mode: str = "per_call", quantize: bool = True, per_row: bool = False, hadamard: bool = False, lsq: bool = False, dense_expand: int = 2, fp_attention: bool = False, top_k_routing: int = 0, ) -> None: super().__init__() if pathways not in (3, 5): raise ValueError(f"pathways must be 3 or 5, got {pathways}") if skip_mode not in ("per_call", "per_token"): raise ValueError(f"skip_mode must be 'per_call' or 'per_token', got {skip_mode!r}") if n_banks < 1: raise ValueError(f"n_banks must be >= 1, got {n_banks}") self.d_model = d_model self.pathways = pathways self.n_banks = n_banks self.skip_threshold = skip_threshold self.skip_mode = skip_mode self.quantize = quantize self.top_k_routing = top_k_routing self.pathway_names = PATHWAY_NAMES_5 if pathways == 5 else PATHWAY_NAMES_3 self.norm = nn.LayerNorm(d_model) def _build(kind: str) -> nn.Module | nn.ModuleList: mk = lambda: _make_pathway( kind, d_model, d_head, kernel_size, wide_kernel_size, top_k, quantize, per_row, hadamard, lsq, dense_expand, fp_attention, ) if n_banks <= 1: return mk() return nn.ModuleList([mk() for _ in range(n_banks)]) self.local = _build("local") self.state = _build("state") self.sparse = _build("sparse") if pathways == 5: self.wide = _build("wide") self.dense = _build("dense") # Router: routes over (pathway × bank) when n_banks > 1, else pathways. n_router_outputs = pathways * n_banks self.router = TernaryLinear( d_model, n_router_outputs, quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq, ) self._aux_loss = torch.tensor(0.0) def _pathway_modules(self) -> list[tuple[str, nn.Module | nn.ModuleList]]: if self.pathways == 5: return [ ("local", self.local), ("wide", self.wide), ("state", self.state), ("sparse", self.sparse), ("dense", self.dense), ] return [ ("local", self.local), ("state", self.state), ("sparse", self.sparse), ] def _compute_single_bank(self, h: Tensor, r: Tensor) -> Tensor: outputs = [mod(h) for _, mod in self._pathway_modules()] return sum(r[..., i:i + 1] * outputs[i] for i in range(len(outputs))) def _compute_multi_bank(self, h: Tensor, r: Tensor) -> Tensor: """Multi-bank dispenser: per-token top-1 bank selection per pathway. r shape: (B, L, n_pathways * n_banks) """ B, L, _ = r.shape plist = self._pathway_modules() n_paths = len(plist) r_2d = r.view(B, L, n_paths, self.n_banks) pathway_weights = r_2d.sum(dim=-1) # (B, L, n_paths) bank_idx = r_2d.argmax(dim=-1) # (B, L, n_paths) # Load balance: each bank should be selected ~1/n_banks of the time. bank_probs = r_2d.mean(dim=(0, 1)) # (n_paths, n_banks) target = 1.0 / self.n_banks self._aux_loss = ((bank_probs - target) ** 2).mean() * 0.01 mixed = torch.zeros(B, L, self.d_model, device=h.device, dtype=h.dtype) for p_idx, (_name, banks) in enumerate(plist): pw = pathway_weights[..., p_idx:p_idx + 1] # (B, L, 1) bidx = bank_idx[..., p_idx] # (B, L) for b in range(self.n_banks): mask = (bidx == b) if not mask.any(): continue out = banks[b](h) mixed = mixed + pw * out * mask.unsqueeze(-1).to(out.dtype) return mixed def _maybe_topk_route(self, r: Tensor) -> Tensor: """Optionally restrict routing to the top-k pathways per token (Mixtral-style).""" if self.top_k_routing <= 0 or self.top_k_routing >= r.shape[-1]: return r top_vals, top_idx = r.topk(self.top_k_routing, dim=-1) mask = torch.zeros_like(r) mask.scatter_(-1, top_idx, top_vals) return mask / mask.sum(dim=-1, keepdim=True).clamp(min=1e-12) def forward(self, x: Tensor) -> Tensor: h = self.norm(x) r = torch.softmax(self.router(h), dim=-1) r = self._maybe_topk_route(r) if self.n_banks <= 1: mixed = self._compute_single_bank(h, r) else: mixed = self._compute_multi_bank(h, r) return x + mixed @property def aux_loss(self) -> Tensor: """Load-balancing loss for multi-bank. Add to main loss during training.""" return self._aux_loss @torch.no_grad() def infer(self, x: Tensor) -> Tensor: h = self.norm(x) r = torch.softmax(self.router(h), dim=-1) if self.n_banks > 1: return x + self._compute_multi_bank(h, r) y = torch.zeros_like(x) if self.skip_mode == "per_call": r_max = r.amax(dim=(0, 1)) for i, (_, mod) in enumerate(self._pathway_modules()): if r_max[i].item() >= self.skip_threshold: step = mod.infer(h) if hasattr(mod, "infer") else mod(h) y = y + r[..., i:i + 1] * step return x + y for i, (_, mod) in enumerate(self._pathway_modules()): step = mod.infer(h) if hasattr(mod, "infer") else mod(h) mask = (r[..., i:i + 1] >= self.skip_threshold).to(step.dtype) y = y + mask * r[..., i:i + 1] * step return x + y @torch.no_grad() def router_weights(self, x: Tensor) -> Tensor: """Per-token router distribution. For single-bank: shape (B, L, n_pathways). For multi-bank: pathway-level weights (banks summed). Shape (B, L, n_pathways). """ r = torch.softmax(self.router(self.norm(x)), dim=-1) if self.n_banks > 1: B, L, _ = r.shape n_paths = len(self._pathway_modules()) return r.view(B, L, n_paths, self.n_banks).sum(dim=-1) return r @torch.no_grad() def router_entropy(self, x: Tensor) -> Tensor: r = self.router_weights(x).clamp_min(1e-12) return -(r * r.log()).sum(dim=-1)