TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
5.34 kB
"""tilelli.core.ternary_ssm — the State pathway of Tilelli.
From ARCHITECTURE.md:
State path: small Mamba-style SSM. Long-range topic carry. O(n).
Day-0 scope: a **diagonal** state-space model — one independent scalar
recurrence per channel — which is the S4D / HiPPO-diag skeleton that
Mamba is built on. We skip Mamba's data-dependent selection for now;
that's a refinement on top of a working diagonal SSM, not the core idea.
The per-channel recurrence:
h_t[c] = a[c] · h_{t-1}[c] + b[c] · x_t[c]
y_t[c] = c[c] · h_t[c]
Three learnable per-channel scalars: decay `a`, input gain `b`, output
scale `c`. Stability demands |a| < 1; we enforce that with `tanh(a_raw)`.
Training uses the **convolutional mode** — because the recurrence is
linear and diagonal, y_t unrolls to a 1-D convolution with kernel
K[c, i] = c[c] · a[c]^i · b[c] for i = 0 … L-1
so a single depthwise `F.conv1d` gives us the whole output sequence in
one shot. This is the S4 trick. Inference uses the recurrent mode — a
simple per-step state update, O(L · C) sequential — which is what
Tilelli will actually run on CPU one token at a time.
A note on ternary weights here:
The per-channel scalars are only O(C) parameters, vs O(C²) for the
Linear layers. Ternarizing them saves almost nothing and makes the
decay dynamics much harder to learn (decay must be in (0, 1), which
ternary {-α, 0, +α} can't cleanly express). We keep these few
parameters in FP32 and are honest about it: the SSM is the one place
in Tilelli where a little floating point lives. The big consumers —
Linear and Conv — remain pure ternary.
"""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from torch.nn import functional as F
class DiagonalSSM(nn.Module):
"""Per-channel diagonal state-space model. Input/output shape (B, L, C).
Parameters are three per-channel vectors:
- ``a_raw`` : pre-tanh decay; effective a = tanh(a_raw) ∈ (-1, 1)
- ``b`` : input gain
- ``c_out`` : output scale
The state dimension equals the channel count (one scalar state per
channel). For a wider state per channel, stack multiple DiagonalSSMs
or move to a non-diagonal variant.
"""
def __init__(self, channels: int) -> None:
super().__init__()
self.channels = channels
# Init decay near 0.9 so early training has long-ish memory.
# tanh(1.5) ≈ 0.905.
self.a_raw = nn.Parameter(torch.full((channels,), 1.5))
self.b = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels)))
self.c_out = nn.Parameter(torch.randn(channels) * (1.0 / math.sqrt(channels)))
# ------------------------------------------------------------------ #
# Training forward — convolutional mode
# ------------------------------------------------------------------ #
def forward(self, x: Tensor) -> Tensor:
if x.dim() != 3:
raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
B, L, C = x.shape
if C != self.channels:
raise ValueError(f"channel mismatch: module has {self.channels}, input has {C}")
a = torch.tanh(self.a_raw) # (C,), in (-1, 1)
b = self.b # (C,)
c_out = self.c_out # (C,)
# Build the per-channel causal kernel. We want
# y_t = sum_{d=0}^{L-1} (c_out * a^d * b) * x_{t-d}
# torch.conv1d is cross-correlation: with left-pad L-1, the
# LAST kernel element is delay 0, so the powers must run from
# (L-1) down to 0 across the kernel's spatial axis.
i = torch.arange(L - 1, -1, -1, device=x.device, dtype=x.dtype) # (L,)
powers = a.unsqueeze(-1) ** i.unsqueeze(0) # (C, L)
kernel = (c_out * b).unsqueeze(-1) * powers # (C, L)
kernel = kernel.unsqueeze(1) # (C, 1, L)
# Depthwise causal conv: left-pad L-1, groups=C
x_ = x.transpose(1, 2) # (B, C, L)
x_ = F.pad(x_, (L - 1, 0))
y = F.conv1d(x_, kernel, groups=C)
return y.transpose(1, 2) # (B, L, C)
# ------------------------------------------------------------------ #
# Inference — recurrent mode, O(L·C) sequential
# ------------------------------------------------------------------ #
@torch.no_grad()
def infer(self, x: Tensor) -> Tensor:
"""Step-by-step recurrence. Agrees with `forward` numerically.
This is the path Tilelli runs at CPU inference time — one
token in, one token out, state of shape (B, C) carried across
steps. No L² kernel to build.
"""
if x.dim() != 3:
raise ValueError(f"expected (B, L, C), got shape {tuple(x.shape)}")
B, L, C = x.shape
a = torch.tanh(self.a_raw)
b = self.b
c_out = self.c_out
h = torch.zeros(B, C, dtype=x.dtype, device=x.device)
ys = []
for t in range(L):
h = a * h + b * x[:, t]
ys.append(c_out * h)
return torch.stack(ys, dim=1) # (B, L, C)