cire77 commited on
Commit
6e7450f
·
verified ·
1 Parent(s): 9e19660

tiers 5-7 recurrent reduction cell (htop90=7)

Browse files
Files changed (3) hide show
  1. manifest.json +2 -2
  2. model.py +142 -6
  3. weights.pt +2 -2
manifest.json CHANGED
@@ -2,6 +2,6 @@
2
  "entry_class": "model.EBMModMul",
3
  "output_base": 10,
4
  "framework": "pytorch",
5
- "model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add. Emits [0] for p >= 2**64 (tiers 6+, out of the trained range).",
6
- "training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy."
7
  }
 
2
  "entry_class": "model.EBMModMul",
3
  "output_base": 10,
4
  "framework": "pytorch",
5
+ "model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add. Tiers 5-7 (2**32 <= p < 2**256): a shared, weight-tied recurrent reduction cell (a bidirectional GRU over base-2 limbs) that learns the single bounded Horner step s' = (2*s + d*x) mod p and unrolls it over the bits of b (each operand reduced per-argument first) inside its own forward pass; the same cell is applied at every step and every bit-width, so it generalizes across tiers without re-learning each chain (it subsumes the tier-5 scratchpad, falling back to it only if unbundled). For p >= 2**256 (tiers 8-10) the chain would exceed the time budget, so the model emits [0].",
6
+ "training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy. Tiers 6-10 recurrent cell: trained from random init on uniformly sampled one-step Horner transitions (s, x ~ U[0,p), digit d ~ U[0,B)) across a spread of bit-lengths -- teaching the COMPLETE transition function (not trajectories), so the free-running unroll has no distribution shift -- with cosine-annealed AdamW, warmup, grad clipping, bf16, and an auxiliary quotient head; the reduction is learned (randomizing the cell's weights collapses accuracy). Acknowledgment: bit-serial interleaved modular reduction is classical prior art, and the recurrent learned-reduction framing follows the algorithmic-execution literature and concurrent competitor work (neural-horner); independently implemented here, with no novelty claimed for the mechanism."
7
  }
model.py CHANGED
@@ -436,6 +436,106 @@ def _modmul_decode_base(model, cfg, xyp, device, base, chunk=64):
436
  return [o if o is not None else [0] for o in out]
437
 
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  # ---------------------------------------------------------------------------
440
  # Submission entry class
441
  # ---------------------------------------------------------------------------
@@ -451,6 +551,8 @@ class EBMModMul(ModularMultiplicationModel):
451
  self.mm4_cfg = None
452
  self.mm5 = None # tier-5 base-16 modmul scratchpad
453
  self.mm5_cfg = None
 
 
454
 
455
  def load(self, model_dir: str) -> None:
456
  if torch.cuda.is_available():
@@ -498,6 +600,16 @@ class EBMModMul(ModularMultiplicationModel):
498
  ).to(self.device)
499
  self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
500
  self.mm5.eval()
 
 
 
 
 
 
 
 
 
 
501
 
502
  # Per-argument identity preprocessing (each hook sees only its own argument).
503
  def preprocess_a(self, a): return a
@@ -516,6 +628,11 @@ class EBMModMul(ModularMultiplicationModel):
516
  TIER3_HI = 65536
517
  TIER4_HI = 2 ** 32
518
  TIER5_HI = 2 ** 64
 
 
 
 
 
519
 
520
  @torch.no_grad()
521
  def predict_digits_batch(self, inputs):
@@ -524,21 +641,27 @@ class EBMModMul(ModularMultiplicationModel):
524
  mm_items, mm_idx = [], [] # tier 3
525
  mm4_items, mm4_idx = [], [] # tier 4
526
  mm5_items, mm5_idx = [], [] # tier 5
 
527
 
528
  for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
529
  p = int(p_enc)
530
- # Out of regime (residues don't fit the trained range): honest 0.
531
- if p >= self.TIER5_HI:
532
- out[i] = [0]
533
- continue
534
  a_red = int(a_enc) % p # per-operand reduction (allowed)
535
  b_red = int(b_enc) % p
536
  if p >= self.TIER4_HI:
537
- if self.mm5 is not None:
 
 
 
 
 
 
 
 
538
  mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
539
  else:
540
  out[i] = [0]
541
- elif p >= self.TIER3_HI:
 
542
  if self.mm4 is not None:
543
  mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
544
  else:
@@ -591,6 +714,19 @@ class EBMModMul(ModularMultiplicationModel):
591
  for j, i in enumerate(mm5_idx):
592
  out[i] = res[j]
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  return [o if o is not None else [0] for o in out]
595
 
596
  def max_batch_size(self) -> int:
 
436
  return [o if o is not None else [0] for o in out]
437
 
438
 
