etwk commited on
Commit
41fc51b
Β·
1 Parent(s): b670993

Ship shared 1024-2048 high cell (V2); sync docs + model

Browse files

Replace the two dedicated high cells (weights1024.pt, weights2048.pt) with one
shared carry-aware-TCN cell across tiers 9-10 (weights_shared_1024_2048.pt):
distilled to both dedicated teachers + worst-bit margin, 2048 chain preserved so
the primary key (highest_tier_above_90=10) is held. tier 9 = 1.00, tier 10 = 1.00,
overall_accuracy 1.000; two shared weight files total (~10.7M params, 0.04 GB).

model.py: _build_cell drops non-constructor keys (unified/widths).

README.md CHANGED
@@ -17,9 +17,9 @@ metrics:
17
  # horner_rnn
18
 
19
  A compliant bit-sequential RNN that **clears every reduction tier, 1 through 10** (primes up to
20
- 2^2048) on the public benchmark β€” tiers 1-8 = 100%, tier 9 = 99%, **tier 10 = 100%** β€”
21
- so `highest_tier_above_90 = 10` (the maximum), overall_accuracy **0.999**. Every cell is
22
- the same **carry-aware TCN** (~15.4M params total across three weight files, 0.06 GB), so its capability comes from *learning one algorithmic step* rather
23
  than memorising finite multiplication tables, and it verifiably generalises to primes never seen
24
  in training.
25
 
@@ -53,18 +53,18 @@ The recurrence is exact only if the state is wide enough to hold the residue, so
53
  trained per bit-width β€” but because the dilated convolution is weight-shared across bit-positions
54
  and the carry/borrow rule is position-invariant, **one shared weight-set serves all small/mid
