cire77 commited on
Commit
9e19660
·
verified ·
1 Parent(s): c13b94d

tier5 route [2^32,2^64) to base-16 modmul decoder (htop90=5)

Browse files
Files changed (3) hide show
  1. manifest.json +2 -2
  2. model.py +157 -11
  3. weights.pt +2 -2
manifest.json CHANGED
@@ -2,6 +2,6 @@
2
  "entry_class": "model.EBMModMul",
3
  "output_base": 10,
4
  "framework": "pytorch",
5
- "model_description": "Two trained networks behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tier 3 (512 <= p < 65536): an autoregressive 'abacus' decoder (d_model=384, 8 layers) that emits an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields d:q1:r1:pp:t:q2:r2 - folding multiply and reduction into one Horner pass so no intermediate exceeds ~6 digits; the final remainder digits are the answer. Operands are reduced per-argument (a%p, b%p) before the network runs. Answer emitted as base-10 digits. Emits [0] for p >= 65536 (tiers 4+, out of the trained range).",
6
- "training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier-3 scratchpad: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp explicitly), trained over the full tier-3 prime range [512,65536) with cosine-annealed AdamW, LR warmup, grad clipping and bf16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy."
7
  }
 
2
  "entry_class": "model.EBMModMul",
3
  "output_base": 10,
4
  "framework": "pytorch",
5
+ "model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add. Emits [0] for p >= 2**64 (tiers 6+, out of the trained range).",
6
+ "training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy."
7
  }
model.py CHANGED
@@ -9,11 +9,12 @@ Compliance contract (see rules/evaluation.md):
9
  ``[0, p)``) materially determines the answer.
10
  - We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
11
 
12
- Routing by prime size: tiers 1-2 (p < 512) use the classification head; tier 3
13
- (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) use the interleaved
14
- modular-multiply scratchpad decoder (same architecture, separately trained
15
- weights). p >= 2**32 (tiers 5+) is out of regime, so we emit ``[0]`` an honest
16
- fallback, not a guess.
 
17
 
18
  The architecture (encoder + classification/angular head) is loaded from the
19
  checkpoint's ``arch`` field, so the same wrapper serves either trained head.
@@ -230,11 +231,15 @@ def _digits_msb(n: int) -> list[int]:
230
 
231
  class AbacusDecoder(nn.Module):
232
  """Decoder-only transformer with abacus (place-within-number) embeddings.
233
- Architecture identical to training/modmul_probe.py for state_dict match."""
 
 
 
234
 
235
- def __init__(self, max_len, abacus_max, d_model=384, nhead=8, num_layers=8, dim_ff=1536):
 
236
  super().__init__()
237
- self.tok_emb = nn.Embedding(MM_VOCAB, d_model)
238
  self.pos_emb = nn.Embedding(max_len, d_model)
239
  self.abacus_emb = nn.Embedding(abacus_max, d_model)
240
  layer = nn.TransformerEncoderLayer(
@@ -243,7 +248,7 @@ class AbacusDecoder(nn.Module):
243
  )
244
  self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
245
  self.ln = nn.LayerNorm(d_model)
246
- self.head = nn.Linear(d_model, MM_VOCAB, bias=False)
247
  self.max_len = max_len
248
  self.register_buffer("pos_ids", torch.arange(max_len), persistent=False)
249
 
@@ -326,6 +331,111 @@ def _modmul_decode(model, cfg, xyp, device, chunk=128):
326
  return [o if o is not None else [0] for o in out]
327
 
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # ---------------------------------------------------------------------------
330
  # Submission entry class
331
  # ---------------------------------------------------------------------------
@@ -339,6 +449,8 @@ class EBMModMul(ModularMultiplicationModel):
339
  self.mm_cfg = None
340
  self.mm4 = None # tier-4 modmul scratchpad
341
  self.mm4_cfg = None
 
 
342
 
343
  def load(self, model_dir: str) -> None:
344
  if torch.cuda.is_available():
@@ -375,6 +487,17 @@ class EBMModMul(ModularMultiplicationModel):
375
  ).to(self.device)
376
  self.mm4.load_state_dict(ckpt["tier4"]["state_dict"])
377
  self.mm4.eval()
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  # Per-argument identity preprocessing (each hook sees only its own argument).
380
  def preprocess_a(self, a): return a
@@ -392,6 +515,7 @@ class EBMModMul(ModularMultiplicationModel):
392
  TIER3_LO = 512
