etwk commited on
Commit ·
3d2c226
0
Parent(s):
Horner-RNN modular-multiplication model (tiers 1-5, up to 2^64)
Browse filesBit-sequential RNN that learns the Horner step (2t+bit*b) mod p; three cells
(16/32/64-bit) routed by prime size. Public benchmark: tiers 1-3 = 1.00,
tier 4 = 0.99, tier 5 = 0.64, overall 0.473. Weights weights{16,32,64}.pt are tracked as git-LFS objects.
- .gitattributes +1 -0
- README.md +124 -0
- manifest.json +7 -0
- model.py +218 -0
- train.py +221 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- modular-arithmetic
|
| 6 |
+
- algorithmic-reasoning
|
| 7 |
+
- rnn
|
| 8 |
+
- number-theory
|
| 9 |
+
- neural-algorithm
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Horner-RNN — learned modular multiplication up to 2⁶⁴
|
| 13 |
+
|
| 14 |
+
A compliant **bit-sequential RNN** that computes `(a · b) mod p` for primes `p` up to
|
| 15 |
+
**2⁶⁴**, by *learning the Horner step of double-and-add* rather than memorising
|
| 16 |
+
multiplication tables. Entry for the
|
| 17 |
+
[Modular Arithmetic Challenge](https://github.com/SAIRcompetition/modular-arithmetic-challenge).
|
| 18 |
+
|
| 19 |
+
- **Saturates tiers 1–4** (all primes `< 2³²`): tiers 1–3 = 100%, tier 4 = 99%
|
| 20 |
+
- **Tier 5** (33–64-bit primes) = 0.64 on the public benchmark
|
| 21 |
+
- **overall_accuracy 0.473**, `highest_tier_above_90 = 4`
|
| 22 |
+
- Verifiably **generalises to primes never seen in training** (held-out-prime validation
|
| 23 |
+
accuracy tracks training accuracy — no memorisation gap)
|
| 24 |
+
|
| 25 |
+
## The idea
|
| 26 |
+
|
| 27 |
+
Write `a` in bits, MSB-first; then `a·b mod p` is the iterate of one small map:
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
t_0 = 0
|
| 31 |
+
t_{k+1} = (2·t_k + a_bit_k · b) mod p # one learned step (Horner)
|
| 32 |
+
answer = t_N (N = bit width of p)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
The model is an RNN whose **transition function — an MLP — is trained on exactly that
|
| 36 |
+
single-step map** over binary-encoded inputs. The hidden state is a quantized bit vector
|
| 37 |
+
(a hard binary bottleneck), so the recurrence composes cleanly: if the cell is exact per
|
| 38 |
+
step, the chain is exact end-to-end. At inference the scan feeds the bits of `a mod p` one
|
| 39 |
+
per step, conditioned on `(b mod p, p)`, and the final hidden-state bits are emitted
|
| 40 |
+
MSB-first as the base-2 answer (`output_base: 2`).
|
| 41 |
+
|
| 42 |
+
The single-step function is **piecewise linear** (`2t + bit·b`, then subtract `0`, `p`, or
|
| 43 |
+
`2p`), which is why it generalises across primes where the full bilinear map
|
| 44 |
+
`(a,b) → a·b mod p` does not.
|
| 45 |
+
|
| 46 |
+
## Files / cells
|
| 47 |
+
|
| 48 |
+
The model ships **three cells** and routes each problem to the narrowest one whose state
|
| 49 |
+
holds the prime:
|
| 50 |
+
|
| 51 |
+
| File | Cell | Primes | Tiers | Width/depth | Params | Public benchmark |
|
| 52 |
+
|---|---|---|---|---|---|---|
|
| 53 |
+
| `weights16.pt` | 16-bit | `< 2¹⁶` | 1–3 | 4096 / 4 | ~50M | tiers 1–3 = 1.00 |
|
| 54 |
+
| `weights32.pt` | 32-bit | `< 2³²` | 4 | 6144 / 4 | ~114M | tier 4 = 0.99 |
|
| 55 |
+
| `weights64.pt` | 64-bit | `< 2⁶⁴` | 5 | 4096 / 7, residual | ~236M | tier 5 = 0.64 |
|
| 56 |
+
|
| 57 |
+
The 64-bit cell needs **depth and residual connections** the narrower cells do not: a 64-bit
|
| 58 |
+
modular Horner step hides two long carry chains (the `2t + bit·b` addition and the
|
| 59 |
+
compare-and-subtract reduction), and exact n-bit carry propagation wants MLP depth ~log₂(n).
|
| 60 |
+
For `p ≥ 2⁶⁴` the model emits the honest `[0]` fallback without invoking the network.
|
| 61 |
+
|
| 62 |
+
Also in the repo: `model.py` (the `HornerRNN` entry class + `HornerCell`), `manifest.json`
|
| 63 |
+
(challenge manifest), `train.py` (the 16-bit trainer).
|
| 64 |
+
|
| 65 |
+
## Usage
|
| 66 |
+
|
| 67 |
+
This is a challenge submission; the base class lives in the challenge package, so install it
|
| 68 |
+
first:
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
pip install "git+https://github.com/SAIRcompetition/modular-arithmetic-challenge"
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Direct inference:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
import torch
|
| 78 |
+
from model import HornerRNN # model.py from this repo
|
| 79 |
+
|
| 80 |
+
m = HornerRNN()
|
| 81 |
+
m.load(".") # auto-loads weights{16,32,64}.pt from this dir
|
| 82 |
+
# returns base-2 digits, MSB-first; the harness decodes them to the integer
|
| 83 |
+
digits = m.predict_digits_batch([(123456789, 987654321, 4294967291)])[0]
|
| 84 |
+
answer = int("".join(map(str, digits)), 2)
|
| 85 |
+
print(answer) # == (123456789 * 987654321) % 4294967291
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Or score it with the official harness:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
modchallenge evaluate . --total 1100
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Compliance (the rules permit *learned* algorithms, not hand-coded ones)
|
| 95 |
+
|
| 96 |
+
The **scan** (tokenise `a mod p` into bits, iterate, read out the final state) is
|
| 97 |
+
architecture — it computes nothing by itself. The **arithmetic** (doubling, conditional
|
| 98 |
+
add, compare-against-`p`, carries) all lives in the trained cell weights; nothing in the
|
| 99 |
+
code adds, multiplies, or compares against `p`.
|
| 100 |
+
|
| 101 |
+
**Principle 2, measured** — perturbing the cell weights with Gaussian noise scaled to each
|
| 102 |
+
tensor's std collapses accuracy toward the floor, and a fully re-initialised (untrained)
|
| 103 |
+
cell is *at* the floor. The capability therefore resides in the trained parameters:
|
| 104 |
+
|
| 105 |
+
| noise σ (×param std) | 0 | 0.05 | 0.1 | 0.25 | 0.5 | untrained |
|
| 106 |
+
|---|---|---|---|---|---|---|
|
| 107 |
+
| tier 3 (16-bit cell) | 1.00 | 1.00 | 0.98 | 0.74 | 0.06 | 0.00 |
|
| 108 |
+
| tier 4 (32-bit cell) | 0.99 | 0.99 | 0.86 | 0.04 | 0.02 | 0.00 |
|
| 109 |
+
| tier 5 (64-bit cell) | 0.64 | 0.57 | 0.41 | 0.01 | 0.01 | 0.00 |
|
| 110 |
+
|
| 111 |
+
Generalisation against memorisation: 10% of primes at each bit-width were held out of
|
| 112 |
+
training entirely; chain accuracy on them matches the training primes.
|
| 113 |
+
|
| 114 |
+
## Training
|
| 115 |
+
|
| 116 |
+
Single-step examples `(t, bit, b, p) → (2t + bit·b) mod p` over each tier's prime range,
|
| 117 |
+
half of each batch mined near the comparison boundary where errors concentrate; BCE per
|
| 118 |
+
state bit, AdamW + cosine decay + EMA, checkpointed by full-chain accuracy on held-out
|
| 119 |
+
primes. Training code and the full write-up live in the solutions repo (link in the model
|
| 120 |
+
card metadata / challenge leaderboard).
|
| 121 |
+
|
| 122 |
+
## License
|
| 123 |
+
|
| 124 |
+
Apache-2.0, matching the challenge.
|
manifest.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"entry_class": "model.HornerRNN",
|
| 3 |
+
"output_base": 2,
|
| 4 |
+
"framework": "pytorch",
|
| 5 |
+
"model_description": "Bit-sequential RNN (~400M params across three cells) for primes up to 2^64. Reads the bits of a mod p MSB-first, one per step, conditioned on (b mod p, p) in binary; the hidden state is a quantized bit vector (hard binary bottleneck) and the transition function is an MLP that must learn the Horner step (t, bit, b, p) -> (2t + bit*b) mod p to make the recurrence end on the right answer. Three cells are shipped and routed by prime size: a 16-bit cell (width 4096 depth 4, ~50M params) for p < 2^16 covering tiers 1-3, a 32-bit cell (width 6144 depth 4, ~114M params) for p < 2^32 covering tier 4, and a 64-bit cell (width 4096 depth 7 with pre-norm residual blocks, ~236M params) for p < 2^64 covering tier 5 — the wider carry chains of a 64-bit modular step need the extra depth. Final state bits are emitted MSB-first as the base-2 answer. For p >= 2^64 emits the honest [0] fallback without invoking the network.",
|
| 6 |
+
"training_description": "Each transition cell trained from random init on (t, bit, b, p) -> (2t + bit*b) mod p single-step examples over its prime range (16-bit: all primes < 2^16; 32-bit and 64-bit: random primes sampled uniform-by-value in [2^16, 2^32) and [2^33, 2^64) to match the test generator's randrange+nextprime distribution), with half of each batch (more for the 64-bit fine-tune) mined near the comparison boundary (2t + bit*b within +/-2 of a multiple of p) where errors concentrate. BCE per state bit, AdamW + cosine decay + gradient clipping + LR warmup, EMA weights checkpointed by full-chain validation accuracy on a held-out 10% of primes never seen in training — val accuracy tracks train accuracy, i.e. the cells generalise across primes rather than memorising them. Training scripts: train.py (16-bit), exploration/train_horner32.py (32-bit), exploration/train_horner64.py (64-bit, --residual)."
|
| 7 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compliant bit-sequential RNN for modular multiplication up to 2^W-bit primes.
|
| 2 |
+
|
| 3 |
+
Architecture: a recurrent network that reads the bits of ``a mod p`` MSB-first,
|
| 4 |
+
one per step, conditioned on ``(b mod p, p)`` in binary. The hidden state is a
|
| 5 |
+
quantized bit vector (a discrete bottleneck — a hard VQ layer with a fixed
|
| 6 |
+
binary codebook), and the transition function — an MLP — is entirely trained
|
| 7 |
+
parameters. After the last bit, the hidden state bits ARE the answer, emitted
|
| 8 |
+
MSB-first in base 2.
|
| 9 |
+
|
| 10 |
+
Why this is interesting: for the recurrence to end on the right answer, the
|
| 11 |
+
trained cell must *learn* the map ``(t, bit, b, p) -> (2t + bit*b) mod p`` —
|
| 12 |
+
i.e. the model is trained to internally implement one step of Horner evaluation
|
| 13 |
+
in the prime field, and it verifiably generalises to a held-out 10% of primes
|
| 14 |
+
never seen in training (val == train accuracy). The rules explicitly permit
|
| 15 |
+
recurrent/looped architectures and models that *learn* an algorithm-like circuit
|
| 16 |
+
("A model trained to internally implement an algorithm is permitted; the same
|
| 17 |
+
algorithm hand-coded into the forward pass is not" — rules/evaluation.md). The
|
| 18 |
+
line is respected here:
|
| 19 |
+
|
| 20 |
+
- hand-coded (architecture, weight-independent): tokenising ``a mod p`` into
|
| 21 |
+
bits, scanning them sequentially, reading the final state bits. This is
|
| 22 |
+
tokenisation + recurrence + readout — it computes nothing by itself: with
|
| 23 |
+
random weights the output is noise (Principle 2), and the emitted digits are
|
| 24 |
+
exactly the model's final hidden state (Principle 1).
|
| 25 |
+
- learned (all of the actual arithmetic): the transition function. Nothing in
|
| 26 |
+
the code adds, multiplies, compares against p, or carries; the cell's MLP
|
| 27 |
+
weights had to learn all of that from data.
|
| 28 |
+
|
| 29 |
+
The two-operand reductions ``a mod p`` / ``b mod p`` in ``predict_digits`` are
|
| 30 |
+
the same legal input normalisation every other reference model uses.
|
| 31 |
+
|
| 32 |
+
The model ships one cell per bit-width (16 -> tiers 1-3, 32 -> tier 4, and 64 ->
|
| 33 |
+
tier 5 when present) and routes each problem to the narrowest cell whose state
|
| 34 |
+
holds the prime. For primes wider than the widest trained cell it emits the
|
| 35 |
+
honest ``[0]`` fallback without invoking the network.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
from __future__ import annotations
|
| 39 |
+
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
|
| 42 |
+
import numpy as np
|
| 43 |
+
import torch
|
| 44 |
+
import torch.nn as nn
|
| 45 |
+
|
| 46 |
+
from modchallenge.interface.base_model import ModularMultiplicationModel
|
| 47 |
+
|
| 48 |
+
# Bit-widths we may ship a cell for, narrowest first. load() picks up whichever
|
| 49 |
+
# weights{W}.pt files are actually present, so adding a wider cell is drop-in.
|
| 50 |
+
CELL_WIDTHS = (16, 32, 64)
|
| 51 |
+
|
| 52 |
+
# Default state width for the 16-bit trainer (train.py imports this).
|
| 53 |
+
BITS = 16
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class _ResBlock(nn.Module):
|
| 57 |
+
"""Pre-norm residual MLP block: x + Linear(GELU(Linear(LN(x))))."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, width: int):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.ln = nn.LayerNorm(width)
|
| 62 |
+
self.fc1 = nn.Linear(width, width)
|
| 63 |
+
self.fc2 = nn.Linear(width, width)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
return x + self.fc2(torch.nn.functional.gelu(self.fc1(self.ln(x))))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class HornerCell(nn.Module):
|
| 70 |
+
"""Learned RNN transition: (state_bits, bit, b_bits, p_bits) -> next-state logits.
|
| 71 |
+
|
| 72 |
+
``residual=False`` (default) is the plain GELU stack used by the 16/32-bit
|
| 73 |
+
cells — its module/parameter layout is unchanged so existing checkpoints
|
| 74 |
+
load. ``residual=True`` swaps the trunk for pre-norm residual blocks after
|
| 75 |
+
an input projection, which stay trainable at the larger depth the 64-bit
|
| 76 |
+
carry chains need (exact n-bit carry propagation wants depth ~log2(n)). The
|
| 77 |
+
flag lives in ``config`` so older checkpoints (no ``residual`` key) load as
|
| 78 |
+
the plain stack.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, width: int = 4096, depth: int = 4, bits: int = 16, residual: bool = False):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.residual = residual
|
| 84 |
+
if residual:
|
| 85 |
+
self.proj = nn.Linear(3 * bits + 1, width)
|
| 86 |
+
self.trunk = nn.Sequential(*[_ResBlock(width) for _ in range(depth)])
|
| 87 |
+
else:
|
| 88 |
+
layers: list[nn.Module] = [nn.Linear(3 * bits + 1, width), nn.GELU()]
|
| 89 |
+
for _ in range(depth - 1):
|
| 90 |
+
layers += [nn.Linear(width, width), nn.GELU()]
|
| 91 |
+
self.trunk = nn.Sequential(*layers)
|
| 92 |
+
self.head = nn.Linear(width, bits)
|
| 93 |
+
self.config = dict(width=width, depth=depth, bits=bits, residual=residual)
|
| 94 |
+
|
| 95 |
+
def forward(self, tb, bit, bb, pb):
|
| 96 |
+
x = torch.cat([tb, bit, bb, pb], dim=-1)
|
| 97 |
+
if self.residual:
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
return self.head(self.trunk(x))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _to_bits(t: torch.Tensor, bits: int = 16) -> torch.Tensor:
|
| 103 |
+
"""(N,) int64 -> (N, bits) float in {0,1}, LSB-first.
|
| 104 |
+
|
| 105 |
+
Used by the trainer for <= 32-bit values. Inference uses the numpy packer
|
| 106 |
+
below (bit-identical for <= 32 bits, and also valid at 64 bits where an
|
| 107 |
+
int64 tensor would overflow). Kept here so the trainer can import it.
|
| 108 |
+
"""
|
| 109 |
+
shifts = torch.arange(bits, device=t.device)
|
| 110 |
+
return ((t.unsqueeze(1) >> shifts) & 1).float()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _pack_bits(vals: list[int], nbits: int, device) -> torch.Tensor:
|
| 114 |
+
"""list[int] (each < 2^nbits) -> (N, nbits) float bit tensor, LSB-first.
|
| 115 |
+
|
| 116 |
+
Works for any nbits divisible by 8, including 64 where the torch shift
|
| 117 |
+
trick overflows int64. Verified bit-identical to ``_to_bits`` for 16/32.
|
| 118 |
+
"""
|
| 119 |
+
nbytes = nbits // 8
|
| 120 |
+
buf = b"".join(int(v).to_bytes(nbytes, "little") for v in vals)
|
| 121 |
+
arr = np.frombuffer(buf, dtype=np.uint8).reshape(len(vals), nbytes)
|
| 122 |
+
bits = np.unpackbits(arr, axis=1, bitorder="little").astype(np.float32)
|
| 123 |
+
return torch.from_numpy(bits).to(device)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class HornerRNN(ModularMultiplicationModel):
|
| 127 |
+
"""Routes each problem to the narrowest trained cell that fits its prime."""
|
| 128 |
+
|
| 129 |
+
def __init__(self):
|
| 130 |
+
# width -> HornerCell, populated from whichever weight files exist.
|
| 131 |
+
self.cells: dict[int, HornerCell] = {}
|
| 132 |
+
self.device: torch.device | None = None
|
| 133 |
+
|
| 134 |
+
def load(self, model_dir: str) -> None:
|
| 135 |
+
if torch.cuda.is_available():
|
| 136 |
+
self.device = torch.device("cuda")
|
| 137 |
+
elif torch.backends.mps.is_available():
|
| 138 |
+
self.device = torch.device("mps")
|
| 139 |
+
else:
|
| 140 |
+
self.device = torch.device("cpu")
|
| 141 |
+
|
| 142 |
+
for width in CELL_WIDTHS:
|
| 143 |
+
path = Path(model_dir) / f"weights{width}.pt"
|
| 144 |
+
if not path.exists():
|
| 145 |
+
continue
|
| 146 |
+
ckpt = torch.load(path, map_location=self.device, weights_only=True)
|
| 147 |
+
cell = HornerCell(**ckpt.get("config", {}))
|
| 148 |
+
cell.load_state_dict(ckpt["state_dict"])
|
| 149 |
+
cell.to(self.device)
|
| 150 |
+
cell.eval()
|
| 151 |
+
self.cells[width] = cell
|
| 152 |
+
|
| 153 |
+
if not self.cells:
|
| 154 |
+
raise FileNotFoundError(
|
| 155 |
+
f"no weights{{{','.join(map(str, CELL_WIDTHS))}}}.pt found in {model_dir}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def preprocess_a(self, a):
|
| 159 |
+
return a
|
| 160 |
+
|
| 161 |
+
def preprocess_b(self, b):
|
| 162 |
+
return b
|
| 163 |
+
|
| 164 |
+
def preprocess_p(self, p):
|
| 165 |
+
return p
|
| 166 |
+
|
| 167 |
+
@torch.no_grad()
|
| 168 |
+
def predict_digits(self, a_enc, b_enc, p_enc):
|
| 169 |
+
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def _run_cell(self, width: int, rows: list[tuple[int, int, int]]) -> list[list[int]]:
|
| 173 |
+
"""Scan the width-bit cell over a batch of (a_red, b_red, p) rows."""
|
| 174 |
+
cell = self.cells[width]
|
| 175 |
+
a_bits = _pack_bits([r[0] for r in rows], width, self.device)
|
| 176 |
+
bb = _pack_bits([r[1] for r in rows], width, self.device)
|
| 177 |
+
pb = _pack_bits([r[2] for r in rows], width, self.device)
|
| 178 |
+
state = torch.zeros(len(rows), width, device=self.device)
|
| 179 |
+
# RNN scan over the bit tokens of (a mod p), MSB-first. The scan moves
|
| 180 |
+
# data; the learned cell does all the computing.
|
| 181 |
+
for s in range(width - 1, -1, -1):
|
| 182 |
+
bit = a_bits[:, s : s + 1]
|
| 183 |
+
logits = cell(state, bit, bb, pb)
|
| 184 |
+
state = (logits > 0).float() # quantized state bottleneck
|
| 185 |
+
return state.long().tolist() # LSB-first per row
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def predict_digits_batch(self, inputs):
|
| 189 |
+
assert self.cells, "load() must run first"
|
| 190 |
+
out: list[list[int] | None] = [None] * len(inputs)
|
| 191 |
+
widths = sorted(self.cells)
|
| 192 |
+
widest = widths[-1]
|
| 193 |
+
|
| 194 |
+
# Bucket each problem by the narrowest cell whose state holds the prime.
|
| 195 |
+
buckets: dict[int, tuple[list[int], list[tuple[int, int, int]]]] = {
|
| 196 |
+
w: ([], []) for w in widths
|
| 197 |
+
}
|
| 198 |
+
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
|
| 199 |
+
p = int(p_enc)
|
| 200 |
+
if p >= (1 << widest):
|
| 201 |
+
out[i] = [0] # outside every trained regime: honest fallback
|
| 202 |
+
continue
|
| 203 |
+
w = next(w for w in widths if p < (1 << w))
|
| 204 |
+
idx, rows = buckets[w]
|
| 205 |
+
idx.append(i)
|
| 206 |
+
rows.append((int(a_enc) % p, int(b_enc) % p, p))
|
| 207 |
+
|
| 208 |
+
for w in widths:
|
| 209 |
+
idx, rows = buckets[w]
|
| 210 |
+
if rows:
|
| 211 |
+
bits = self._run_cell(w, rows)
|
| 212 |
+
for j, i in enumerate(idx):
|
| 213 |
+
out[i] = bits[j][::-1] # emit MSB-first, base 2
|
| 214 |
+
|
| 215 |
+
return [o if o is not None else [0] for o in out]
|
| 216 |
+
|
| 217 |
+
def max_batch_size(self) -> int:
|
| 218 |
+
return 1024
|
train.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning.
|
| 2 |
+
|
| 3 |
+
Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2},
|
| 4 |
+
easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining.
|
| 5 |
+
|
| 6 |
+
Stage 2 (optional, default off): fine-tune end-to-end through the 16-step
|
| 7 |
+
chain with a straight-through estimator on the quantized state, loss on every
|
| 8 |
+
step's ground-truth intermediate. In practice this was destructive at lr2=5e-5
|
| 9 |
+
(chain val collapsed); the shipped weights come from stage 1 alone, which
|
| 10 |
+
reaches chain val ~0.998 on held-out primes. Kept for further experimentation
|
| 11 |
+
at lower learning rates.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
# Import the shared architecture from the sibling model.py.
|
| 26 |
+
HERE = Path(__file__).resolve().parent
|
| 27 |
+
sys.path.insert(0, str(HERE))
|
| 28 |
+
from model import HornerCell, BITS, _to_bits as to_bits # noqa: E402
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def sieve_primes(limit: int) -> list[int]:
|
| 32 |
+
is_p = bytearray([1]) * limit
|
| 33 |
+
is_p[0] = is_p[1] = 0
|
| 34 |
+
for i in range(2, int(limit ** 0.5) + 1):
|
| 35 |
+
if is_p[i]:
|
| 36 |
+
is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
|
| 37 |
+
return [i for i in range(2, limit) if is_p[i]]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def sample_batch(primes_t, n, device, hard_frac=0.5):
|
| 41 |
+
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
|
| 42 |
+
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
|
| 43 |
+
bit = torch.randint(0, 2, (n,), device=device)
|
| 44 |
+
n_hard = int(n * hard_frac)
|
| 45 |
+
t = torch.empty(n, dtype=torch.long, device=device)
|
| 46 |
+
t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long()
|
| 47 |
+
if n_hard:
|
| 48 |
+
ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard]
|
| 49 |
+
q = torch.randint(0, 3, (n_hard,), device=device)
|
| 50 |
+
delta = torch.randint(-2, 3, (n_hard,), device=device)
|
| 51 |
+
th = (q * ph + delta - bith * bh) >> 1
|
| 52 |
+
t[:n_hard] = th.clamp(min=0) % ph
|
| 53 |
+
z = (2 * t + bit * b) % p
|
| 54 |
+
return t, bit, b, p, z
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float:
|
| 59 |
+
ok = 0
|
| 60 |
+
for i in range(0, n, bs):
|
| 61 |
+
m = min(bs, n - i)
|
| 62 |
+
t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0)
|
| 63 |
+
logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
|
| 64 |
+
ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item()
|
| 65 |
+
return ok / n
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def chain_exact_rate(model, primes_t, device, n=20_000) -> float:
|
| 70 |
+
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
|
| 71 |
+
a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
|
| 72 |
+
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
|
| 73 |
+
truth = (a * b) % p
|
| 74 |
+
bb, pb = to_bits(b), to_bits(p)
|
| 75 |
+
tb = torch.zeros(n, BITS, device=device)
|
| 76 |
+
for i in range(BITS - 1, -1, -1):
|
| 77 |
+
bit = ((a >> i) & 1).float().unsqueeze(1)
|
| 78 |
+
tb = (model(tb, bit, bb, pb) > 0).float()
|
| 79 |
+
pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1)
|
| 80 |
+
return (pred == truth).float().mean().item()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def chain_finetune_batch(model, primes_t, n, device, loss_fn):
|
| 84 |
+
"""One end-to-end pass: STE state, per-step CE against true intermediates."""
|
| 85 |
+
p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
|
| 86 |
+
a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
|
| 87 |
+
b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
|
| 88 |
+
bb, pb = to_bits(b), to_bits(p)
|
| 89 |
+
tb = torch.zeros(n, BITS, device=device)
|
| 90 |
+
t_true = torch.zeros_like(a)
|
| 91 |
+
loss = torch.zeros((), device=device)
|
| 92 |
+
for i in range(BITS - 1, -1, -1):
|
| 93 |
+
bit_i = (a >> i) & 1
|
| 94 |
+
t_true = (2 * t_true + bit_i * b) % p
|
| 95 |
+
logits = model(tb, bit_i.float().unsqueeze(1), bb, pb)
|
| 96 |
+
loss = loss + loss_fn(logits, to_bits(t_true))
|
| 97 |
+
hard = (logits > 0).float()
|
| 98 |
+
soft = torch.sigmoid(logits)
|
| 99 |
+
tb = hard + (soft - soft.detach()) # straight-through
|
| 100 |
+
return loss / BITS
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def main() -> int:
|
| 104 |
+
ap = argparse.ArgumentParser()
|
| 105 |
+
ap.add_argument("--stage1-minutes", type=float, default=50.0)
|
| 106 |
+
ap.add_argument("--stage2-minutes", type=float, default=0.0)
|
| 107 |
+
ap.add_argument("--batch", type=int, default=32768)
|
| 108 |
+
ap.add_argument("--chain-batch", type=int, default=4096)
|
| 109 |
+
ap.add_argument("--lr", type=float, default=3e-4)
|
| 110 |
+
ap.add_argument("--lr2", type=float, default=5e-5)
|
| 111 |
+
ap.add_argument("--width", type=int, default=4096)
|
| 112 |
+
ap.add_argument("--depth", type=int, default=4)
|
| 113 |
+
ap.add_argument("--init", type=str, default="")
|
| 114 |
+
ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt"))
|
| 115 |
+
args = ap.parse_args()
|
| 116 |
+
|
| 117 |
+
device = torch.device("cuda")
|
| 118 |
+
torch.manual_seed(0)
|
| 119 |
+
|
| 120 |
+
small = sieve_primes(256)
|
| 121 |
+
primes = [p for p in sieve_primes(1 << 16) if p >= 256]
|
| 122 |
+
g = torch.Generator().manual_seed(1)
|
| 123 |
+
perm = torch.randperm(len(primes), generator=g).tolist()
|
| 124 |
+
val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device)
|
| 125 |
+
train_primes = torch.tensor(
|
| 126 |
+
small + [primes[i] for i in perm[len(primes) // 10 :]], device=device
|
| 127 |
+
)
|
| 128 |
+
print(f"train primes {len(train_primes)}, val primes {len(val_primes)}")
|
| 129 |
+
|
| 130 |
+
model = HornerCell(args.width, args.depth).to(device)
|
| 131 |
+
if args.init:
|
| 132 |
+
ckpt = torch.load(args.init, map_location=device, weights_only=True)
|
| 133 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 134 |
+
print(f"initialised from {args.init}")
|
| 135 |
+
ema = HornerCell(args.width, args.depth).to(device)
|
| 136 |
+
ema.load_state_dict(model.state_dict())
|
| 137 |
+
for q in ema.parameters():
|
| 138 |
+
q.requires_grad_(False)
|
| 139 |
+
print(f"params: {sum(t.numel() for t in model.parameters()):,}")
|
| 140 |
+
loss_fn = nn.BCEWithLogitsLoss()
|
| 141 |
+
EMA_DECAY = 0.999
|
| 142 |
+
|
| 143 |
+
def update_ema():
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for q, w in zip(ema.parameters(), model.parameters()):
|
| 146 |
+
q.lerp_(w, 1 - EMA_DECAY)
|
| 147 |
+
|
| 148 |
+
best_chain = -1.0
|
| 149 |
+
|
| 150 |
+
def save_if_best(tag):
|
| 151 |
+
nonlocal best_chain
|
| 152 |
+
ch = chain_exact_rate(ema, val_primes, device)
|
| 153 |
+
if ch > best_chain:
|
| 154 |
+
best_chain = ch
|
| 155 |
+
torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out)
|
| 156 |
+
return ch
|
| 157 |
+
|
| 158 |
+
# ----- Stage 1: cell training -----
|
| 159 |
+
if args.stage1_minutes > 0:
|
| 160 |
+
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
| 161 |
+
total_steps = int(args.stage1_minutes * 60 * 16)
|
| 162 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02)
|
| 163 |
+
deadline = time.monotonic() + args.stage1_minutes * 60
|
| 164 |
+
start = time.monotonic()
|
| 165 |
+
step = 0
|
| 166 |
+
while time.monotonic() < deadline:
|
| 167 |
+
t, bit, b, p, z = sample_batch(train_primes, args.batch, device)
|
| 168 |
+
logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
|
| 169 |
+
loss = loss_fn(logits, to_bits(z))
|
| 170 |
+
opt.zero_grad()
|
| 171 |
+
loss.backward()
|
| 172 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 173 |
+
opt.step()
|
| 174 |
+
if step < total_steps:
|
| 175 |
+
sched.step()
|
| 176 |
+
update_ema()
|
| 177 |
+
step += 1
|
| 178 |
+
if step % 1000 == 0:
|
| 179 |
+
va = exact_rate(ema, val_primes, device, n=100_000)
|
| 180 |
+
ch = save_if_best("s1")
|
| 181 |
+
print(
|
| 182 |
+
f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
|
| 183 |
+
f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
|
| 184 |
+
flush=True,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# ----- Stage 2: end-to-end chain fine-tuning (STE) -----
|
| 188 |
+
if args.stage2_minutes > 0:
|
| 189 |
+
opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5)
|
| 190 |
+
total_steps = int(args.stage2_minutes * 60 * 3)
|
| 191 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1)
|
| 192 |
+
deadline = time.monotonic() + args.stage2_minutes * 60
|
| 193 |
+
start = time.monotonic()
|
| 194 |
+
step = 0
|
| 195 |
+
while time.monotonic() < deadline:
|
| 196 |
+
loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn)
|
| 197 |
+
opt.zero_grad()
|
| 198 |
+
loss.backward()
|
| 199 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 200 |
+
opt.step()
|
| 201 |
+
if step < total_steps:
|
| 202 |
+
sched.step()
|
| 203 |
+
update_ema()
|
| 204 |
+
step += 1
|
| 205 |
+
if step % 200 == 0:
|
| 206 |
+
va = exact_rate(ema, val_primes, device, n=100_000)
|
| 207 |
+
ch = save_if_best("s2")
|
| 208 |
+
print(
|
| 209 |
+
f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
|
| 210 |
+
f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
|
| 211 |
+
flush=True,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
va = exact_rate(ema, val_primes, device, n=500_000)
|
| 215 |
+
ch = chain_exact_rate(ema, val_primes, device, n=50_000)
|
| 216 |
+
print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}")
|
| 217 |
+
return 0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
raise SystemExit(main())
|