etwk commited on
Commit
3d2c226
·
0 Parent(s):

Horner-RNN modular-multiplication model (tiers 1-5, up to 2^64)

Browse files

Bit-sequential RNN that learns the Horner step (2t+bit*b) mod p; three cells
(16/32/64-bit) routed by prime size. Public benchmark: tiers 1-3 = 1.00,
tier 4 = 0.99, tier 5 = 0.64, overall 0.473. Weights weights{16,32,64}.pt are tracked as git-LFS objects.

Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +124 -0
  3. manifest.json +7 -0
  4. model.py +218 -0
  5. train.py +221 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: pytorch
4
+ tags:
5
+ - modular-arithmetic
6
+ - algorithmic-reasoning
7
+ - rnn
8
+ - number-theory
9
+ - neural-algorithm
10
+ ---
11
+
12
+ # Horner-RNN — learned modular multiplication up to 2⁶⁴
13
+
14
+ A compliant **bit-sequential RNN** that computes `(a · b) mod p` for primes `p` up to
15
+ **2⁶⁴**, by *learning the Horner step of double-and-add* rather than memorising
16
+ multiplication tables. Entry for the
17
+ [Modular Arithmetic Challenge](https://github.com/SAIRcompetition/modular-arithmetic-challenge).
18
+
19
+ - **Saturates tiers 1–4** (all primes `< 2³²`): tiers 1–3 = 100%, tier 4 = 99%
20
+ - **Tier 5** (33–64-bit primes) = 0.64 on the public benchmark
21
+ - **overall_accuracy 0.473**, `highest_tier_above_90 = 4`
22
+ - Verifiably **generalises to primes never seen in training** (held-out-prime validation
23
+ accuracy tracks training accuracy — no memorisation gap)
24
+
25
+ ## The idea
26
+
27
+ Write `a` in bits, MSB-first; then `a·b mod p` is the iterate of one small map:
28
+
29
+ ```
30
+ t_0 = 0
31
+ t_{k+1} = (2·t_k + a_bit_k · b) mod p # one learned step (Horner)
32
+ answer = t_N (N = bit width of p)
33
+ ```
34
+
35
+ The model is an RNN whose **transition function — an MLP — is trained on exactly that
36
+ single-step map** over binary-encoded inputs. The hidden state is a quantized bit vector
37
+ (a hard binary bottleneck), so the recurrence composes cleanly: if the cell is exact per
38
+ step, the chain is exact end-to-end. At inference the scan feeds the bits of `a mod p` one
39
+ per step, conditioned on `(b mod p, p)`, and the final hidden-state bits are emitted
40
+ MSB-first as the base-2 answer (`output_base: 2`).
41
+
42
+ The single-step function is **piecewise linear** (`2t + bit·b`, then subtract `0`, `p`, or
43
+ `2p`), which is why it generalises across primes where the full bilinear map
44
+ `(a,b) → a·b mod p` does not.
45
+
46
+ ## Files / cells
47
+
48
+ The model ships **three cells** and routes each problem to the narrowest one whose state
49
+ holds the prime:
50
+
51
+ | File | Cell | Primes | Tiers | Width/depth | Params | Public benchmark |
52
+ |---|---|---|---|---|---|---|
53
+ | `weights16.pt` | 16-bit | `< 2¹⁶` | 1–3 | 4096 / 4 | ~50M | tiers 1–3 = 1.00 |
54
+ | `weights32.pt` | 32-bit | `< 2³²` | 4 | 6144 / 4 | ~114M | tier 4 = 0.99 |
55
+ | `weights64.pt` | 64-bit | `< 2⁶⁴` | 5 | 4096 / 7, residual | ~236M | tier 5 = 0.64 |
56
+
57
+ The 64-bit cell needs **depth and residual connections** the narrower cells do not: a 64-bit
58
+ modular Horner step hides two long carry chains (the `2t + bit·b` addition and the
59
+ compare-and-subtract reduction), and exact n-bit carry propagation wants MLP depth ~log₂(n).
60
+ For `p ≥ 2⁶⁴` the model emits the honest `[0]` fallback without invoking the network.
61
+
62
+ Also in the repo: `model.py` (the `HornerRNN` entry class + `HornerCell`), `manifest.json`
63
+ (challenge manifest), `train.py` (the 16-bit trainer).
64
+
65
+ ## Usage
66
+
67
+ This is a challenge submission; the base class lives in the challenge package, so install it
68
+ first:
69
+
70
+ ```bash
71
+ pip install "git+https://github.com/SAIRcompetition/modular-arithmetic-challenge"
72
+ ```
73
+
74
+ Direct inference:
75
+
76
+ ```python
77
+ import torch
78
+ from model import HornerRNN # model.py from this repo
79
+
80
+ m = HornerRNN()
81
+ m.load(".") # auto-loads weights{16,32,64}.pt from this dir
82
+ # returns base-2 digits, MSB-first; the harness decodes them to the integer
83
+ digits = m.predict_digits_batch([(123456789, 987654321, 4294967291)])[0]
84
+ answer = int("".join(map(str, digits)), 2)
85
+ print(answer) # == (123456789 * 987654321) % 4294967291
86
+ ```
87
+
88
+ Or score it with the official harness:
89
+
90
+ ```bash
91
+ modchallenge evaluate . --total 1100
92
+ ```
93
+
94
+ ## Compliance (the rules permit *learned* algorithms, not hand-coded ones)
95
+
96
+ The **scan** (tokenise `a mod p` into bits, iterate, read out the final state) is
97
+ architecture — it computes nothing by itself. The **arithmetic** (doubling, conditional
98
+ add, compare-against-`p`, carries) all lives in the trained cell weights; nothing in the
99
+ code adds, multiplies, or compares against `p`.
100
+
101
+ **Principle 2, measured** — perturbing the cell weights with Gaussian noise scaled to each
102
+ tensor's std collapses accuracy toward the floor, and a fully re-initialised (untrained)
103
+ cell is *at* the floor. The capability therefore resides in the trained parameters:
104
+
105
+ | noise σ (×param std) | 0 | 0.05 | 0.1 | 0.25 | 0.5 | untrained |
106
+ |---|---|---|---|---|---|---|
107
+ | tier 3 (16-bit cell) | 1.00 | 1.00 | 0.98 | 0.74 | 0.06 | 0.00 |
108
+ | tier 4 (32-bit cell) | 0.99 | 0.99 | 0.86 | 0.04 | 0.02 | 0.00 |
109
+ | tier 5 (64-bit cell) | 0.64 | 0.57 | 0.41 | 0.01 | 0.01 | 0.00 |
110
+
111
+ Generalisation against memorisation: 10% of primes at each bit-width were held out of
112
+ training entirely; chain accuracy on them matches the training primes.
113
+
114
+ ## Training
115
+
116
+ Single-step examples `(t, bit, b, p) → (2t + bit·b) mod p` over each tier's prime range,
117
+ half of each batch mined near the comparison boundary where errors concentrate; BCE per
118
+ state bit, AdamW + cosine decay + EMA, checkpointed by full-chain accuracy on held-out
119
+ primes. Training code and the full write-up live in the solutions repo (link in the model
120
+ card metadata / challenge leaderboard).
121
+
122
+ ## License
123
+
124
+ Apache-2.0, matching the challenge.
manifest.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "entry_class": "model.HornerRNN",
3
+ "output_base": 2,
4
+ "framework": "pytorch",
5
+ "model_description": "Bit-sequential RNN (~400M params across three cells) for primes up to 2^64. Reads the bits of a mod p MSB-first, one per step, conditioned on (b mod p, p) in binary; the hidden state is a quantized bit vector (hard binary bottleneck) and the transition function is an MLP that must learn the Horner step (t, bit, b, p) -> (2t + bit*b) mod p to make the recurrence end on the right answer. Three cells are shipped and routed by prime size: a 16-bit cell (width 4096 depth 4, ~50M params) for p < 2^16 covering tiers 1-3, a 32-bit cell (width 6144 depth 4, ~114M params) for p < 2^32 covering tier 4, and a 64-bit cell (width 4096 depth 7 with pre-norm residual blocks, ~236M params) for p < 2^64 covering tier 5 — the wider carry chains of a 64-bit modular step need the extra depth. Final state bits are emitted MSB-first as the base-2 answer. For p >= 2^64 emits the honest [0] fallback without invoking the network.",
6
+ "training_description": "Each transition cell trained from random init on (t, bit, b, p) -> (2t + bit*b) mod p single-step examples over its prime range (16-bit: all primes < 2^16; 32-bit and 64-bit: random primes sampled uniform-by-value in [2^16, 2^32) and [2^33, 2^64) to match the test generator's randrange+nextprime distribution), with half of each batch (more for the 64-bit fine-tune) mined near the comparison boundary (2t + bit*b within +/-2 of a multiple of p) where errors concentrate. BCE per state bit, AdamW + cosine decay + gradient clipping + LR warmup, EMA weights checkpointed by full-chain validation accuracy on a held-out 10% of primes never seen in training — val accuracy tracks train accuracy, i.e. the cells generalise across primes rather than memorising them. Training scripts: train.py (16-bit), exploration/train_horner32.py (32-bit), exploration/train_horner64.py (64-bit, --residual)."
7
+ }
model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compliant bit-sequential RNN for modular multiplication up to 2^W-bit primes.
2
+
3
+ Architecture: a recurrent network that reads the bits of ``a mod p`` MSB-first,
4
+ one per step, conditioned on ``(b mod p, p)`` in binary. The hidden state is a
5
+ quantized bit vector (a discrete bottleneck — a hard VQ layer with a fixed
6
+ binary codebook), and the transition function — an MLP — is entirely trained
7
+ parameters. After the last bit, the hidden state bits ARE the answer, emitted
8
+ MSB-first in base 2.
9
+
10
+ Why this is interesting: for the recurrence to end on the right answer, the
11
+ trained cell must *learn* the map ``(t, bit, b, p) -> (2t + bit*b) mod p`` —
12
+ i.e. the model is trained to internally implement one step of Horner evaluation
13
+ in the prime field, and it verifiably generalises to a held-out 10% of primes
14
+ never seen in training (val == train accuracy). The rules explicitly permit
15
+ recurrent/looped architectures and models that *learn* an algorithm-like circuit
16
+ ("A model trained to internally implement an algorithm is permitted; the same
17
+ algorithm hand-coded into the forward pass is not" — rules/evaluation.md). The
18
+ line is respected here:
19
+
20
+ - hand-coded (architecture, weight-independent): tokenising ``a mod p`` into
21
+ bits, scanning them sequentially, reading the final state bits. This is
22
+ tokenisation + recurrence + readout — it computes nothing by itself: with
23
+ random weights the output is noise (Principle 2), and the emitted digits are
24
+ exactly the model's final hidden state (Principle 1).
25
+ - learned (all of the actual arithmetic): the transition function. Nothing in
26
+ the code adds, multiplies, compares against p, or carries; the cell's MLP
27
+ weights had to learn all of that from data.
28
+
29
+ The two-operand reductions ``a mod p`` / ``b mod p`` in ``predict_digits`` are
30
+ the same legal input normalisation every other reference model uses.
31
+
32
+ The model ships one cell per bit-width (16 -> tiers 1-3, 32 -> tier 4, and 64 ->
33
+ tier 5 when present) and routes each problem to the narrowest cell whose state
34
+ holds the prime. For primes wider than the widest trained cell it emits the
35
+ honest ``[0]`` fallback without invoking the network.
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ from pathlib import Path
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn as nn
45
+
46
+ from modchallenge.interface.base_model import ModularMultiplicationModel
47
+
48
+ # Bit-widths we may ship a cell for, narrowest first. load() picks up whichever
49
+ # weights{W}.pt files are actually present, so adding a wider cell is drop-in.
50
+ CELL_WIDTHS = (16, 32, 64)
51
+
52
+ # Default state width for the 16-bit trainer (train.py imports this).
53
+ BITS = 16
54
+
55
+
56
+ class _ResBlock(nn.Module):
57
+ """Pre-norm residual MLP block: x + Linear(GELU(Linear(LN(x))))."""
58
+
59
+ def __init__(self, width: int):
60
+ super().__init__()
61
+ self.ln = nn.LayerNorm(width)
62
+ self.fc1 = nn.Linear(width, width)
63
+ self.fc2 = nn.Linear(width, width)
64
+
65
+ def forward(self, x):
66
+ return x + self.fc2(torch.nn.functional.gelu(self.fc1(self.ln(x))))
67
+
68
+
69
+ class HornerCell(nn.Module):
70
+ """Learned RNN transition: (state_bits, bit, b_bits, p_bits) -> next-state logits.
71
+
72
+ ``residual=False`` (default) is the plain GELU stack used by the 16/32-bit
73
+ cells — its module/parameter layout is unchanged so existing checkpoints
74
+ load. ``residual=True`` swaps the trunk for pre-norm residual blocks after
75
+ an input projection, which stay trainable at the larger depth the 64-bit
76
+ carry chains need (exact n-bit carry propagation wants depth ~log2(n)). The
77
+ flag lives in ``config`` so older checkpoints (no ``residual`` key) load as
78
+ the plain stack.
79
+ """
80
+
81
+ def __init__(self, width: int = 4096, depth: int = 4, bits: int = 16, residual: bool = False):
82
+ super().__init__()
83
+ self.residual = residual
84
+ if residual:
85
+ self.proj = nn.Linear(3 * bits + 1, width)
86
+ self.trunk = nn.Sequential(*[_ResBlock(width) for _ in range(depth)])
87
+ else:
88
+ layers: list[nn.Module] = [nn.Linear(3 * bits + 1, width), nn.GELU()]
89
+ for _ in range(depth - 1):
90
+ layers += [nn.Linear(width, width), nn.GELU()]
91
+ self.trunk = nn.Sequential(*layers)
92
+ self.head = nn.Linear(width, bits)
93
+ self.config = dict(width=width, depth=depth, bits=bits, residual=residual)
94
+
95
+ def forward(self, tb, bit, bb, pb):
96
+ x = torch.cat([tb, bit, bb, pb], dim=-1)
97
+ if self.residual:
98
+ x = self.proj(x)
99
+ return self.head(self.trunk(x))
100
+
101
+
102
+ def _to_bits(t: torch.Tensor, bits: int = 16) -> torch.Tensor:
103
+ """(N,) int64 -> (N, bits) float in {0,1}, LSB-first.
104
+
105
+ Used by the trainer for <= 32-bit values. Inference uses the numpy packer
106
+ below (bit-identical for <= 32 bits, and also valid at 64 bits where an
107
+ int64 tensor would overflow). Kept here so the trainer can import it.
108
+ """
109
+ shifts = torch.arange(bits, device=t.device)
110
+ return ((t.unsqueeze(1) >> shifts) & 1).float()
111
+
112
+
113
+ def _pack_bits(vals: list[int], nbits: int, device) -> torch.Tensor:
114
+ """list[int] (each < 2^nbits) -> (N, nbits) float bit tensor, LSB-first.
115
+
116
+ Works for any nbits divisible by 8, including 64 where the torch shift
117
+ trick overflows int64. Verified bit-identical to ``_to_bits`` for 16/32.
118
+ """
119
+ nbytes = nbits // 8
120
+ buf = b"".join(int(v).to_bytes(nbytes, "little") for v in vals)
121
+ arr = np.frombuffer(buf, dtype=np.uint8).reshape(len(vals), nbytes)
122
+ bits = np.unpackbits(arr, axis=1, bitorder="little").astype(np.float32)
123
+ return torch.from_numpy(bits).to(device)
124
+
125
+
126
+ class HornerRNN(ModularMultiplicationModel):
127
+ """Routes each problem to the narrowest trained cell that fits its prime."""
128
+
129
+ def __init__(self):
130
+ # width -> HornerCell, populated from whichever weight files exist.
131
+ self.cells: dict[int, HornerCell] = {}
132
+ self.device: torch.device | None = None
133
+
134
+ def load(self, model_dir: str) -> None:
135
+ if torch.cuda.is_available():
136
+ self.device = torch.device("cuda")
137
+ elif torch.backends.mps.is_available():
138
+ self.device = torch.device("mps")
139
+ else:
140
+ self.device = torch.device("cpu")
141
+
142
+ for width in CELL_WIDTHS:
143
+ path = Path(model_dir) / f"weights{width}.pt"
144
+ if not path.exists():
145
+ continue
146
+ ckpt = torch.load(path, map_location=self.device, weights_only=True)
147
+ cell = HornerCell(**ckpt.get("config", {}))
148
+ cell.load_state_dict(ckpt["state_dict"])
149
+ cell.to(self.device)
150
+ cell.eval()
151
+ self.cells[width] = cell
152
+
153
+ if not self.cells:
154
+ raise FileNotFoundError(
155
+ f"no weights{{{','.join(map(str, CELL_WIDTHS))}}}.pt found in {model_dir}"
156
+ )
157
+
158
+ def preprocess_a(self, a):
159
+ return a
160
+
161
+ def preprocess_b(self, b):
162
+ return b
163
+
164
+ def preprocess_p(self, p):
165
+ return p
166
+
167
+ @torch.no_grad()
168
+ def predict_digits(self, a_enc, b_enc, p_enc):
169
+ return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
170
+
171
+ @torch.no_grad()
172
+ def _run_cell(self, width: int, rows: list[tuple[int, int, int]]) -> list[list[int]]:
173
+ """Scan the width-bit cell over a batch of (a_red, b_red, p) rows."""
174
+ cell = self.cells[width]
175
+ a_bits = _pack_bits([r[0] for r in rows], width, self.device)
176
+ bb = _pack_bits([r[1] for r in rows], width, self.device)
177
+ pb = _pack_bits([r[2] for r in rows], width, self.device)
178
+ state = torch.zeros(len(rows), width, device=self.device)
179
+ # RNN scan over the bit tokens of (a mod p), MSB-first. The scan moves
180
+ # data; the learned cell does all the computing.
181
+ for s in range(width - 1, -1, -1):
182
+ bit = a_bits[:, s : s + 1]
183
+ logits = cell(state, bit, bb, pb)
184
+ state = (logits > 0).float() # quantized state bottleneck
185
+ return state.long().tolist() # LSB-first per row
186
+
187
+ @torch.no_grad()
188
+ def predict_digits_batch(self, inputs):
189
+ assert self.cells, "load() must run first"
190
+ out: list[list[int] | None] = [None] * len(inputs)
191
+ widths = sorted(self.cells)
192
+ widest = widths[-1]
193
+
194
+ # Bucket each problem by the narrowest cell whose state holds the prime.
195
+ buckets: dict[int, tuple[list[int], list[tuple[int, int, int]]]] = {
196
+ w: ([], []) for w in widths
197
+ }
198
+ for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
199
+ p = int(p_enc)
200
+ if p >= (1 << widest):
201
+ out[i] = [0] # outside every trained regime: honest fallback
202
+ continue
203
+ w = next(w for w in widths if p < (1 << w))
204
+ idx, rows = buckets[w]
205
+ idx.append(i)
206
+ rows.append((int(a_enc) % p, int(b_enc) % p, p))
207
+
208
+ for w in widths:
209
+ idx, rows = buckets[w]
210
+ if rows:
211
+ bits = self._run_cell(w, rows)
212
+ for j, i in enumerate(idx):
213
+ out[i] = bits[j][::-1] # emit MSB-first, base 2
214
+
215
+ return [o if o is not None else [0] for o in out]
216
+
217
+ def max_batch_size(self) -> int:
218
+ return 1024
train.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train the horner_rnn transition cell (bit-level Horner step) + chain fine-tuning.
2
+
3
+ Stage 1: train cell f(t, bit, b, p) = (2t + bit*b) mod p (quotients {0,1,2},
4
+ easier than base-4's {0..6}) with grad clipping, EMA, hard-boundary mining.
5
+
6
+ Stage 2 (optional, default off): fine-tune end-to-end through the 16-step
7
+ chain with a straight-through estimator on the quantized state, loss on every
8
+ step's ground-truth intermediate. In practice this was destructive at lr2=5e-5
9
+ (chain val collapsed); the shipped weights come from stage 1 alone, which
10
+ reaches chain val ~0.998 on held-out primes. Kept for further experimentation
11
+ at lower learning rates.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import time
18
+
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ # Import the shared architecture from the sibling model.py.
26
+ HERE = Path(__file__).resolve().parent
27
+ sys.path.insert(0, str(HERE))
28
+ from model import HornerCell, BITS, _to_bits as to_bits # noqa: E402
29
+
30
+
31
+ def sieve_primes(limit: int) -> list[int]:
32
+ is_p = bytearray([1]) * limit
33
+ is_p[0] = is_p[1] = 0
34
+ for i in range(2, int(limit ** 0.5) + 1):
35
+ if is_p[i]:
36
+ is_p[i * i :: i] = bytearray(len(is_p[i * i :: i]))
37
+ return [i for i in range(2, limit) if is_p[i]]
38
+
39
+
40
+ def sample_batch(primes_t, n, device, hard_frac=0.5):
41
+ p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
42
+ b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
43
+ bit = torch.randint(0, 2, (n,), device=device)
44
+ n_hard = int(n * hard_frac)
45
+ t = torch.empty(n, dtype=torch.long, device=device)
46
+ t[n_hard:] = (torch.rand(n - n_hard, device=device) * p[n_hard:]).long()
47
+ if n_hard:
48
+ ph, bh, bith = p[:n_hard], b[:n_hard], bit[:n_hard]
49
+ q = torch.randint(0, 3, (n_hard,), device=device)
50
+ delta = torch.randint(-2, 3, (n_hard,), device=device)
51
+ th = (q * ph + delta - bith * bh) >> 1
52
+ t[:n_hard] = th.clamp(min=0) % ph
53
+ z = (2 * t + bit * b) % p
54
+ return t, bit, b, p, z
55
+
56
+
57
+ @torch.no_grad()
58
+ def exact_rate(model, primes_t, device, n=200_000, bs=65536) -> float:
59
+ ok = 0
60
+ for i in range(0, n, bs):
61
+ m = min(bs, n - i)
62
+ t, bit, b, p, z = sample_batch(primes_t, m, device, hard_frac=0.0)
63
+ logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
64
+ ok += ((logits > 0).long() == to_bits(z).long()).all(dim=1).sum().item()
65
+ return ok / n
66
+
67
+
68
+ @torch.no_grad()
69
+ def chain_exact_rate(model, primes_t, device, n=20_000) -> float:
70
+ p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
71
+ a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
72
+ b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
73
+ truth = (a * b) % p
74
+ bb, pb = to_bits(b), to_bits(p)
75
+ tb = torch.zeros(n, BITS, device=device)
76
+ for i in range(BITS - 1, -1, -1):
77
+ bit = ((a >> i) & 1).float().unsqueeze(1)
78
+ tb = (model(tb, bit, bb, pb) > 0).float()
79
+ pred = (tb.long() * (1 << torch.arange(BITS, device=device))).sum(dim=1)
80
+ return (pred == truth).float().mean().item()
81
+
82
+
83
+ def chain_finetune_batch(model, primes_t, n, device, loss_fn):
84
+ """One end-to-end pass: STE state, per-step CE against true intermediates."""
85
+ p = primes_t[torch.randint(len(primes_t), (n,), device=device)]
86
+ a = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
87
+ b = (torch.rand(n, device=device) * p).long().clamp(max=p - 1)
88
+ bb, pb = to_bits(b), to_bits(p)
89
+ tb = torch.zeros(n, BITS, device=device)
90
+ t_true = torch.zeros_like(a)
91
+ loss = torch.zeros((), device=device)
92
+ for i in range(BITS - 1, -1, -1):
93
+ bit_i = (a >> i) & 1
94
+ t_true = (2 * t_true + bit_i * b) % p
95
+ logits = model(tb, bit_i.float().unsqueeze(1), bb, pb)
96
+ loss = loss + loss_fn(logits, to_bits(t_true))
97
+ hard = (logits > 0).float()
98
+ soft = torch.sigmoid(logits)
99
+ tb = hard + (soft - soft.detach()) # straight-through
100
+ return loss / BITS
101
+
102
+
103
+ def main() -> int:
104
+ ap = argparse.ArgumentParser()
105
+ ap.add_argument("--stage1-minutes", type=float, default=50.0)
106
+ ap.add_argument("--stage2-minutes", type=float, default=0.0)
107
+ ap.add_argument("--batch", type=int, default=32768)
108
+ ap.add_argument("--chain-batch", type=int, default=4096)
109
+ ap.add_argument("--lr", type=float, default=3e-4)
110
+ ap.add_argument("--lr2", type=float, default=5e-5)
111
+ ap.add_argument("--width", type=int, default=4096)
112
+ ap.add_argument("--depth", type=int, default=4)
113
+ ap.add_argument("--init", type=str, default="")
114
+ ap.add_argument("--out", type=str, default=str(HERE / "weights16.pt"))
115
+ args = ap.parse_args()
116
+
117
+ device = torch.device("cuda")
118
+ torch.manual_seed(0)
119
+
120
+ small = sieve_primes(256)
121
+ primes = [p for p in sieve_primes(1 << 16) if p >= 256]
122
+ g = torch.Generator().manual_seed(1)
123
+ perm = torch.randperm(len(primes), generator=g).tolist()
124
+ val_primes = torch.tensor([primes[i] for i in perm[: len(primes) // 10]], device=device)
125
+ train_primes = torch.tensor(
126
+ small + [primes[i] for i in perm[len(primes) // 10 :]], device=device
127
+ )
128
+ print(f"train primes {len(train_primes)}, val primes {len(val_primes)}")
129
+
130
+ model = HornerCell(args.width, args.depth).to(device)
131
+ if args.init:
132
+ ckpt = torch.load(args.init, map_location=device, weights_only=True)
133
+ model.load_state_dict(ckpt["state_dict"])
134
+ print(f"initialised from {args.init}")
135
+ ema = HornerCell(args.width, args.depth).to(device)
136
+ ema.load_state_dict(model.state_dict())
137
+ for q in ema.parameters():
138
+ q.requires_grad_(False)
139
+ print(f"params: {sum(t.numel() for t in model.parameters()):,}")
140
+ loss_fn = nn.BCEWithLogitsLoss()
141
+ EMA_DECAY = 0.999
142
+
143
+ def update_ema():
144
+ with torch.no_grad():
145
+ for q, w in zip(ema.parameters(), model.parameters()):
146
+ q.lerp_(w, 1 - EMA_DECAY)
147
+
148
+ best_chain = -1.0
149
+
150
+ def save_if_best(tag):
151
+ nonlocal best_chain
152
+ ch = chain_exact_rate(ema, val_primes, device)
153
+ if ch > best_chain:
154
+ best_chain = ch
155
+ torch.save({"state_dict": ema.state_dict(), "config": ema.config}, args.out)
156
+ return ch
157
+
158
+ # ----- Stage 1: cell training -----
159
+ if args.stage1_minutes > 0:
160
+ opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
161
+ total_steps = int(args.stage1_minutes * 60 * 16)
162
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr * 0.02)
163
+ deadline = time.monotonic() + args.stage1_minutes * 60
164
+ start = time.monotonic()
165
+ step = 0
166
+ while time.monotonic() < deadline:
167
+ t, bit, b, p, z = sample_batch(train_primes, args.batch, device)
168
+ logits = model(to_bits(t), bit.float().unsqueeze(1), to_bits(b), to_bits(p))
169
+ loss = loss_fn(logits, to_bits(z))
170
+ opt.zero_grad()
171
+ loss.backward()
172
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
173
+ opt.step()
174
+ if step < total_steps:
175
+ sched.step()
176
+ update_ema()
177
+ step += 1
178
+ if step % 1000 == 0:
179
+ va = exact_rate(ema, val_primes, device, n=100_000)
180
+ ch = save_if_best("s1")
181
+ print(
182
+ f"S1 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
183
+ f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
184
+ flush=True,
185
+ )
186
+
187
+ # ----- Stage 2: end-to-end chain fine-tuning (STE) -----
188
+ if args.stage2_minutes > 0:
189
+ opt = torch.optim.AdamW(model.parameters(), lr=args.lr2, weight_decay=1e-5)
190
+ total_steps = int(args.stage2_minutes * 60 * 3)
191
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=args.lr2 * 0.1)
192
+ deadline = time.monotonic() + args.stage2_minutes * 60
193
+ start = time.monotonic()
194
+ step = 0
195
+ while time.monotonic() < deadline:
196
+ loss = chain_finetune_batch(model, train_primes, args.chain_batch, device, loss_fn)
197
+ opt.zero_grad()
198
+ loss.backward()
199
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
200
+ opt.step()
201
+ if step < total_steps:
202
+ sched.step()
203
+ update_ema()
204
+ step += 1
205
+ if step % 200 == 0:
206
+ va = exact_rate(ema, val_primes, device, n=100_000)
207
+ ch = save_if_best("s2")
208
+ print(
209
+ f"S2 step {step:6d} | loss {loss.item():.5f} | ema cell val {va:.5f} "
210
+ f"| ema CHAIN val {ch:.4f} | {time.monotonic()-start:.0f}s",
211
+ flush=True,
212
+ )
213
+
214
+ va = exact_rate(ema, val_primes, device, n=500_000)
215
+ ch = chain_exact_rate(ema, val_primes, device, n=50_000)
216
+ print(f"FINAL ema cell val {va:.6f} | chain val {ch:.4f} | best chain {best_chain:.4f}")
217
+ return 0
218
+
219
+
220
+ if __name__ == "__main__":
221
+ raise SystemExit(main())