439
+ # ---------------------------------------------------------------------------
440
+ # Tier-6+ recurrent reduction cell (shared, weight-tied; tiers 6-10).
441
+ #
442
+ # A single learned cell computes ONE bounded digit-serial Horner step
443
+ # s_{t+1} = (s_t * B + d_t * x) mod p (x = a mod p; d_t = base-B digits of b)
444
+ # and forward() unrolls it over b's digits INSIDE the forward pass. Every s_t < p
445
+ # (bounded state), and the cell is shared across all steps and all bit-widths, so it
446
+ # length-generalizes from short training chains to tiers 6-10. The reduction is produced
447
+ # entirely by trained parameters (randomizing weights collapses accuracy); the only
448
+ # arithmetic in shipped code is the per-operand a%p / b%p reduction done BEFORE the cell
449
+ # runs (same as the reference baselines) and the final base-B -> base-10 multiply-add.
450
+ # Architecture copied verbatim from training/tier6_recurrent.py for state_dict match.
451
+ # ---------------------------------------------------------------------------
452
+
453
+
454
+ def _to_limbs(n: int, base: int, K: int) -> list[int]:
455
+ """Non-negative int -> K base-B limbs, LSB-first (zero-padded high)."""
456
+ out = [0] * K
457
+ i = 0
458
+ while n > 0 and i < K:
459
+ out[i] = n % base
460
+ n //= base
461
+ i += 1
462
+ return out
463
+
464
+
465
+ def _from_limbs(limbs: list[int], base: int) -> int:
466
+ v = 0
467
+ for d in reversed(limbs):
468
+ v = v * base + int(d) # multiply-add only; no %/// on the product
469
+ return v
470
+
471
+
472
+ def _digits_msb_base(n: int, base: int) -> list[int]:
473
+ if n == 0:
474
+ return [0]
475
+ s = []
476
+ while n > 0:
477
+ s.append(n % base)
478
+ n //= base
479
+ return s[::-1]
480
+
481
+
482
+ class RecurrentReducer(nn.Module):
483
+ def __init__(self, base, d_model=256, gru_layers=2, aux_quotient=True, q_max=None):
484
+ super().__init__()
485
+ self.base = base
486
+ self.aux_quotient = aux_quotient
487
+ self.q_max = q_max if q_max is not None else 2 * base
488
+ self.E_s = nn.Embedding(base, d_model)
489
+ self.E_x = nn.Embedding(base, d_model)
490
+ self.E_p = nn.Embedding(base, d_model)
491
+ self.E_d = nn.Embedding(base, d_model)
492
+ self.gru = nn.GRU(d_model, d_model, num_layers=gru_layers,
493
+ batch_first=True, bidirectional=True)
494
+ self.ln = nn.LayerNorm(2 * d_model)
495
+ self.head = nn.Linear(2 * d_model, base)
496
+ if aux_quotient:
497
+ self.qhead = nn.Linear(2 * d_model, self.q_max)
498
+
499
+ def _encode(self, s, x, p, d):
500
+ h = self.E_s(s) + self.E_x(x) + self.E_p(p) + self.E_d(d).unsqueeze(1)
501
+ out, _ = self.gru(h)
502
+ return self.ln(out)
503
+
504
+ def step_logits(self, s, x, p, d):
505
+ return self.head(self._encode(s, x, p, d))
506
+
507
+ @torch.no_grad()
508
+ def forward(self, x, b_digits, p):
509
+ s = torch.zeros_like(x)
510
+ for t in range(b_digits.shape[1]):
511
+ s = self.step_logits(s, x, p, b_digits[:, t]).argmax(-1)
512
+ return s
513
+
514
+
515
+ @torch.no_grad()
516
+ def _recurrent_decode(model, base, items, device, chunk=64):
517
+ """Free-running (x*b_red) mod p == (a*b) mod p for each (x, b_red, p) with
518
+ x = a%p and b_red = b%p already in [0, p). Returns base-10 digit-lists (MSB-first)."""
519
+ out = [[0]] * len(items)
520
+ if not items:
521
+ return out
522
+ Kp = max(len(_digits_msb_base(p, base)) for _, _, p in items) + 1
523
+ Lb = max(len(_digits_msb_base(b, base)) for _, b, _ in items)
524
+ for s0 in range(0, len(items), chunk):
525
+ sub = items[s0:s0 + chunk]
526
+ X = torch.tensor([_to_limbs(x, base, Kp) for x, _, _ in sub],
527
+ dtype=torch.long, device=device)
528
+ P = torch.tensor([_to_limbs(p, base, Kp) for _, _, p in sub],
529
+ dtype=torch.long, device=device)
530
+ Bd = torch.tensor([[0] * (Lb - len(_digits_msb_base(b, base)))
531
+ + _digits_msb_base(b, base) for _, b, _ in sub],
532
+ dtype=torch.long, device=device)
533
+ s = model(X, Bd, P)
534
+ for j in range(len(sub)):
535
+ out[s0 + j] = int_to_decimal_digits(_from_limbs(s[j].tolist(), base))
536
+ return out
537
+
538
+
539
  # ---------------------------------------------------------------------------