393
  TIER3_HI = 65536
394
  TIER4_HI = 2 ** 32
 
395
 
396
  @torch.no_grad()
397
  def predict_digits_batch(self, inputs):
@@ -399,16 +523,22 @@ class EBMModMul(ModularMultiplicationModel):
399
  x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], [] # tiers 1-2
400
  mm_items, mm_idx = [], [] # tier 3
401
  mm4_items, mm4_idx = [], [] # tier 4
 
402
 
403
  for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
404
  p = int(p_enc)
405
  # Out of regime (residues don't fit the trained range): honest 0.
406
- if p >= self.TIER4_HI:
407
  out[i] = [0]
408
  continue
409
  a_red = int(a_enc) % p # per-operand reduction (allowed)
410
  b_red = int(b_enc) % p
411
- if p >= self.TIER3_HI:
 
 
 
 
 
412
  if self.mm4 is not None:
413
  mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
414
  else:
@@ -445,6 +575,22 @@ class EBMModMul(ModularMultiplicationModel):
445
  for j, i in enumerate(mm4_idx):
446
  out[i] = res[j]
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  return [o if o is not None else [0] for o in out]
449
 
450
  def max_batch_size(self) -> int:
 
9
  ``[0, p)``) materially determines the answer.
10
  - We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
11
 
12
+ Routing by prime size: tiers 1-2 (p < 512) use the classification head; tiers 3-5
13
+ use the interleaved modular-multiply scratchpad decoder (same architecture,
14
+ separately trained weights) — tier 3 (512 <= p < 65536) and tier 4
15
+ (65536 <= p < 2**32) in numeric base 10, tier 5 (2**32 <= p < 2**64) in numeric
16
+ base 16 (shorter Horner chain at large prime sizes). p >= 2**64 (tiers 6+) is out
17
+ of regime, so we emit ``[0]`` — an honest fallback, not a guess.
18
 
19
  The architecture (encoder + classification/angular head) is loaded from the
20
  checkpoint's ``arch`` field, so the same wrapper serves either trained head.
 
231
 
232
  class AbacusDecoder(nn.Module):
233
  """Decoder-only transformer with abacus (place-within-number) embeddings.
234
+ Architecture identical to training/modmul_probe.py / tier5_modmul.py for
235
+ state_dict match. ``vocab`` defaults to the base-10 scratchpad vocab (18) used
236
+ by tiers 3-4; the tier-5 base-16 scratchpad passes vocab=24 (16 digits + 8
237
+ specials)."""
238
 
239
+ def __init__(self, max_len, abacus_max, d_model=384, nhead=8, num_layers=8,
240
+ dim_ff=1536, vocab=MM_VOCAB):
241
  super().__init__()
242
+ self.tok_emb = nn.Embedding(vocab, d_model)
243
  self.pos_emb = nn.Embedding(max_len, d_model)
244
  self.abacus_emb = nn.Embedding(abacus_max, d_model)
245
  layer = nn.TransformerEncoderLayer(
 
248
  )
249
  self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
250
  self.ln = nn.LayerNorm(d_model)
251
+ self.head = nn.Linear(d_model, vocab, bias=False)
252
  self.max_len = max_len
253
  self.register_buffer("pos_ids", torch.arange(max_len), persistent=False)
254
 
 
331
  return [o if o is not None else [0] for o in out]
332
 
333
 
