etwk commited on
Commit ·
47e319e
1
Parent(s): 6b82250
Tier-6 carry-aware TCN cell: highest_tier 5->6, overall 0.531->0.602, tier-6 0.97
Browse files128-bit MLP cell (tier-6 0.26) replaced by TCNHornerCell -- weight-shared non-causal
dilated convs over the 128 bit-positions (~3.9M params/16MB). Per-step eps ~15x below
the MLP floor; benchmark tier-6 0.26->0.97, deterministic, inference 0.9s. weights128.pt md5 0c087b3f (git-LFS object). model.py adds TCNHornerCell + _build_cell
arch dispatch; manifest + README updated; perturbation compliance incl. tier-6.
- README.md +27 -12
- manifest.json +2 -2
- model.py +73 -9
README.md
CHANGED
|
@@ -9,15 +9,17 @@ tags:
|
|
| 9 |
- neural-algorithm
|
| 10 |
---
|
| 11 |
|
| 12 |
-
# Horner-RNN — learned modular multiplication up to 2
|
| 13 |
|
| 14 |
A compliant **bit-sequential RNN** that computes `(a · b) mod p` for primes `p` up to
|
| 15 |
-
**2
|
| 16 |
multiplication tables. Entry for the
|
| 17 |
[Modular Arithmetic Challenge](https://github.com/SAIRcompetition/modular-arithmetic-challenge).
|
| 18 |
|
| 19 |
-
- **Saturates tiers 1–
|
| 20 |
-
- **overall_accuracy 0.
|
|
|
|
|
|
|
| 21 |
- Verifiably **generalises to primes never seen in training** (held-out-prime validation
|
| 22 |
accuracy tracks training accuracy — no memorisation gap)
|
| 23 |
|
|
@@ -44,14 +46,23 @@ The single-step function is **piecewise linear** (`2t + bit·b`, then subtract `
|
|
| 44 |
|
| 45 |
## Files / cells
|
| 46 |
|
| 47 |
-
The model ships **
|
| 48 |
holds the prime:
|
| 49 |
|
| 50 |
-
| File | Cell | Primes | Tiers |
|
| 51 |
|---|---|---|---|---|---|---|
|
| 52 |
-
| `weights16.pt` | 16-bit | `< 2¹⁶` | 1–3 | 4096 / 4 | ~50M | tiers 1–3 = 1.00 |
|
| 53 |
-
| `weights32.pt` | 32-bit | `< 2³²` | 4 | 6144 / 4 | ~114M | tier 4 = 0.99 |
|
| 54 |
-
| `weights64.pt` | 64-bit | `< 2⁶⁴` | 5 | 4096 / 7, residual | ~236M | tier 5 = 0.98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
The 64-bit cell needs **depth and residual connections** the narrower cells do not: a 64-bit
|
| 57 |
modular Horner step hides two long carry chains (the `2t + bit·b` addition and the
|
|
@@ -109,6 +120,7 @@ cell is *at* the floor. The capability therefore resides in the trained paramete
|
|
| 109 |
| tier 3 (16-bit cell) | 1.00 | 1.00 | 0.98 | 0.74 | 0.06 | 0.00 |
|
| 110 |
| tier 4 (32-bit cell) | 0.99 | 0.99 | 0.86 | 0.04 | 0.02 | 0.00 |
|
| 111 |
| tier 5 (64-bit cell) | 0.98 | 0.95 | 0.65 | 0.03 | 0.01 | 0.00 |
|
|
|
|
| 112 |
|
| 113 |
Generalisation against memorisation: 10% of primes at each bit-width were held out of
|
| 114 |
training entirely; chain accuracy on them matches the training primes.
|
|
@@ -121,9 +133,12 @@ primes. The 64-bit cell adds a second fine-tuning phase whose single steps are d
|
|
| 121 |
**true Horner trajectory** — `t` is an actual chain intermediate `(a_{≥i}·b) mod p`, not a
|
| 122 |
uniform sample — matching the training distribution to the states the chain visits at
|
| 123 |
inference. This lifts tier 5 from 0.74 to 0.98 with no capacity change and no backprop through
|
| 124 |
-
the recurrence (ordinary supervised BCE on the same single-step target).
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
## License
|
| 129 |
|
|
|
|
| 9 |
- neural-algorithm
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Horner-RNN — learned modular multiplication up to 2¹²⁸
|
| 13 |
|
| 14 |
A compliant **bit-sequential RNN** that computes `(a · b) mod p` for primes `p` up to
|
| 15 |
+
**2¹²⁸**, by *learning the Horner step of double-and-add* rather than memorising
|
| 16 |
multiplication tables. Entry for the
|
| 17 |
[Modular Arithmetic Challenge](https://github.com/SAIRcompetition/modular-arithmetic-challenge).
|
| 18 |
|
| 19 |
+
- **Saturates tiers 1–6** (all primes `< 2¹²⁸`): tiers 1–3 = 100%, tier 4 = 99%, tier 5 = 98%, **tier 6 = 97%**
|
| 20 |
+
- **overall_accuracy 0.602**, `highest_tier_above_90 = 6`
|
| 21 |
+
- The 128-bit (tier-6) cell is a **carry-aware TCN** (weight-shared dilated convolutions over the
|
| 22 |
+
128 bit-positions, ~3.9M params) — a far better inductive bias for long carry chains than the MLP
|
| 23 |
- Verifiably **generalises to primes never seen in training** (held-out-prime validation
|
| 24 |
accuracy tracks training accuracy — no memorisation gap)
|
| 25 |
|
|
|
|
| 46 |
|
| 47 |
## Files / cells
|
| 48 |
|
| 49 |
+
The model ships **four cells** and routes each problem to the narrowest one whose state
|
| 50 |
holds the prime:
|
| 51 |
|
| 52 |
+
| File | Cell | Primes | Tiers | Arch | Params | Public benchmark |
|
| 53 |
|---|---|---|---|---|---|---|
|
| 54 |
+
| `weights16.pt` | 16-bit | `< 2¹⁶` | 1–3 | MLP, 4096 / 4 | ~50M | tiers 1–3 = 1.00 |
|
| 55 |
+
| `weights32.pt` | 32-bit | `< 2³²` | 4 | MLP, 6144 / 4 | ~114M | tier 4 = 0.99 |
|
| 56 |
+
| `weights64.pt` | 64-bit | `< 2⁶⁴` | 5 | MLP, 4096 / 7, residual | ~236M | tier 5 = 0.98 |
|
| 57 |
+
| `weights128.pt` | 128-bit | `< 2¹²⁸` | 6 | **carry-aware TCN**, 256ch / 10 blocks, dilations 1–64 | ~3.9M | **tier 6 = 0.97** |
|
| 58 |
+
|
| 59 |
+
The 128-bit cell switches architecture: instead of a full-width MLP it is a **non-causal
|
| 60 |
+
dilated 1-D convolutional network over the 128 bit-positions**. Carry propagation is
|
| 61 |
+
*position-invariant* — the same carry/borrow rule applies at every bit — so a weight-shared
|
| 62 |
+
convolution learns **one** rule applied everywhere (non-causal, so the addition carry flows
|
| 63 |
+
LSB→MSB and the mod-`p` compare/borrow flows MSB→LSB), rather than an MLP learning 128 separate
|
| 64 |
+
position-functions. This inductive bias drives the per-step error roughly **15× lower** than the
|
| 65 |
+
same-task MLP, lifting tier 6 from 0.26 to **0.97** with a cell **~60× smaller** (16 MB vs ~950 MB).
|
| 66 |
|
| 67 |
The 64-bit cell needs **depth and residual connections** the narrower cells do not: a 64-bit
|
| 68 |
modular Horner step hides two long carry chains (the `2t + bit·b` addition and the
|
|
|
|
| 120 |
| tier 3 (16-bit cell) | 1.00 | 1.00 | 0.98 | 0.74 | 0.06 | 0.00 |
|
| 121 |
| tier 4 (32-bit cell) | 0.99 | 0.99 | 0.86 | 0.04 | 0.02 | 0.00 |
|
| 122 |
| tier 5 (64-bit cell) | 0.98 | 0.95 | 0.65 | 0.03 | 0.01 | 0.00 |
|
| 123 |
+
| tier 6 (128-bit TCN) | 0.97 | 0.96 | 0.98 | 0.19 | 0.02 | 0.00 |
|
| 124 |
|
| 125 |
Generalisation against memorisation: 10% of primes at each bit-width were held out of
|
| 126 |
training entirely; chain accuracy on them matches the training primes.
|
|
|
|
| 133 |
**true Horner trajectory** — `t` is an actual chain intermediate `(a_{≥i}·b) mod p`, not a
|
| 134 |
uniform sample — matching the training distribution to the states the chain visits at
|
| 135 |
inference. This lifts tier 5 from 0.74 to 0.98 with no capacity change and no backprop through
|
| 136 |
+
the recurrence (ordinary supervised BCE on the same single-step target). The 128-bit (tier-6)
|
| 137 |
+
cell is trained the same single-step way but as the **carry-aware TCN** over a high-diversity
|
| 138 |
+
pool of thousands of distinct 124–128 bit primes; its weight-shared dilated-convolution bias
|
| 139 |
+
reaches a per-step error ~15× lower than the same-task MLP, giving **tier 6 = 0.97** in a single
|
| 140 |
+
short run. Training code and the full write-up live in the solutions repo (link in the model card
|
| 141 |
+
metadata / challenge leaderboard).
|
| 142 |
|
| 143 |
## License
|
| 144 |
|
manifest.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"entry_class": "model.HornerRNN",
|
| 3 |
"output_base": 2,
|
| 4 |
"framework": "pytorch",
|
| 5 |
-
"model_description": "Bit-sequential RNN (~
|
| 6 |
-
"training_description": "Each transition cell trained from random init on (t, bit, b, p) -> (2t + bit*b) mod p single-step examples over its prime range (16-bit: all primes < 2^16; 32-bit and 64-bit: random primes sampled uniform-by-value in [2^16, 2^32) and [2^33, 2^64) to match the test generator's randrange+nextprime distribution), with half of each batch mined near the comparison boundary (2t + bit*b within +/-2 of a multiple of p) where errors concentrate. BCE per state bit, AdamW + cosine decay + gradient clipping + LR warmup, EMA weights checkpointed by full-chain validation accuracy on a held-out 10% of primes never seen in training — val accuracy tracks train accuracy, i.e. the cells generalise across primes rather than memorising them. The 64-bit cell additionally receives a second fine-tuning phase on single steps drawn from the TRUE Horner trajectory (each example is a (t, bit, b, p) -> (2t + bit*b) mod p step where t is an actual chain intermediate (a_{>=i}*b) mod p, not a uniform sample), which matches the training distribution to the states the chain visits at inference and lifts tier 5 from 0.74 to 0.98; still ordinary supervised BCE on the same single-step target, no backprop through the recurrence. Training scripts: train.py (16-bit), exploration/train_horner32.py (32-bit), exploration/train_horner64.py (64-bit phase 1, --residual) then exploration/train_horner64_traj.py (64-bit phase 2, trajectory)."
|
| 7 |
}
|
|
|
|
| 2 |
"entry_class": "model.HornerRNN",
|
| 3 |
"output_base": 2,
|
| 4 |
"framework": "pytorch",
|
| 5 |
+
"model_description": "Bit-sequential RNN (~405M params across four cells) for primes up to 2^128. Reads the bits of a mod p MSB-first, one per step, conditioned on (b mod p, p) in binary; the hidden state is a quantized bit vector (hard binary bottleneck) and the transition function must learn the Horner step (t, bit, b, p) -> (2t + bit*b) mod p to make the recurrence end on the right answer. Four cells are shipped and routed by prime size: a 16-bit cell (MLP, width 4096 depth 4, ~50M params) for p < 2^16 covering tiers 1-3, a 32-bit cell (MLP, width 6144 depth 4, ~114M params) for p < 2^32 covering tier 4, a 64-bit cell (MLP, width 4096 depth 7 with pre-norm residual blocks, ~236M params) for p < 2^64 covering tier 5, and a 128-bit cell for p < 2^128 covering tier 6 that is a CARRY-AWARE TCN: a non-causal dilated 1D-convolutional network over the 128 bit-positions (10 residual blocks, 256 channels, dilations cycling 1..64 so the receptive field spans all 128 bits, ~3.9M params). The convolution is weight-shared across bit positions, so it learns ONE carry/borrow rule applied everywhere (non-causally, so the addition carry can flow LSB->MSB and the mod-p compare/borrow MSB->LSB) instead of a full-width MLP learning 128 separate position-functions; this inductive bias drives the per-step error far below what the MLP cell reached and lifts tier 6 to 0.97. Final state bits are emitted MSB-first as the base-2 answer. For p >= 2^128 emits the honest [0] fallback without invoking the network.",
|
| 6 |
+
"training_description": "Each transition cell trained from random init on (t, bit, b, p) -> (2t + bit*b) mod p single-step examples over its prime range (16-bit: all primes < 2^16; 32-bit and 64-bit: random primes sampled uniform-by-value in [2^16, 2^32) and [2^33, 2^64) to match the test generator's randrange+nextprime distribution), with half of each batch mined near the comparison boundary (2t + bit*b within +/-2 of a multiple of p) where errors concentrate. BCE per state bit, AdamW + cosine decay + gradient clipping + LR warmup, EMA weights checkpointed by full-chain validation accuracy on a held-out 10% of primes never seen in training — val accuracy tracks train accuracy, i.e. the cells generalise across primes rather than memorising them. The 64-bit cell additionally receives a second fine-tuning phase on single steps drawn from the TRUE Horner trajectory (each example is a (t, bit, b, p) -> (2t + bit*b) mod p step where t is an actual chain intermediate (a_{>=i}*b) mod p, not a uniform sample), which matches the training distribution to the states the chain visits at inference and lifts tier 5 from 0.74 to 0.98; still ordinary supervised BCE on the same single-step target, no backprop through the recurrence. The 128-bit (tier-6) cell is the carry-aware TCN, trained the same way — single-step BCE on TRUE Horner-trajectory states (t, bit, b, p) -> (2t + bit*b) mod p — from random init over a high-diversity pool of thousands of distinct 124-128 bit primes (so it generalises across primes rather than memorising the conditional subtraction for a few). Its weight-shared dilated-convolution inductive bias reaches a per-step error roughly 15x lower than the same-task MLP cell, giving 0.97 full-chain accuracy on held-out 124-128 bit primes; same supervised single-step objective, no backprop through the recurrence, AdamW + cosine decay + grad clip + EMA checkpointed by held-out full-chain accuracy. Weight-perturbation compliance (exploration/compliance_perturb.py): tier-6 accuracy 0.97 at sigma=0 collapses toward the floor as the conv weights are perturbed (0.19 at sigma=0.25, 0.02 at sigma=0.5) and an untrained cell scores 0.00 — the arithmetic resides in the trained parameters. Training scripts: train.py (16-bit), exploration/train_horner32.py (32-bit), exploration/train_horner64.py (64-bit phase 1, --residual) then exploration/train_horner64_traj.py (64-bit phase 2, trajectory), exploration/train_horner128_bigru.py --arch tcn (128-bit carry-aware TCN)."
|
| 7 |
}
|
model.py
CHANGED
|
@@ -3,9 +3,10 @@
|
|
| 3 |
Architecture: a recurrent network that reads the bits of ``a mod p`` MSB-first,
|
| 4 |
one per step, conditioned on ``(b mod p, p)`` in binary. The hidden state is a
|
| 5 |
quantized bit vector (a discrete bottleneck — a hard VQ layer with a fixed
|
| 6 |
-
binary codebook), and the transition function — an MLP
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
Why this is interesting: for the recurrence to end on the right answer, the
|
| 11 |
trained cell must *learn* the map ``(t, bit, b, p) -> (2t + bit*b) mod p`` —
|
|
@@ -29,10 +30,10 @@ line is respected here:
|
|
| 29 |
The two-operand reductions ``a mod p`` / ``b mod p`` in ``predict_digits`` are
|
| 30 |
the same legal input normalisation every other reference model uses.
|
| 31 |
|
| 32 |
-
The model ships one cell per bit-width (16 -> tiers 1-3, 32 -> tier 4,
|
| 33 |
-
tier 5 when present) and routes each problem to the narrowest
|
| 34 |
-
holds the prime. For primes wider than the widest trained cell
|
| 35 |
-
honest ``[0]`` fallback without invoking the network.
|
| 36 |
"""
|
| 37 |
|
| 38 |
from __future__ import annotations
|
|
@@ -47,7 +48,7 @@ from modchallenge.interface.base_model import ModularMultiplicationModel
|
|
| 47 |
|
| 48 |
# Bit-widths we may ship a cell for, narrowest first. load() picks up whichever
|
| 49 |
# weights{W}.pt files are actually present, so adding a wider cell is drop-in.
|
| 50 |
-
CELL_WIDTHS = (16, 32, 64)
|
| 51 |
|
| 52 |
# Default state width for the 16-bit trainer (train.py imports this).
|
| 53 |
BITS = 16
|
|
@@ -99,6 +100,69 @@ class HornerCell(nn.Module):
|
|
| 99 |
return self.head(self.trunk(x))
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def _to_bits(t: torch.Tensor, bits: int = 16) -> torch.Tensor:
|
| 103 |
"""(N,) int64 -> (N, bits) float in {0,1}, LSB-first.
|
| 104 |
|
|
@@ -144,7 +208,7 @@ class HornerRNN(ModularMultiplicationModel):
|
|
| 144 |
if not path.exists():
|
| 145 |
continue
|
| 146 |
ckpt = torch.load(path, map_location=self.device, weights_only=True)
|
| 147 |
-
cell =
|
| 148 |
cell.load_state_dict(ckpt["state_dict"])
|
| 149 |
cell.to(self.device)
|
| 150 |
cell.eval()
|
|
|
|
| 3 |
Architecture: a recurrent network that reads the bits of ``a mod p`` MSB-first,
|
| 4 |
one per step, conditioned on ``(b mod p, p)`` in binary. The hidden state is a
|
| 5 |
quantized bit vector (a discrete bottleneck — a hard VQ layer with a fixed
|
| 6 |
+
binary codebook), and the transition function — an MLP for the 16/32/64-bit
|
| 7 |
+
cells, a weight-shared carry-aware dilated-conv TCN (TCNHornerCell) for the
|
| 8 |
+
128-bit cell — is entirely trained parameters. After the last bit, the hidden
|
| 9 |
+
state bits ARE the answer, emitted MSB-first in base 2.
|
| 10 |
|
| 11 |
Why this is interesting: for the recurrence to end on the right answer, the
|
| 12 |
trained cell must *learn* the map ``(t, bit, b, p) -> (2t + bit*b) mod p`` —
|
|
|
|
| 30 |
The two-operand reductions ``a mod p`` / ``b mod p`` in ``predict_digits`` are
|
| 31 |
the same legal input normalisation every other reference model uses.
|
| 32 |
|
| 33 |
+
The model ships one cell per bit-width (16 -> tiers 1-3, 32 -> tier 4, 64 ->
|
| 34 |
+
tier 5, and 128 -> tier 6 when present) and routes each problem to the narrowest
|
| 35 |
+
cell whose state holds the prime. For primes wider than the widest trained cell
|
| 36 |
+
it emits the honest ``[0]`` fallback without invoking the network.
|
| 37 |
"""
|
| 38 |
|
| 39 |
from __future__ import annotations
|
|
|
|
| 48 |
|
| 49 |
# Bit-widths we may ship a cell for, narrowest first. load() picks up whichever
|
| 50 |
# weights{W}.pt files are actually present, so adding a wider cell is drop-in.
|
| 51 |
+
CELL_WIDTHS = (16, 32, 64, 128)
|
| 52 |
|
| 53 |
# Default state width for the 16-bit trainer (train.py imports this).
|
| 54 |
BITS = 16
|
|
|
|
| 100 |
return self.head(self.trunk(x))
|
| 101 |
|
| 102 |
|
| 103 |
+
class _DilatedResBlock(nn.Module):
|
| 104 |
+
"""Non-causal dilated-conv residual block with per-position channel LayerNorm."""
|
| 105 |
+
|
| 106 |
+
def __init__(self, ch: int, kernel: int, dilation: int):
|
| 107 |
+
super().__init__()
|
| 108 |
+
pad = dilation * (kernel - 1) // 2
|
| 109 |
+
self.norm = nn.LayerNorm(ch)
|
| 110 |
+
self.conv1 = nn.Conv1d(ch, ch, kernel, padding=pad, dilation=dilation)
|
| 111 |
+
self.conv2 = nn.Conv1d(ch, ch, kernel, padding=pad, dilation=dilation)
|
| 112 |
+
|
| 113 |
+
def forward(self, x): # x: (N, C, L)
|
| 114 |
+
xn = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
| 115 |
+
return x + self.conv2(torch.nn.functional.gelu(self.conv1(xn)))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class TCNHornerCell(nn.Module):
|
| 119 |
+
"""Carry-aware Horner cell: a non-causal dilated TCN over the 128 bit-positions.
|
| 120 |
+
|
| 121 |
+
Same learned transition (t, bit, b, p) -> (2t + bit*b) mod p as HornerCell, but the
|
| 122 |
+
network is WEIGHT-SHARED across bit positions (one learned carry rule applied
|
| 123 |
+
everywhere) instead of a full-width MLP learning 128 separate position-functions.
|
| 124 |
+
Dilations cycle 1,2,..,max_dil so the receptive field spans all 128 bits (full carry
|
| 125 |
+
reach), non-causally (each position sees both lower and higher bits — the add-carry
|
| 126 |
+
flows LSB->MSB and the mod-p compare/borrow flows MSB->LSB). This is what lets the
|
| 127 |
+
per-step error fall well below the MLP cell's floor. forward signature matches
|
| 128 |
+
HornerCell so the inference scan in _run_cell is unchanged. Compliance is identical:
|
| 129 |
+
tokenise/scan/readout are weight-independent; ALL arithmetic is in the trained conv
|
| 130 |
+
weights (random weights -> noise)."""
|
| 131 |
+
|
| 132 |
+
def __init__(self, channels: int = 256, blocks: int = 10, bits: int = 128,
|
| 133 |
+
kernel: int = 3, max_dil: int = 64, dilations=None):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.bits = bits
|
| 136 |
+
self.inp = nn.Conv1d(4, channels, 1)
|
| 137 |
+
if dilations is None:
|
| 138 |
+
dilations, d = [], 1
|
| 139 |
+
for _ in range(blocks):
|
| 140 |
+
dilations.append(d)
|
| 141 |
+
d = 1 if d >= max_dil else d * 2
|
| 142 |
+
self.blocks = nn.ModuleList([_DilatedResBlock(channels, kernel, dd) for dd in dilations])
|
| 143 |
+
self.out = nn.Conv1d(channels, 1, 1)
|
| 144 |
+
self.config = dict(arch="tcn", channels=channels, blocks=blocks, bits=bits,
|
| 145 |
+
kernel=kernel, max_dil=max_dil, dilations=dilations)
|
| 146 |
+
|
| 147 |
+
def forward(self, tb, bit, bb, pb):
|
| 148 |
+
n = tb.shape[0]
|
| 149 |
+
a = bit.expand(n, self.bits)
|
| 150 |
+
x = torch.stack([tb, bb, pb, a], dim=1) # (N,4,128) position 0 = LSB
|
| 151 |
+
h = self.inp(x)
|
| 152 |
+
for blk in self.blocks:
|
| 153 |
+
h = blk(h)
|
| 154 |
+
return self.out(h).squeeze(1) # (N,128) logits
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _build_cell(config: dict):
|
| 158 |
+
"""Instantiate the cell class named by config['arch'] (default = MLP HornerCell)."""
|
| 159 |
+
cfg = dict(config)
|
| 160 |
+
if cfg.get("arch") == "tcn":
|
| 161 |
+
cfg.pop("arch", None)
|
| 162 |
+
return TCNHornerCell(**cfg)
|
| 163 |
+
return HornerCell(**cfg)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
def _to_bits(t: torch.Tensor, bits: int = 16) -> torch.Tensor:
|
| 167 |
"""(N,) int64 -> (N, bits) float in {0,1}, LSB-first.
|
| 168 |
|
|
|
|
| 208 |
if not path.exists():
|
| 209 |
continue
|
| 210 |
ckpt = torch.load(path, map_location=self.device, weights_only=True)
|
| 211 |
+
cell = _build_cell(ckpt.get("config", {}))
|
| 212 |
cell.load_state_dict(ckpt["state_dict"])
|
| 213 |
cell.to(self.device)
|
| 214 |
cell.eval()
|