tiers 5-7 recurrent reduction cell (htop90=7)
Browse files- manifest.json +2 -2
- model.py +142 -6
- weights.pt +2 -2
manifest.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"entry_class": "model.EBMModMul",
|
| 3 |
"output_base": 10,
|
| 4 |
"framework": "pytorch",
|
| 5 |
-
"model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add.
|
| 6 |
-
"training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy."
|
| 7 |
}
|
|
|
|
| 2 |
"entry_class": "model.EBMModMul",
|
| 3 |
"output_base": 10,
|
| 4 |
"framework": "pytorch",
|
| 5 |
+
"model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add. Tiers 5-7 (2**32 <= p < 2**256): a shared, weight-tied recurrent reduction cell (a bidirectional GRU over base-2 limbs) that learns the single bounded Horner step s' = (2*s + d*x) mod p and unrolls it over the bits of b (each operand reduced per-argument first) inside its own forward pass; the same cell is applied at every step and every bit-width, so it generalizes across tiers without re-learning each chain (it subsumes the tier-5 scratchpad, falling back to it only if unbundled). For p >= 2**256 (tiers 8-10) the chain would exceed the time budget, so the model emits [0].",
|
| 6 |
+
"training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy. Tiers 6-10 recurrent cell: trained from random init on uniformly sampled one-step Horner transitions (s, x ~ U[0,p), digit d ~ U[0,B)) across a spread of bit-lengths -- teaching the COMPLETE transition function (not trajectories), so the free-running unroll has no distribution shift -- with cosine-annealed AdamW, warmup, grad clipping, bf16, and an auxiliary quotient head; the reduction is learned (randomizing the cell's weights collapses accuracy). Acknowledgment: bit-serial interleaved modular reduction is classical prior art, and the recurrent learned-reduction framing follows the algorithmic-execution literature and concurrent competitor work (neural-horner); independently implemented here, with no novelty claimed for the mechanism."
|
| 7 |
}
|
model.py
CHANGED
|
@@ -436,6 +436,106 @@ def _modmul_decode_base(model, cfg, xyp, device, base, chunk=64):
|
|
| 436 |
return [o if o is not None else [0] for o in out]
|
| 437 |
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# ---------------------------------------------------------------------------
|
| 440 |
# Submission entry class
|
| 441 |
# ---------------------------------------------------------------------------
|
|
@@ -451,6 +551,8 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 451 |
self.mm4_cfg = None
|
| 452 |
self.mm5 = None # tier-5 base-16 modmul scratchpad
|
| 453 |
self.mm5_cfg = None
|
|
|
|
|
|
|
| 454 |
|
| 455 |
def load(self, model_dir: str) -> None:
|
| 456 |
if torch.cuda.is_available():
|
|
@@ -498,6 +600,16 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 498 |
).to(self.device)
|
| 499 |
self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
|
| 500 |
self.mm5.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
# Per-argument identity preprocessing (each hook sees only its own argument).
|
| 503 |
def preprocess_a(self, a): return a
|
|
@@ -516,6 +628,11 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 516 |
TIER3_HI = 65536
|
| 517 |
TIER4_HI = 2 ** 32
|
| 518 |
TIER5_HI = 2 ** 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
@torch.no_grad()
|
| 521 |
def predict_digits_batch(self, inputs):
|
|
@@ -524,21 +641,27 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 524 |
mm_items, mm_idx = [], [] # tier 3
|
| 525 |
mm4_items, mm4_idx = [], [] # tier 4
|
| 526 |
mm5_items, mm5_idx = [], [] # tier 5
|
|
|
|
| 527 |
|
| 528 |
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
|
| 529 |
p = int(p_enc)
|
| 530 |
-
# Out of regime (residues don't fit the trained range): honest 0.
|
| 531 |
-
if p >= self.TIER5_HI:
|
| 532 |
-
out[i] = [0]
|
| 533 |
-
continue
|
| 534 |
a_red = int(a_enc) % p # per-operand reduction (allowed)
|
| 535 |
b_red = int(b_enc) % p
|
| 536 |
if p >= self.TIER4_HI:
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
|
| 539 |
else:
|
| 540 |
out[i] = [0]
|
| 541 |
-
|
|
|
|
| 542 |
if self.mm4 is not None:
|
| 543 |
mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
|
| 544 |
else:
|
|
@@ -591,6 +714,19 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 591 |
for j, i in enumerate(mm5_idx):
|
| 592 |
out[i] = res[j]
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
return [o if o is not None else [0] for o in out]
|
| 595 |
|
| 596 |
def max_batch_size(self) -> int:
|
|
|
|
| 436 |
return [o if o is not None else [0] for o in out]
|
| 437 |
|
| 438 |
|
| 439 |
+
# ---------------------------------------------------------------------------
|
| 440 |
+
# Tier-6+ recurrent reduction cell (shared, weight-tied; tiers 6-10).
|
| 441 |
+
#
|
| 442 |
+
# A single learned cell computes ONE bounded digit-serial Horner step
|
| 443 |
+
# s_{t+1} = (s_t * B + d_t * x) mod p (x = a mod p; d_t = base-B digits of b)
|
| 444 |
+
# and forward() unrolls it over b's digits INSIDE the forward pass. Every s_t < p
|
| 445 |
+
# (bounded state), and the cell is shared across all steps and all bit-widths, so it
|
| 446 |
+
# length-generalizes from short training chains to tiers 6-10. The reduction is produced
|
| 447 |
+
# entirely by trained parameters (randomizing weights collapses accuracy); the only
|
| 448 |
+
# arithmetic in shipped code is the per-operand a%p / b%p reduction done BEFORE the cell
|
| 449 |
+
# runs (same as the reference baselines) and the final base-B -> base-10 multiply-add.
|
| 450 |
+
# Architecture copied verbatim from training/tier6_recurrent.py for state_dict match.
|
| 451 |
+
# ---------------------------------------------------------------------------
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _to_limbs(n: int, base: int, K: int) -> list[int]:
|
| 455 |
+
"""Non-negative int -> K base-B limbs, LSB-first (zero-padded high)."""
|
| 456 |
+
out = [0] * K
|
| 457 |
+
i = 0
|
| 458 |
+
while n > 0 and i < K:
|
| 459 |
+
out[i] = n % base
|
| 460 |
+
n //= base
|
| 461 |
+
i += 1
|
| 462 |
+
return out
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _from_limbs(limbs: list[int], base: int) -> int:
|
| 466 |
+
v = 0
|
| 467 |
+
for d in reversed(limbs):
|
| 468 |
+
v = v * base + int(d) # multiply-add only; no %/// on the product
|
| 469 |
+
return v
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _digits_msb_base(n: int, base: int) -> list[int]:
|
| 473 |
+
if n == 0:
|
| 474 |
+
return [0]
|
| 475 |
+
s = []
|
| 476 |
+
while n > 0:
|
| 477 |
+
s.append(n % base)
|
| 478 |
+
n //= base
|
| 479 |
+
return s[::-1]
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class RecurrentReducer(nn.Module):
|
| 483 |
+
def __init__(self, base, d_model=256, gru_layers=2, aux_quotient=True, q_max=None):
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.base = base
|
| 486 |
+
self.aux_quotient = aux_quotient
|
| 487 |
+
self.q_max = q_max if q_max is not None else 2 * base
|
| 488 |
+
self.E_s = nn.Embedding(base, d_model)
|
| 489 |
+
self.E_x = nn.Embedding(base, d_model)
|
| 490 |
+
self.E_p = nn.Embedding(base, d_model)
|
| 491 |
+
self.E_d = nn.Embedding(base, d_model)
|
| 492 |
+
self.gru = nn.GRU(d_model, d_model, num_layers=gru_layers,
|
| 493 |
+
batch_first=True, bidirectional=True)
|
| 494 |
+
self.ln = nn.LayerNorm(2 * d_model)
|
| 495 |
+
self.head = nn.Linear(2 * d_model, base)
|
| 496 |
+
if aux_quotient:
|
| 497 |
+
self.qhead = nn.Linear(2 * d_model, self.q_max)
|
| 498 |
+
|
| 499 |
+
def _encode(self, s, x, p, d):
|
| 500 |
+
h = self.E_s(s) + self.E_x(x) + self.E_p(p) + self.E_d(d).unsqueeze(1)
|
| 501 |
+
out, _ = self.gru(h)
|
| 502 |
+
return self.ln(out)
|
| 503 |
+
|
| 504 |
+
def step_logits(self, s, x, p, d):
|
| 505 |
+
return self.head(self._encode(s, x, p, d))
|
| 506 |
+
|
| 507 |
+
@torch.no_grad()
|
| 508 |
+
def forward(self, x, b_digits, p):
|
| 509 |
+
s = torch.zeros_like(x)
|
| 510 |
+
for t in range(b_digits.shape[1]):
|
| 511 |
+
s = self.step_logits(s, x, p, b_digits[:, t]).argmax(-1)
|
| 512 |
+
return s
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
@torch.no_grad()
|
| 516 |
+
def _recurrent_decode(model, base, items, device, chunk=64):
|
| 517 |
+
"""Free-running (x*b_red) mod p == (a*b) mod p for each (x, b_red, p) with
|
| 518 |
+
x = a%p and b_red = b%p already in [0, p). Returns base-10 digit-lists (MSB-first)."""
|
| 519 |
+
out = [[0]] * len(items)
|
| 520 |
+
if not items:
|
| 521 |
+
return out
|
| 522 |
+
Kp = max(len(_digits_msb_base(p, base)) for _, _, p in items) + 1
|
| 523 |
+
Lb = max(len(_digits_msb_base(b, base)) for _, b, _ in items)
|
| 524 |
+
for s0 in range(0, len(items), chunk):
|
| 525 |
+
sub = items[s0:s0 + chunk]
|
| 526 |
+
X = torch.tensor([_to_limbs(x, base, Kp) for x, _, _ in sub],
|
| 527 |
+
dtype=torch.long, device=device)
|
| 528 |
+
P = torch.tensor([_to_limbs(p, base, Kp) for _, _, p in sub],
|
| 529 |
+
dtype=torch.long, device=device)
|
| 530 |
+
Bd = torch.tensor([[0] * (Lb - len(_digits_msb_base(b, base)))
|
| 531 |
+
+ _digits_msb_base(b, base) for _, b, _ in sub],
|
| 532 |
+
dtype=torch.long, device=device)
|
| 533 |
+
s = model(X, Bd, P)
|
| 534 |
+
for j in range(len(sub)):
|
| 535 |
+
out[s0 + j] = int_to_decimal_digits(_from_limbs(s[j].tolist(), base))
|
| 536 |
+
return out
|
| 537 |
+
|
| 538 |
+
|
| 539 |
# ---------------------------------------------------------------------------
|
| 540 |
# Submission entry class
|
| 541 |
# ---------------------------------------------------------------------------
|
|
|
|
| 551 |
self.mm4_cfg = None
|
| 552 |
self.mm5 = None # tier-5 base-16 modmul scratchpad
|
| 553 |
self.mm5_cfg = None
|
| 554 |
+
self.mm6 = None # tier-6+ recurrent reduction cell
|
| 555 |
+
self.mm6_cfg = None
|
| 556 |
|
| 557 |
def load(self, model_dir: str) -> None:
|
| 558 |
if torch.cuda.is_available():
|
|
|
|
| 600 |
).to(self.device)
|
| 601 |
self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
|
| 602 |
self.mm5.eval()
|
| 603 |
+
# Tiers 6-10: the shared recurrent reduction cell (length-generalizes).
|
| 604 |
+
if "tier6" in ckpt:
|
| 605 |
+
c6 = ckpt["tier6"]["config"]
|
| 606 |
+
self.mm6_cfg = c6
|
| 607 |
+
self.mm6 = RecurrentReducer(
|
| 608 |
+
c6["base"], d_model=c6["d_model"], gru_layers=c6["gru_layers"],
|
| 609 |
+
aux_quotient=c6.get("aux_quotient", True),
|
| 610 |
+
).to(self.device)
|
| 611 |
+
self.mm6.load_state_dict(ckpt["tier6"]["state_dict"])
|
| 612 |
+
self.mm6.eval()
|
| 613 |
|
| 614 |
# Per-argument identity preprocessing (each hook sees only its own argument).
|
| 615 |
def preprocess_a(self, a): return a
|
|
|
|
| 628 |
TIER3_HI = 65536
|
| 629 |
TIER4_HI = 2 ** 32
|
| 630 |
TIER5_HI = 2 ** 64
|
| 631 |
+
# The recurrent cell handles tiers 6-7 (p < 2^256) accurately AND within budget.
|
| 632 |
+
# For p >= 2^256 (tiers 8-10) and tier-0's huge-p multiplications, the Horner chain
|
| 633 |
+
# would be thousands of steps and eat the whole 300s budget (starving later tiers),
|
| 634 |
+
# so we emit a fast [0] instead. This cap is what keeps the full 1100-set under 300s.
|
| 635 |
+
TIER7_HI = 2 ** 256
|
| 636 |
|
| 637 |
@torch.no_grad()
|
| 638 |
def predict_digits_batch(self, inputs):
|
|
|
|
| 641 |
mm_items, mm_idx = [], [] # tier 3
|
| 642 |
mm4_items, mm4_idx = [], [] # tier 4
|
| 643 |
mm5_items, mm5_idx = [], [] # tier 5
|
| 644 |
+
mm6_items, mm6_idx = [], [] # tiers 6-10
|
| 645 |
|
| 646 |
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
|
| 647 |
p = int(p_enc)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
a_red = int(a_enc) % p # per-operand reduction (allowed)
|
| 649 |
b_red = int(b_enc) % p
|
| 650 |
if p >= self.TIER4_HI:
|
| 651 |
+
# Tiers 5-7: the recurrent reduction cell (p in [2^32, 2^256)), CAPPED at
|
| 652 |
+
# p < 2^256 so the Horner chain over b_red stays bounded (<=256 steps);
|
| 653 |
+
# beyond that emit a fast [0] (budget protection -- see TIER7_HI). The cell
|
| 654 |
+
# subsumes the old base-16 tier-5 scratchpad (a ~0.6 coin flip) at ~1.0 and
|
| 655 |
+
# adds tiers 6-7. Falls back to the base-16 scratchpad for tier 5 only if
|
| 656 |
+
# the cell isn't bundled.
|
| 657 |
+
if self.mm6 is not None and p < self.TIER7_HI:
|
| 658 |
+
mm6_items.append((a_red, b_red, p)); mm6_idx.append(i)
|
| 659 |
+
elif p < self.TIER5_HI and self.mm5 is not None:
|
| 660 |
mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
|
| 661 |
else:
|
| 662 |
out[i] = [0]
|
| 663 |
+
continue
|
| 664 |
+
if p >= self.TIER3_HI:
|
| 665 |
if self.mm4 is not None:
|
| 666 |
mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
|
| 667 |
else:
|
|
|
|
| 714 |
for j, i in enumerate(mm5_idx):
|
| 715 |
out[i] = res[j]
|
| 716 |
|
| 717 |
+
if mm6_items:
|
| 718 |
+
# Tiers 6-10: free-running recurrent unroll, batched. bf16 on CUDA to match
|
| 719 |
+
# training precision and bound the long-chain memory/throughput.
|
| 720 |
+
if self.device.type == "cuda":
|
| 721 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 722 |
+
res = _recurrent_decode(self.mm6, self.mm6_cfg["base"], mm6_items,
|
| 723 |
+
self.device)
|
| 724 |
+
else:
|
| 725 |
+
res = _recurrent_decode(self.mm6, self.mm6_cfg["base"], mm6_items,
|
| 726 |
+
self.device)
|
| 727 |
+
for j, i in enumerate(mm6_idx):
|
| 728 |
+
out[i] = res[j]
|
| 729 |
+
|
| 730 |
return [o if o is not None else [0] for o in out]
|
| 731 |
|
| 732 |
def max_batch_size(self) -> int:
|
weights.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fcc6145bbfe7c0b63305fa5915ae5e65118c610bc69e18e51bcb13c03185f736
|
| 3 |
+
size 149206895
|