| """tilelli.core.ternary_ssm — the State pathway of Tilelli. |
| |
| From ARCHITECTURE.md: |
| State path: small Mamba-style SSM. Long-range topic carry. O(n). |
| |
| Day-0 scope: a **diagonal** state-space model — one independent scalar |
| recurrence per channel — which is the S4D / HiPPO-diag skeleton that |
| Mamba is built on. We skip Mamba's data-dependent selection for now; |
| that's a refinement on top of a working diagonal SSM, not the core idea. |
| |
| The per-channel recurrence: |
| |
| h_t[c] = a[c] · h_{t-1}[c] + b[c] · x_t[c] |
| y_t[c] = c[c] · h_t[c] |
| |
| Three learnable per-channel scalars: decay `a`, input gain `b`, output |
| scale `c`. Stability demands |a| < 1; we enforce that with `tanh(a_raw)`. |
| |
| Training uses the **convolutional mode** — because the recurrence is |
| linear and diagonal, y_t unrolls to a 1-D convolution with kernel |
| |
| K[c, i] = c[c] · a[c]^i · b[c] for i = 0 … L-1 |
| |
| so a single depthwise `F.conv1d` gives us the whole output sequence in |
| one shot. This is the S4 trick. Inference uses the recurrent mode — a |
| simple per-step state update, O(L · C) sequential — which is what |
| Tilelli will actually run on CPU one token at a time. |
| |
| A note on ternary weights here: |
| The per-channel scalars are only O(C) parameters, vs O(C²) for the |
| Linear layers. Ternarizing them saves almost nothing and makes the |
| decay dynamics much harder to learn (decay must be in (0, 1), which |
| ternary {-α, 0, +α} can't cleanly express). We keep these few |
| parameters in FP32 and are honest about it: the SSM is the one place |
| in Tilelli where a little floating point lives. The big consumers — |
| Linear and Conv — remain pure ternary. |
| """ |
| from __future__ import annotations |
|
|
| import math |
|
|
| import torch |
| from torch import Tensor, nn |
| from torch.nn import functional as F |
|
|
|
|
| class DiagonalSSM(nn.Module): |
| """Per-channel diagonal state-space model. Input/output shape (B, L, C). |
| |
| Parameters are three per-channel vectors: |
| - ``a_raw`` : pre-tanh decay; effective a = tanh(a_raw) ∈ (-1, 1) |
| - ``b`` : input gain |
| - ``c_out`` : output scale |
| |
| The state dimension equals the channel count (one scalar state per |
| channel). For a wider state per channel, stack multiple DiagonalSSMs |
| or move to a non-diagonal variant. |
| """ |
|
|
| def __init__(self, channels: int) -> None: |
| super().__init__() |
| self.channels = channels |
| |
| |
| self.a_raw = nn.Parameter(torch.full((channels,), 1.5)) |
| self.b = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels))) |
| self.c_out = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels))) |
|
|
| |
| |
| |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if x.dim() != 3: |
| raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") |
| B, L, C = x.shape |
| if C != self.channels: |
| raise ValueError(f"channel mismatch: module has {self.channels}, input has {C}") |
|
|
| a = torch.tanh(self.a_raw) |
| b = self.b |
| c_out = self.c_out |
|
|
| |
| |
| |
| |
| |
| i = torch.arange(L - 1, -1, -1, device=x.device, dtype=x.dtype) |
| powers = a.unsqueeze(-1) ** i.unsqueeze(0) |
| kernel = (c_out * b).unsqueeze(-1) * powers |
| kernel = kernel.unsqueeze(1) |
|
|
| |
| x_ = x.transpose(1, 2) |
| x_ = F.pad(x_, (L - 1, 0)) |
| y = F.conv1d(x_, kernel, groups=C) |
| return y.transpose(1, 2) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def infer(self, x: Tensor) -> Tensor: |
| """Step-by-step recurrence. Agrees with `forward` numerically. |
| |
| This is the path Tilelli runs at CPU inference time — one |
| token in, one token out, state of shape (B, C) carried across |
| steps. No L² kernel to build. |
| """ |
| if x.dim() != 3: |
| raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}") |
| B, L, C = x.shape |
| a = torch.tanh(self.a_raw) |
| b = self.b |
| c_out = self.c_out |
| h = torch.zeros(B, C, dtype=x.dtype, device=x.device) |
| ys = [] |
| for t in range(L): |
| h = a * h + b * x[:, t] |
| ys.append(c_out * h) |
| return torch.stack(ys, dim=1) |
|
|