55
  widths 16/32/64/128/256/512** (run at each prime's native width). The model therefore ships
56
- **three weight files** and routes each problem to the narrowest cell whose state holds its prime:
57
 
58
  | Weight file | Primes | Tiers | Architecture | Params | Public benchmark |
59
  |---|---|---|---|---|---|
60
  | `weights_shared_16_512.pt` | `< 2^512` | 1-8 | carry-aware TCN, 14 blocks, dil 1..256 β€” **one shared set**, run at native width | ~5.5M | tiers 1-8 = 1.00 |
61
- | `weights1024.pt` | `< 2^1024` | 9 | carry-aware TCN, 12 blocks, dil 1..512 | ~4.7M | tier 9 = 0.99 |
62
- | `weights2048.pt` | `< 2^2048` | 10 | carry-aware TCN, 13 blocks, dil 1..1024 | ~5.1M | tier 10 = 1.00 |
63
 
64
  The earlier four separate mid-width cells had already collapsed into one shared 64–512 set;
65
- this version further merges the 16- and 32-bit small-prime cells into that same shared block-pool.
66
- The final shared 16–512 set reaches tiers 1–8 = 1.00 and cuts the total to **~15.4M params,
67
- 0.06 GB**. For `p >= 2^2048` (outside all regimes) the model emits the honest `[0]`
 
68
  fallback without invoking the network.
69
 
70
  ## The carry-aware TCN (every tier)
@@ -83,7 +83,7 @@ The two small cells were originally width-4096/6144 MLPs (660 MB combined); repl
83
  the carry-aware TCN, trained width-matched (bit-length-uniform over the cell's whole range),
84
  shrank the artifact from 0.77 GB to ~0.13 GB (the later mid-cell collapse then brought the total
85
  to **0.08 GB**), raised tier 4 from 0.99 to **1.00**, and made
86
- the small-prime tiers width-robust before the final 16–512 merge cut the artifact to **0.06 GB**.
87
  A TCN trained near-max-width only has a short-prime blind spot (see the audit note below), which
88
  the width-matched training removes.
89
 
@@ -222,17 +222,55 @@ soup 1.00 β‰₯ 0.99 on the same draw, tiers 1-9 byte-identical. Full recipe and f
222
  OOMs otherwise) and disk-cached prime pools (`--build-pools-only`; gmpy2 `next_prime` is
223
  ~227 ms/prime at 2048-bit). Validate with `python exploration/score_tier10.py <ckpt>`.
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  ## Score (public benchmark, fixed seed)
226
 
227
  | Total problems | overall_accuracy | highest_tier_above_90 | deterministic |
228
  |---|---|---|---|
229
- | **1100** | **0.999** | **10** (max) | True |
230
 
231
- Per-tier at total=1100: tiers 1–8 **1.00**, tier 9 **0.99**, tier 10 **1.00**
232
  (overall_accuracy is the mean over tiers 1-10). Tier 0 (pure multiplication, primes near each
233
- width's maximum β€” a separate regime, not in overall_accuracy) is **0.64** on this fixed public
234
- seed. Inference for all 1100 problems is 171s, within the 300s budget (the 2048-step tier-10 scan
235
- is the bulk); artifact 0.06 GB.
236
 
237
  ## Status under the rules
238
 
@@ -271,8 +309,11 @@ faithful 5-prime bootstrap plus a fixed-seed end-to-end A/B, no regression on an
271
  Earlier this round the thin small/mid tiers were re-polished with the width-matched,
272
  worst-bit-margin recipe and then collapsed into the shared 16–512 soup β€” **tier 8 0.92 β†’ 1.00**
273
  public, with matched faithful bootstrap E[acc] 0.9866 β†’ 0.9931 and `P(tier8 < 0.95)` 1.396% β†’
274
- 0.205%. Tier 10 independently improved **0.94 β†’ 0.98 β†’ 1.00**. `overall_accuracy` is now **0.999**
275
- with tiers 1–8 all at 1.00 and the lowest scored tier tier 9 = 0.99. Tier 0 (pure multiplication,
 
 
 
276
  primes near each width's maximum) is excluded from `overall_accuracy`, so it moves neither ranking
277
  key. Both ranking keys are saturated; remaining gains are sub-percent.
278
 
@@ -282,7 +323,7 @@ of each tier's bit-range), a cell trained near-max-width only can score ~0 on sh
282
  still look perfect on the public set β€” exactly the gap that capped tier 9 before it was
283
  width-matched. Every shipped cell is now trained width-matched (value-uniform **plus** a
284
  bit-length-uniform band): the shared 16–512 cell on the full {16,32,64,128,256,512} mix,
285
- and the 1024/2048 cells across their own ranges. Re-auditing the shared-cell model on 40k
286
  secret-style draws found **P(tier < 0.90) β‰ˆ 0.000%** β€” the shared 16–512 cell (tiers 1–8) shows
287
  no width knee, and tiers 9/10 are blind only in the *deep* value-uniform tail (knees ~970-bit /
288
  ~1950-bit), which carries β‰ˆ2⁻⁡⁴ / 2⁻⁹⁸ of the draw mass and is effectively unsamplable. No
 
17
  # horner_rnn
18
 
19
  A compliant bit-sequential RNN that **clears every reduction tier, 1 through 10** (primes up to
20
+ 2^2048) on the public benchmark β€” **tiers 1-10 = 100%** β€”
21
+ so `highest_tier_above_90 = 10` (the maximum), overall_accuracy **1.000**. Every cell is
22
+ the same **carry-aware TCN** (~10.7M params total across two shared weight files, 0.04 GB), so its capability comes from *learning one algorithmic step* rather
23
  than memorising finite multiplication tables, and it verifiably generalises to primes never seen
24
  in training.
25
 
 
53
  trained per bit-width β€” but because the dilated convolution is weight-shared across bit-positions
54
  and the carry/borrow rule is position-invariant, **one shared weight-set serves all small/mid
55
  widths 16/32/64/128/256/512** (run at each prime's native width). The model therefore ships
56
+ **two shared weight files** and routes each problem to the narrowest cell whose state holds its prime:
57
 
58
  | Weight file | Primes | Tiers | Architecture | Params | Public benchmark |
59
  |---|---|---|---|---|---|
60
  | `weights_shared_16_512.pt` | `< 2^512` | 1-8 | carry-aware TCN, 14 blocks, dil 1..256 β€” **one shared set**, run at native width | ~5.5M | tiers 1-8 = 1.00 |
61
+ | `weights_shared_1024_2048.pt` | `< 2^2048` | 9-10 | carry-aware TCN, 13 blocks, dil 1..1024 β€” **one shared high-width set**, run at native width | ~5.1M | tier 9 = 1.00, tier 10 = 1.00 |
 
62
 
63
  The earlier four separate mid-width cells had already collapsed into one shared 64–512 set;
64
+ this version further merges the 16- and 32-bit small-prime cells into that same shared block-pool,
65
+ and merges the 1024/2048 cells into a second high-width shared block-pool. The final two shared
66
+ sets reach tiers 1–10 = 1.00 and cut the total to **~10.7M
67
+ params, 0.04 GB**. For `p >= 2^2048` (outside all regimes) the model emits the honest `[0]`
68
  fallback without invoking the network.
69
 
70
  ## The carry-aware TCN (every tier)
 
83
  the carry-aware TCN, trained width-matched (bit-length-uniform over the cell's whole range),
84
  shrank the artifact from 0.77 GB to ~0.13 GB (the later mid-cell collapse then brought the total
85
  to **0.08 GB**), raised tier 4 from 0.99 to **1.00**, and made
86
+ the small-prime tiers width-robust before the final 16–512 and 1024–2048 merges cut the artifact to **0.04 GB**.
87
  A TCN trained near-max-width only has a short-prime blind spot (see the audit note below), which
88
  the width-matched training removes.
89
 
 
222
  OOMs otherwise) and disk-cached prime pools (`--build-pools-only`; gmpy2 `next_prime` is
223
  ~227 ms/prime at 2048-bit). Validate with `python exploration/score_tier10.py <ckpt>`.
224
 
225
+ ### High-width shared set (1024 + 2048)
226
+
227
+ The final shrink step shares one 13-block high-width TCN across tiers 9 and 10. Directly running
228
+ the 2048 cell at native 1024 width already scored tier 9 = 0.94, so the route was a bounded
229
+ 1024+2048 polish from the public-correct 2048 cell, not extending the 16–512 cell upward.
230
+
231
+ The decisive lever is **distillation to the two dedicated teachers** plus a worst-bit margin loss:
232
+ warm-start from the dedicated 2048 cell, train jointly at widths 1024/2048, and distill the
233
+ 1024-width logits toward the strong dedicated 1024 teacher (which transfers its 1024 chain
234
+ robustness) and the 2048-width logits toward the dedicated 2048 teacher (which holds the tier-10
235
+ primary key). A 2048 chain-preservation floor guards the primary key β€” no checkpoint that erodes
236
+ the 2048 chain can be saved. This makes one shared cell match *both* dedicated cells at their own
237
+ widths, with no model-soup needed:
238
+
239
+ ```bash
240
+ # shared high cell: distill to both dedicated teachers + worst-bit margin, 2048 preserved
241
+ python exploration/train_unified.py --warm \
242
+ --init-from checkpoints/weights2048_ship_shared16_prev.pt \
243
+ --widths 1024,2048 --width-weights 1024:0.7,2048:0.3 \
244
+ --blocks 13 --max-dil 1024 --grad-checkpoint --max-rows 512 --accum 8 \
245
+ --bitlen-frac 0.5 --lr 3e-5 --stage-a 0.08 --stage-c 0.12 \
246
+ --margin-weight 0.5 --margin-m1 12.0 \
247
+ --distill-weight 0.15 \
248
+ --distill-map 1024:checkpoints/weights1024_ship_shared16_prev.pt,2048:checkpoints/weights2048_ship_shared16_prev.pt \
249
+ --preserve-widths 2048 --preserve-chain 0.98 \
250
+ --out checkpoints/shared_high_v2_s1.pt
251
+ # package the .final cell (clean config + top-level widths=[1024,2048]):
252
+ python exploration/package_shared_high.py checkpoints/shared_high_v2_s1.pt.final \
253
+ <prev weights_shared_1024_2048.pt as config template> horner_rnn/weights_shared_1024_2048.pt
254
+ ```
255
+
256
+ An earlier attempt that pinned the public-correct 2048 cell in by model-soup (0.70Β·old-2048 +
257
+ 0.30Β·pilot) held tier 10 but **regressed tier 9** under a faithful 5-prime bootstrap (E 0.968,
258
+ worst-prime 0.80) because the old-2048 cell is only ~0.94 at native 1024 β€” so the soup route was
259
+ dropped in favour of the distill+margin cell above. Gate (`diag_5prime_boot`, pool 100, seed 991):
260
+ tier 9 E[acc] 0.9939 / worst-prime 0.933 (β‰ˆ the dedicated cell), tier 10 E[acc] 0.9913 /
261
+ P(acc<0.90) 0.002% / worst-prime 0.933 (primary key held). Public benchmark: tiers 9 and 10 = 1.00.
262
+
263
  ## Score (public benchmark, fixed seed)
264
 
265
  | Total problems | overall_accuracy | highest_tier_above_90 | deterministic |
266
  |---|---|---|---|
267
+ | **1100** | **1.000** | **10** (max) | True |
268
 
269
+ Per-tier at total=1100: tiers 1–10 all **1.00**
270
  (overall_accuracy is the mean over tiers 1-10). Tier 0 (pure multiplication, primes near each
271
+ width's maximum β€” a separate regime, not in overall_accuracy) is **0.70** on this fixed public
272
+ seed. Inference for all 1100 problems is ~174s, within the 300s budget (the 2048-step tier-10 scan
273
+ is the bulk); artifact 0.04 GB.
274
 
275
  ## Status under the rules
276
 
 
309
  Earlier this round the thin small/mid tiers were re-polished with the width-matched,
310
  worst-bit-margin recipe and then collapsed into the shared 16–512 soup β€” **tier 8 0.92 β†’ 1.00**
311
  public, with matched faithful bootstrap E[acc] 0.9866 β†’ 0.9931 and `P(tier8 < 0.95)` 1.396% β†’
312
+ 0.205%. Tier 10 independently improved **0.94 β†’ 0.98 β†’ 1.00**, then the 1024/2048 cells were collapsed
313
+ into one high-width shared cell (distilled to both dedicated teachers + worst-bit margin) that
314
+ matches both dedicated cells at their own widths β€” tier 9 recovered to public **1.00** (faithful
315
+ E[acc] 0.9939) while public tier 10 stayed 1.00 (faithful P(acc<0.90) 0.002%). `overall_accuracy`
316
+ is now **1.000** with tiers 1–10 all at 1.00. Tier 0 (pure multiplication,
317
  primes near each width's maximum) is excluded from `overall_accuracy`, so it moves neither ranking
318
  key. Both ranking keys are saturated; remaining gains are sub-percent.
319
 
 
323
  still look perfect on the public set β€” exactly the gap that capped tier 9 before it was
324
  width-matched. Every shipped cell is now trained width-matched (value-uniform **plus** a
325
  bit-length-uniform band): the shared 16–512 cell on the full {16,32,64,128,256,512} mix,
326
+ and the shared 1024–2048 cell across the high-width ranges. Re-auditing the shared-cell model on 40k
327
  secret-style draws found **P(tier < 0.90) β‰ˆ 0.000%** β€” the shared 16–512 cell (tiers 1–8) shows
328
  no width knee, and tiers 9/10 are blind only in the *deep* value-uniform tail (knees ~970-bit /
329
  ~1950-bit), which carries β‰ˆ2⁻⁡⁴ / 2⁻⁹⁸ of the draw mass and is effectively unsamplable. No
manifest.json CHANGED
@@ -2,6 +2,6 @@
2
  "entry_class": "model.HornerRNN",
3
  "output_base": 2,
4
  "framework": "pytorch",
5
- "model_description": "Bit-sequential RNN (~15.4M params across three weight files) for modular multiplication with primes up to 2^2048. The model reads the bits of a mod p MSB-first, one per step, conditioned on (b mod p, p) in binary. Its hidden state is a hard-quantized bit vector, and the transition function is a learned carry-aware dilated-convolution TCN trained to implement the Horner step (t, bit, b, p) -> (2*t + bit*b) mod p. The final hidden state bits are emitted MSB-first as the base-2 answer. Routing uses the narrowest state width that can hold p. A SINGLE shared TCN weight-set, weights_shared_16_512.pt, serves 16/32/64/128/256/512-bit states (tiers 1-8) at each prime's native width and reaches tiers 1-8 = 1.00 on the public benchmark. Dedicated TCN cells weights1024.pt and weights2048.pt cover tier 9 and tier 10, reaching tier 9 = 0.99 and tier 10 = 1.00. For p >= 2^2048 the model emits the honest [0] fallback without invoking the network. The arithmetic is not in Python code: tokenization, scan, threshold, and readout are architecture, while doubling, conditional add, compare/borrow, and reduction are all learned in the trained cell weights; random or perturbed weights collapse to the floor.",
6
- "training_description": "Each cell is trained from exact single-step labels (t, bit, b, p) -> (2*t + bit*b) mod p, with BCE per state bit, AdamW, cosine decay, gradient clipping, EMA checkpointing, and held-out-prime validation. Training data uses true Horner-trajectory states plus boundary-focused examples; prime sampling is value-uniform to match the challenge generator, with bit-length-uniform bands where needed so the reduction boundary is seen at every position. The current shared 16-512 file was built by warm-starting from the shipped shared 64-512 carry-aware TCN (14 residual TCN blocks, 256 channels, dilations cycling through 1..256), fine-tuning on a {16,32,64,128,256,512} width mix, then averaging that run with a small-tier polish tail: soup25 = 0.75 * unified_16to512_warm_s0.final + 0.25 * unified_16to512_smalltail_s1.final. The warm run used a 16-heavy but 512-preserving distribution (16:.40,32:.18,64:.08,128:.08,256:.08,512:.18), accum=8, lr=2e-4, off-trajectory batches (offtraj-frac=.20,k=4), and width-native validation. The polish tail warm-started from that candidate with lower lr=6e-5 and weights 16:.55,32:.25,64:.04,128:.04,256:.04,512:.08, then the soup was selected by public score plus matched 5-prime bootstrap. On the fixed public benchmark this merge lifts the shipped model from overall 0.997 to 0.999: tiers 1-8 = 1.00, tier 9 = 0.99, tier 10 = 1.00. A matched faithful bootstrap over tiers 1-8 (5 primes/tier structure, pool120, k30, boot200k, seed 515151) ties tiers 1-2 and improves tiers 3-8; tier 8 E[acc] improves 0.9866 -> 0.9931 and P(tier8<0.95) drops 1.396% -> 0.205%. The 1024-bit cell was trained separately with benchmark-width-matched primes in [2^513,2^1024), gradient accumulation, and worst-bit margin loss; it remains byte-unchanged and scores tier 9 = 0.99. The 2048-bit cell was bootstrapped from the 1024-bit cell by octave transfer, hardened with low-lr margin tails, then weight-souped across independent margin tails; it remains byte-unchanged and scores tier 10 = 1.00 while reducing faithful 5-prime tier-10 tail risk. Full all-width single-cell unification including 1024/2048 was tested and rejected because one ~5M cell could not preserve 2048-chain robustness while serving small/mid widths; the shipped design intentionally keeps dedicated 1024 and 2048 cells. Compliance checks: preprocess hooks are identity, the legal two-operand reductions a%p and b%p are used only for input normalization, perturbing trained weights collapses accuracy toward the untrained floor, and held-out-prime generalization tracks train accuracy."
7
  }
 
2
  "entry_class": "model.HornerRNN",
3
  "output_base": 2,
4
  "framework": "pytorch",
5
+ "model_description": "Bit-sequential RNN (~10.7M params across two shared carry-aware TCN weight-sets) for modular multiplication with primes up to 2^2048. The model reads the bits of a mod p MSB-first, one per step, conditioned on (b mod p, p) in binary. Its hidden state is a hard-quantized bit vector, and the transition function is a learned carry-aware dilated-convolution TCN trained to implement the Horner step (t, bit, b, p) -> (2*t + bit*b) mod p. The final hidden state bits are emitted MSB-first as the base-2 answer. Routing uses the narrowest state width that can hold p. A shared TCN weight-set, weights_shared_16_512.pt, serves 16/32/64/128/256/512-bit states (tiers 1-8) at each prime's native width and reaches tiers 1-8 = 1.00 on the public benchmark. A second shared TCN weight-set, weights_shared_1024_2048.pt, serves 1024/2048-bit states (tiers 9-10) at native width and reaches tiers 9 and 10 = 1.00 on the public benchmark. For p >= 2^2048 the model emits the honest [0] fallback without invoking the network. The arithmetic is not in Python code: tokenization, scan, threshold, and readout are architecture, while doubling, conditional add, compare/borrow, and reduction are all learned in the trained cell weights; random or perturbed weights collapse to the floor.",
6
+ "training_description": "Each cell is trained from exact single-step labels (t, bit, b, p) -> (2*t + bit*b) mod p, with BCE per state bit, AdamW, cosine decay, gradient clipping, EMA checkpointing, and held-out-prime validation. Training data uses true Horner-trajectory states plus boundary-focused examples; prime sampling is value-uniform to match the challenge generator, with bit-length-uniform bands where needed so the reduction boundary is seen at every position. The shared 16-512 file was built by warm-starting from the shipped shared 64-512 carry-aware TCN (14 residual TCN blocks, 256 channels, dilations cycling through 1..256), fine-tuning on a {16,32,64,128,256,512} width mix, then averaging that run with a small-tier polish tail: soup25 = 0.75 * unified_16to512_warm_s0.final + 0.25 * unified_16to512_smalltail_s1.final. On the fixed public benchmark this merge brought the small/mid tiers 1-8 to 1.00. A matched faithful bootstrap over tiers 1-8 (5 primes/tier structure, pool120, k30, boot200k, seed 515151) ties tiers 1-2 and improves tiers 3-8; tier 8 E[acc] improves 0.9866 -> 0.9931 and P(tier8<0.95) drops 1.396% -> 0.205%. The shared 1024-2048 file was built by warm-starting from the public-correct 2048-bit TCN (13 blocks, max_dil 1024) and training jointly at widths 1024 and 2048 with logit-distillation to BOTH dedicated teachers (the 1024-width logits toward the strong dedicated 1024 cell, which transfers its 1024 chain robustness; the 2048-width logits toward the dedicated 2048 cell, which holds the tier-10 primary key) plus a worst-bit margin loss, under a 2048 chain-preservation floor so no tier-10-eroding checkpoint can be saved. This makes one shared cell match both dedicated cells at their own widths without any model-soup. An earlier soup route (0.70 * old weights2048 + 0.30 * a 1024/2048 pilot) held tier 10 but regressed tier 9 under a faithful 5-prime bootstrap (E 0.968, worst-prime 0.80, because the old 2048 cell is only ~0.94 at native 1024), so it was dropped. Faithful gate (diag_5prime_boot, pool 100, seed 991): tier 9 E[acc] 0.9939 / worst-prime 0.933 (matching the dedicated 1024 cell), tier 10 E[acc] 0.9913 / P(acc<0.90) 0.002% / worst-prime 0.933 (primary key held). Public benchmark: overall_accuracy 1.00, tiers 1-10 all 1.00, highest_tier_above_90 = 10, deterministic. Full all-width single-cell unification across 16..2048 was tested and rejected because one ~5M cell could not preserve 2048-chain robustness while serving small/mid widths; the shipped design intentionally keeps two adjacent shared groups. Compliance checks: preprocess hooks are identity, the legal two-operand reductions a%p and b%p are used only for input normalization, perturbing trained weights collapses accuracy toward the untrained floor, and held-out-prime generalization tracks train accuracy."
7
  }
model.py CHANGED
@@ -32,9 +32,9 @@ the same legal input normalisation every other reference model uses.
32
 
33
  Routing: each problem goes to the narrowest cell whose state holds the prime.
34
  A SINGLE shared carry-aware TCN weight-set covers 16/32/64/128/256/512-bit
35
- primes (tiers 1-8), run at each prime's native width; dedicated TCN cells cover
36
- 1024 (tier 9) and 2048 (tier 10). For primes wider than the widest trained cell
37
- it emits the honest ``[0]`` fallback without invoking the network.
38
  """
39
 
40
  from __future__ import annotations
@@ -171,6 +171,11 @@ class TCNHornerCell(nn.Module):
171
  def _build_cell(config: dict):
172
  """Instantiate the cell class named by config['arch'] (default = MLP HornerCell)."""
173
  cfg = dict(config)
 
 
 
 
 
174
  if cfg.get("arch") == "tcn":
175
  cfg.pop("arch", None)
176
  return TCNHornerCell(**cfg)
@@ -233,9 +238,9 @@ class HornerRNN(ModularMultiplicationModel):
233
  md = Path(model_dir)
234
 
235
  # Shared multi-width cells: ONE weight-set serving several adjacent widths
236
- # (config-declared `widths`). The 16-512 carry-aware TCN ships this way β€” the
237
- # same trained weights run at each prime's native width (see TCNHornerCell.forward),
238
- # matching/beating the prior small/mid cells it replaces.
239
  for shared in sorted(md.glob("weights_shared_*.pt")):
240
  ckpt = torch.load(shared, map_location=self.device, weights_only=True)
241
  cell = _build_cell(ckpt.get("config", {}))
 
32
 
33
  Routing: each problem goes to the narrowest cell whose state holds the prime.
34
  A SINGLE shared carry-aware TCN weight-set covers 16/32/64/128/256/512-bit
35
+ primes (tiers 1-8), and a second shared TCN weight-set covers 1024/2048-bit
36
+ primes (tiers 9-10), both run at each prime's native width. For primes wider than
37
+ the widest trained cell it emits the honest ``[0]`` fallback without invoking the network.
38
  """
39
 
40
  from __future__ import annotations
 
171
  def _build_cell(config: dict):
172
  """Instantiate the cell class named by config['arch'] (default = MLP HornerCell)."""
173
  cfg = dict(config)
174
+ # Tolerate non-constructor metadata that shared/training checkpoints may carry:
175
+ # `unified` is a training-only marker and `widths` (the shared-set width list)
176
+ # lives as a top-level checkpoint key, not a cell-constructor argument.
177
+ cfg.pop("unified", None)
178
+ cfg.pop("widths", None)
179
  if cfg.get("arch") == "tcn":
180
  cfg.pop("arch", None)
181
  return TCNHornerCell(**cfg)
 
238
  md = Path(model_dir)
239
 
240
  # Shared multi-width cells: ONE weight-set serving several adjacent widths
241
+ # (config-declared `widths`). The 16-512 and 1024-2048 carry-aware TCNs
242
+ # ship this way β€” the same trained weights run at each prime's native width
243
+ # (see TCNHornerCell.forward), matching/beating the cells they replace.
244
  for shared in sorted(md.glob("weights_shared_*.pt")):
245
  ckpt = torch.load(shared, map_location=self.device, weights_only=True)
246
  cell = _build_cell(ckpt.get("config", {}))
weights2048.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6e8e3e38e1eb284917c48587cb7f1c6858940ed82516fcb98880d1a6c9668969
3
- size 20531981
 
 
 
 
weights1024.pt β†’ weights_shared_1024_2048.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b21efe3e2094458cb770d38d433f9ac7a23293d66df8dfff79a57142d99005db
3
- size 18957689
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b9dd97ce05468d7b50ed8a1453bb6f5da18315d341d31e23f96b5283c338691
3
+ size 20533517