Tilelli-llm / src /tilelli /core /tilelli_block.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
10.5 kB
"""tilelli.core.tilelli_block — heterogeneous-pathway block with a per-token
soft router.
Up to five structurally-different operations run in parallel on the same
input, mixed by a per-token softmax router. Optional Ternary Dispenser
(n_banks > 1) replicates each pathway across n_banks weight banks; the
router dispatches both pathway and bank per token. Compute per token stays
constant; parameter capacity multiplies by n_banks.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from tilelli.core.sparse_attention import SparseCausalAttention
from tilelli.core.ssm import DiagonalSSM
from tilelli.core.ternary_conv import TernaryCausalConv1d
from tilelli.core.ternary_linear import TernaryLinear
PATHWAY_NAMES_3 = ("local", "state", "sparse")
PATHWAY_NAMES_5 = ("local", "wide", "state", "sparse", "dense")
class TernaryFFN(nn.Module):
"""Tiny feed-forward network with ternary weights: d → expand·d → d."""
def __init__(
self,
d_model: int,
expand: int = 2,
quantize: bool = True,
per_row: bool = False,
hadamard: bool = False,
lsq: bool = False,
) -> None:
super().__init__()
d_inner = d_model * expand
self.up = TernaryLinear(
d_model, d_inner,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
self.down = TernaryLinear(
d_inner, d_model,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
def forward(self, x: Tensor) -> Tensor:
return self.down(torch.nn.functional.gelu(self.up(x)))
def _make_pathway(
kind: str,
d_model: int,
d_head: int,
kernel_size: int,
wide_kernel_size: int,
top_k: int,
quantize: bool,
per_row: bool,
hadamard: bool,
lsq: bool,
dense_expand: int,
fp_attention: bool,
) -> nn.Module:
"""Build a single pathway module of the named kind.
fp_attention=True forces the Sparse pathway's Q/K/V projections to FP32
even when the global quantize is True. From the Spectrum spinoff insight:
attention is the precision-critical operation where ternary hurts most.
"""
if kind == "local":
return TernaryCausalConv1d(
d_model, kernel_size=kernel_size,
quantize=quantize, per_row=per_row, lsq=lsq,
)
if kind == "wide":
return TernaryCausalConv1d(
d_model, kernel_size=wide_kernel_size,
quantize=quantize, per_row=per_row, lsq=lsq,
)
if kind == "state":
return DiagonalSSM(d_model)
if kind == "sparse":
attn_quantize = False if fp_attention else quantize
return SparseCausalAttention(
d_model, d_head=d_head, top_k=top_k, quantize=attn_quantize,
)
if kind == "dense":
return TernaryFFN(
d_model, expand=dense_expand,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
raise ValueError(f"unknown pathway kind: {kind}")
class TilelliBlock(nn.Module):
"""One Tilelli block: parallel heterogeneous pathways mixed by a router.
Parameters
----------
n_banks : int, default 1
Number of weight banks per pathway (Ternary Dispenser). 1 = original.
>1 = MoE at the weight level: each pathway holds n_banks copies, the
router argmax-picks one bank per token. Adds a load-balancing aux
loss accessible via .aux_loss after each forward.
per_row, hadamard, lsq : bool
Ternary-quantization tricks forwarded to TernaryLinear / Conv. All
default off so the existing aurora-ternary baseline stays identical.
skip_threshold, skip_mode : as before — only used by .infer().
"""
def __init__(
self,
d_model: int,
d_head: int = 32,
kernel_size: int = 5,
wide_kernel_size: int = 21,
top_k: int = 8,
pathways: int = 5,
n_banks: int = 1,
skip_threshold: float = 0.05,
skip_mode: str = "per_call",
quantize: bool = True,
per_row: bool = False,
hadamard: bool = False,
lsq: bool = False,
dense_expand: int = 2,
fp_attention: bool = False,
top_k_routing: int = 0,
) -> None:
super().__init__()
if pathways not in (3, 5):
raise ValueError(f"pathways must be 3 or 5, got {pathways}")
if skip_mode not in ("per_call", "per_token"):
raise ValueError(f"skip_mode must be 'per_call' or 'per_token', got {skip_mode!r}")
if n_banks < 1:
raise ValueError(f"n_banks must be >= 1, got {n_banks}")
self.d_model = d_model
self.pathways = pathways
self.n_banks = n_banks
self.skip_threshold = skip_threshold
self.skip_mode = skip_mode
self.quantize = quantize
self.top_k_routing = top_k_routing
self.pathway_names = PATHWAY_NAMES_5 if pathways == 5 else PATHWAY_NAMES_3
self.norm = nn.LayerNorm(d_model)
def _build(kind: str) -> nn.Module | nn.ModuleList:
mk = lambda: _make_pathway(
kind, d_model, d_head, kernel_size, wide_kernel_size,
top_k, quantize, per_row, hadamard, lsq, dense_expand,
fp_attention,
)
if n_banks <= 1:
return mk()
return nn.ModuleList([mk() for _ in range(n_banks)])
self.local = _build("local")
self.state = _build("state")
self.sparse = _build("sparse")
if pathways == 5:
self.wide = _build("wide")
self.dense = _build("dense")
# Router: routes over (pathway × bank) when n_banks > 1, else pathways.
n_router_outputs = pathways * n_banks
self.router = TernaryLinear(
d_model, n_router_outputs,
quantize=quantize, per_row=per_row, hadamard=hadamard, lsq=lsq,
)
self._aux_loss = torch.tensor(0.0)
def _pathway_modules(self) -> list[tuple[str, nn.Module | nn.ModuleList]]:
if self.pathways == 5:
return [
("local", self.local),
("wide", self.wide),
("state", self.state),
("sparse", self.sparse),
("dense", self.dense),
]
return [
("local", self.local),
("state", self.state),
("sparse", self.sparse),
]
def _compute_single_bank(self, h: Tensor, r: Tensor) -> Tensor:
outputs = [mod(h) for _, mod in self._pathway_modules()]
return sum(r[..., i:i + 1] * outputs[i] for i in range(len(outputs)))
def _compute_multi_bank(self, h: Tensor, r: Tensor) -> Tensor:
"""Multi-bank dispenser: per-token top-1 bank selection per pathway.
r shape: (B, L, n_pathways * n_banks)
"""
B, L, _ = r.shape
plist = self._pathway_modules()
n_paths = len(plist)
r_2d = r.view(B, L, n_paths, self.n_banks)
pathway_weights = r_2d.sum(dim=-1) # (B, L, n_paths)
bank_idx = r_2d.argmax(dim=-1) # (B, L, n_paths)
# Load balance: each bank should be selected ~1/n_banks of the time.
bank_probs = r_2d.mean(dim=(0, 1)) # (n_paths, n_banks)
target = 1.0 / self.n_banks
self._aux_loss = ((bank_probs - target) ** 2).mean() * 0.01
mixed = torch.zeros(B, L, self.d_model, device=h.device, dtype=h.dtype)
for p_idx, (_name, banks) in enumerate(plist):
pw = pathway_weights[..., p_idx:p_idx + 1] # (B, L, 1)
bidx = bank_idx[..., p_idx] # (B, L)
for b in range(self.n_banks):
mask = (bidx == b)
if not mask.any():
continue
out = banks[b](h)
mixed = mixed + pw * out * mask.unsqueeze(-1).to(out.dtype)
return mixed
def _maybe_topk_route(self, r: Tensor) -> Tensor:
"""Optionally restrict routing to the top-k pathways per token (Mixtral-style)."""
if self.top_k_routing <= 0 or self.top_k_routing >= r.shape[-1]:
return r
top_vals, top_idx = r.topk(self.top_k_routing, dim=-1)
mask = torch.zeros_like(r)
mask.scatter_(-1, top_idx, top_vals)
return mask / mask.sum(dim=-1, keepdim=True).clamp(min=1e-12)
def forward(self, x: Tensor) -> Tensor:
h = self.norm(x)
r = torch.softmax(self.router(h), dim=-1)
r = self._maybe_topk_route(r)
if self.n_banks <= 1:
mixed = self._compute_single_bank(h, r)
else:
mixed = self._compute_multi_bank(h, r)
return x + mixed
@property
def aux_loss(self) -> Tensor:
"""Load-balancing loss for multi-bank. Add to main loss during training."""
return self._aux_loss
@torch.no_grad()
def infer(self, x: Tensor) -> Tensor:
h = self.norm(x)
r = torch.softmax(self.router(h), dim=-1)
if self.n_banks > 1:
return x + self._compute_multi_bank(h, r)
y = torch.zeros_like(x)
if self.skip_mode == "per_call":
r_max = r.amax(dim=(0, 1))
for i, (_, mod) in enumerate(self._pathway_modules()):
if r_max[i].item() >= self.skip_threshold:
step = mod.infer(h) if hasattr(mod, "infer") else mod(h)
y = y + r[..., i:i + 1] * step
return x + y
for i, (_, mod) in enumerate(self._pathway_modules()):
step = mod.infer(h) if hasattr(mod, "infer") else mod(h)
mask = (r[..., i:i + 1] >= self.skip_threshold).to(step.dtype)
y = y + mask * r[..., i:i + 1] * step
return x + y
@torch.no_grad()
def router_weights(self, x: Tensor) -> Tensor:
"""Per-token router distribution.
For single-bank: shape (B, L, n_pathways).
For multi-bank: pathway-level weights (banks summed). Shape (B, L, n_pathways).
"""
r = torch.softmax(self.router(self.norm(x)), dim=-1)
if self.n_banks > 1:
B, L, _ = r.shape
n_paths = len(self._pathway_modules())
return r.view(B, L, n_paths, self.n_banks).sum(dim=-1)
return r
@torch.no_grad()
def router_entropy(self, x: Tensor) -> Tensor:
r = self.router_weights(x).clamp_min(1e-12)
return -(r * r.log()).sum(dim=-1)