334
+ # ---------------------------------------------------------------------------
335
+ # Tier-5 base-16 modular-multiply scratchpad (autoregressive).
336
+ #
337
+ # Same AbacusDecoder architecture and 9-field Horner scratchpad as tiers 3-4, but
338
+ # trained in numeric BASE 16 (so the per-step partial products / quotient digits
339
+ # stay easy while the chain length is bounded). tier-5 primes are 33-64 bit, so the
340
+ # chain is ~16 base-16 Horner blocks (~1853 tokens). Vocab: digits 0..15 then
341
+ # PAD,BOS,MUL,MOD,EQ,COLON,STEP,EOS = base..base+7 (see tier5_modmul.make_vocab).
342
+ # The decoded answer is a BASE-16 residue; we convert it to base-10 digits with
343
+ # multiply-add only (no %, //, Barrett/Montgomery/CRT on the product) so it matches
344
+ # the global output_base=10. Compliance is unchanged: the only modular reduction in
345
+ # shipped code is the per-operand int(a)%p / int(b)%p done before the network runs.
346
+ # ---------------------------------------------------------------------------
347
+
348
+
349
+ def _make_vocab_base(base: int) -> dict:
350
+ """Base-B scratchpad vocab, matching training/tier5_modmul.make_vocab."""
351
+ PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS = (
352
+ base, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6, base + 7)
353
+ return dict(PAD=PAD, BOS=BOS, MUL=MUL, MOD=MOD, EQ=EQ, COLON=COLON, STEP=STEP,
354
+ EOS=EOS, VOCAB=base + 8,
355
+ SPECIALS={PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS})
356
+
357
+
358
+ def _digits_base_msb(n: int, base: int) -> list[int]:
359
+ if n == 0:
360
+ return [0]
361
+ s = []
362
+ while n > 0:
363
+ s.append(n % base)
364
+ n //= base
365
+ return s[::-1]
366
+
367
+
368
+ def _base_to_int(ds: list[int], base: int) -> int:
369
+ v = 0
370
+ for d in ds:
371
+ v = v * base + d # multiply-add only; no %/// on the product
372
+ return v
373
+
374
+
375
+ @torch.no_grad()
376
+ def _modmul_decode_base(model, cfg, xyp, device, base, chunk=64):
377
+ """Greedy-decode (x*y) mod p in numeric base ``base`` for each (x, y, p) with
378
+ x, y already in [0, p). Returns base-10 digit-lists (MSB-first), or [0] if
379
+ unparseable. Mirrors _modmul_decode but base-parametrized; the final base-``base``
380
+ residue is re-expressed in base 10 via multiply-add (compliant)."""
381
+ V = _make_vocab_base(base)
382
+ PAD, EOS, COLON = V["PAD"], V["EOS"], V["COLON"]
383
+ max_len, abmax = cfg["max_len"], cfg["abacus_max"]
384
+ specials = torch.tensor(sorted(V["SPECIALS"]), device=device)
385
+ out: list[list[int] | None] = [None] * len(xyp)
386
+
387
+ groups = defaultdict(list)
388
+ prompts = []
389
+ for i, (x, y, p) in enumerate(xyp):
390
+ xd, yd, pd = (_digits_base_msb(x, base), _digits_base_msb(y, base),
391
+ _digits_base_msb(p, base))
392
+ toks = [V["BOS"]] + xd + [V["MUL"]] + yd + [V["MOD"]] + pd + [V["EQ"]]
393
+ abac = ([0] + list(range(len(xd))) + [0] + list(range(len(yd)))
394
+ + [0] + list(range(len(pd))) + [0])
395
+ groups[len(toks)].append(i)
396
+ prompts.append((toks, abac))
397
+
398
+ for L, idxs in groups.items():
399
+ for s in range(0, len(idxs), chunk):
400
+ sub = idxs[s:s + chunk]
401
+ g = len(sub)
402
+ toks = torch.tensor([prompts[i][0] for i in sub], dtype=torch.long, device=device)
403
+ abac = torch.tensor([prompts[i][1] for i in sub], dtype=torch.long, device=device)
404
+ seg = torch.zeros(g, dtype=torch.long, device=device)
405
+ done = torch.zeros(g, dtype=torch.bool, device=device)
406
+ gen = [[] for _ in range(g)]
407
+ while toks.shape[1] < max_len and not bool(done.all()):
408
+ nxt = model(toks, abac)[:, -1].argmax(-1)
409
+ nxt = torch.where(done, torch.full_like(nxt, PAD), nxt)
410
+ is_sp = (nxt.unsqueeze(1) == specials).any(1)
411
+ new_abac = torch.where(is_sp, torch.zeros_like(seg),
412
+ torch.clamp(seg, max=abmax - 1))
413
+ seg = torch.where(is_sp, torch.zeros_like(seg), seg + 1)
414
+ nc, dc = nxt.tolist(), done.tolist()
415
+ for j in range(g):
416
+ if not dc[j] and nc[j] != EOS and nc[j] != PAD:
417
+ gen[j].append(nc[j])
418
+ toks = torch.cat([toks, nxt.unsqueeze(1)], dim=1)
419
+ abac = torch.cat([abac, new_abac.unsqueeze(1)], dim=1)
420
+ done = done | (nxt == EOS)
421
+ for j, i in enumerate(sub):
422
+ gj = gen[j]
423
+ if COLON in gj:
424
+ k = len(gj) - 1 - gj[::-1].index(COLON)
425
+ ans = [d for d in gj[k + 1:] if d < base]
426
+ out[i] = int_to_decimal_digits(_base_to_int(ans, base)) if ans else [0]
427
+ else:
428
+ out[i] = [0]
429
+ del toks, abac, seg, done, gen
430
+ # Only MPS needs the cache drop (it never frees mid-run and OOMs on
431
+ # long chains). On CUDA, empty_cache() synchronizes + forces the
432
+ # allocator to re-acquire buffers every chunk -- a ~5x slowdown over
433
+ # the ~1853-step base-16 decode -- and tier-5 peak is <3 GB anyway.
434
+ if device.type == "mps":
435
+ torch.mps.empty_cache()
436
+ return [o if o is not None else [0] for o in out]
437
+
438
+
439
  # ---------------------------------------------------------------------------