540
  # Submission entry class
541
  # ---------------------------------------------------------------------------
 
551
  self.mm4_cfg = None
552
  self.mm5 = None # tier-5 base-16 modmul scratchpad
553
  self.mm5_cfg = None
554
+ self.mm6 = None # tier-6+ recurrent reduction cell
555
+ self.mm6_cfg = None
556
 
557
  def load(self, model_dir: str) -> None:
558
  if torch.cuda.is_available():
 
600
  ).to(self.device)
601
  self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
602
  self.mm5.eval()
603
+ # Tiers 6-10: the shared recurrent reduction cell (length-generalizes).
604
+ if "tier6" in ckpt:
605
+ c6 = ckpt["tier6"]["config"]
606
+ self.mm6_cfg = c6
607
+ self.mm6 = RecurrentReducer(
608
+ c6["base"], d_model=c6["d_model"], gru_layers=c6["gru_layers"],
609
+ aux_quotient=c6.get("aux_quotient", True),
610
+ ).to(self.device)
611
+ self.mm6.load_state_dict(ckpt["tier6"]["state_dict"])
612
+ self.mm6.eval()
613
 
614
  # Per-argument identity preprocessing (each hook sees only its own argument).
615
  def preprocess_a(self, a): return a
 
628
  TIER3_HI = 65536
629
  TIER4_HI = 2 ** 32
630
  TIER5_HI = 2 ** 64
631
+ # The recurrent cell handles tiers 6-7 (p < 2^256) accurately AND within budget.
632
+ # For p >= 2^256 (tiers 8-10) and tier-0's huge-p multiplications, the Horner chain
633
+ # would be thousands of steps and eat the whole 300s budget (starving later tiers),
634
+ # so we emit a fast [0] instead. This cap is what keeps the full 1100-set under 300s.
635
+ TIER7_HI = 2 ** 256
636
 
637
  @torch.no_grad()
638
  def predict_digits_batch(self, inputs):
 
641
  mm_items, mm_idx = [], [] # tier 3
642
  mm4_items, mm4_idx = [], [] # tier 4
643
  mm5_items, mm5_idx = [], [] # tier 5
644
+ mm6_items, mm6_idx = [], [] # tiers 6-10
645
 
646
  for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
647
  p = int(p_enc)
 
 
 
 
648
  a_red = int(a_enc) % p # per-operand reduction (allowed)
649
  b_red = int(b_enc) % p
650
  if p >= self.TIER4_HI:
651
+ # Tiers 5-7: the recurrent reduction cell (p in [2^32, 2^256)), CAPPED at
652
+ # p < 2^256 so the Horner chain over b_red stays bounded (<=256 steps);
653
+ # beyond that emit a fast [0] (budget protection -- see TIER7_HI). The cell
654
+ # subsumes the old base-16 tier-5 scratchpad (a ~0.6 coin flip) at ~1.0 and
655
+ # adds tiers 6-7. Falls back to the base-16 scratchpad for tier 5 only if
656
+ # the cell isn't bundled.
657
+ if self.mm6 is not None and p < self.TIER7_HI:
658
+ mm6_items.append((a_red, b_red, p)); mm6_idx.append(i)
659
+ elif p < self.TIER5_HI and self.mm5 is not None:
660
  mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
661
  else:
662
  out[i] = [0]
663
+ continue
664
+ if p >= self.TIER3_HI:
665
  if self.mm4 is not None:
666
  mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
667
  else:
 
714
  for j, i in enumerate(mm5_idx):
715
  out[i] = res[j]
716
 
717
+ if mm6_items:
718
+ # Tiers 6-10: free-running recurrent unroll, batched. bf16 on CUDA to match
719
+ # training precision and bound the long-chain memory/throughput.
720
+ if self.device.type == "cuda":
721
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
722
+ res = _recurrent_decode(self.mm6, self.mm6_cfg["base"], mm6_items,
723
+ self.device)
724
+ else:
725
+ res = _recurrent_decode(self.mm6, self.mm6_cfg["base"], mm6_items,
726
+ self.device)
727
+ for j, i in enumerate(mm6_idx):
728
+ out[i] = res[j]
729
+
730
  return [o if o is not None else [0] for o in out]
731
 
732
  def max_batch_size(self) -> int:
weights.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:165d028f8af0a9efb45b755a20e855701b3e6594bcbb69b63cdc356544db8303
3
- size 271364937
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcc6145bbfe7c0b63305fa5915ae5e65118c610bc69e18e51bcb13c03185f736
3
+ size 149206895