440
  # Submission entry class
441
  # ---------------------------------------------------------------------------
 
449
  self.mm_cfg = None
450
  self.mm4 = None # tier-4 modmul scratchpad
451
  self.mm4_cfg = None
452
+ self.mm5 = None # tier-5 base-16 modmul scratchpad
453
+ self.mm5_cfg = None
454
 
455
  def load(self, model_dir: str) -> None:
456
  if torch.cuda.is_available():
 
487
  ).to(self.device)
488
  self.mm4.load_state_dict(ckpt["tier4"]["state_dict"])
489
  self.mm4.eval()
490
+ # Tier 5: base-16 scratchpad, trained on primes in [2**33, 2**64).
491
+ if "tier5" in ckpt:
492
+ c5 = ckpt["tier5"]["config"]
493
+ self.mm5_cfg = c5
494
+ self.mm5 = AbacusDecoder(
495
+ max_len=c5["max_len"], abacus_max=c5["abacus_max"], d_model=c5["d_model"],
496
+ nhead=c5["nhead"], num_layers=c5["layers"], dim_ff=c5["dim_ff"],
497
+ vocab=c5["base"] + 8,
498
+ ).to(self.device)
499
+ self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
500
+ self.mm5.eval()
501
 
502
  # Per-argument identity preprocessing (each hook sees only its own argument).
503
  def preprocess_a(self, a): return a
 
515
  TIER3_LO = 512
516
  TIER3_HI = 65536
517
  TIER4_HI = 2 ** 32
518
+ TIER5_HI = 2 ** 64
519
 
520
  @torch.no_grad()
521
  def predict_digits_batch(self, inputs):
 
523
  x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], [] # tiers 1-2
524
  mm_items, mm_idx = [], [] # tier 3
525
  mm4_items, mm4_idx = [], [] # tier 4
526
+ mm5_items, mm5_idx = [], [] # tier 5
527
 
528
  for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
529
  p = int(p_enc)
530
  # Out of regime (residues don't fit the trained range): honest 0.
531
+ if p >= self.TIER5_HI:
532
  out[i] = [0]
533
  continue
534
  a_red = int(a_enc) % p # per-operand reduction (allowed)
535
  b_red = int(b_enc) % p
536
+ if p >= self.TIER4_HI:
537
+ if self.mm5 is not None:
538
+ mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
539
+ else:
540
+ out[i] = [0]
541
+ elif p >= self.TIER3_HI:
542
  if self.mm4 is not None:
543
  mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
544
  else:
 
575
  for j, i in enumerate(mm4_idx):
576
  out[i] = res[j]
577
 
578
+ if mm5_items:
579
+ # Tier-5 base-16 chains are ~1853 tokens. The model was trained under
580
+ # bf16 autocast, and decoding in bf16 (not fp32) is what keeps the long
581
+ # attention both fast (~125s/100 vs ~470-660s in fp32) and within memory
582
+ # (<3 GB vs a fp32 OOM at 1853-length attention). Match training precision.
583
+ if self.device.type == "cuda":
584
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
585
+ res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
586
+ self.device, base=self.mm5_cfg["base"],
587
+ chunk=64)
588
+ else:
589
+ res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
590
+ self.device, base=self.mm5_cfg["base"], chunk=64)
591
+ for j, i in enumerate(mm5_idx):
592
+ out[i] = res[j]
593
+
594
  return [o if o is not None else [0] for o in out]
595
 
596
  def max_batch_size(self) -> int:
weights.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d68b8f3ce76c7421b381a5c2465218d306323231386600f6ed643b889225ca6
3
- size 141288197
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:165d028f8af0a9efb45b755a20e855701b3e6594bcbb69b63cdc356544db8303
3
+ size 271364937