"""Residue router, version 3: small-prime specialist (tiers 1-2), a lifted local-step pipeline for tier 3 (16-bit residues), and the same two shared rules lifted to 32-bit limbs for tier 4 (17-32-bit primes, operands to 96 bits). Routing by the size of p: * p <= 251 (tiers 1-2): the v1 residue specialist. Each operand residue is looked up in a shared per-(prime, residue) table; the two vectors are combined by ADDITION (a discrete-log inductive bias: logs add under multiplication); a residual MLP trunk transforms the sum; logits come from a per-(prime, class) output table masked to the p classes of the current prime. The answer is a single base-256 digit (p <= 251 < 256). * 251 < p < 65536 (tier 3): two trained shared LOCAL-RULE step nets composed through fixed wiring. After reduction, x, y are 16-bit residues. A MULTIPLY step (the shared carry rule c' = floor((S+c)/2) over the carry-save column sums, composed closed-loop through a fixed parity readout) emits the exact 32-bit product t = x*y. A REDUCTION step (a shared per-nibble borrow/compare rule, composed through fixed restoring-division wiring) emits r = t mod p. The answer r is emitted as base-256 digits MSB-first. * 65536 <= p < 2^32 (tier 4): the SAME two rules at 32-bit geometry. After reduction, x, y are 32-bit residues. The MULTIPLY step (33x32-case carry rule over the 63 carry-save columns, parity readout widened to 64 bits) emits the 64-bit product as BITS -- the product overflows signed int64 at the top end, so the pipeline never materializes it as an integer. The REDUCTION step (the identical 512-case borrow rule) composed over 64 division positions x 9 nibbles emits r = t mod p. The answer r (< 2^32) is emitted as up to four base-256 digits MSB-first. * p >= 2^32 (tiers 5-10): outside the trained regime; returns [0]. Nothing in the forward pass hand-codes the arithmetic over the actual (a, b, p): the carry-save column sums, the parity readout, the bit shifts, the restoring- division topology, and the ge-from-final-borrow decision are FIXED scaffold; the two NONTRIVIAL decisions -- the carry rule and the borrow/compare rule -- live in trained MLP parameters (separate nets per tier-3 / tier-4 geometry). Randomizing any step net's weights collapses its tier. """ from __future__ import annotations import json from pathlib import Path import torch import torch.nn as nn from modchallenge.interface.base_model import ModularMultiplicationModel # =========================================================================== # Tier 1-2 specialist (v1 residue net) # =========================================================================== PRIMES = ( 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, ) MAX_P = 251 class SmallResidueNet(nn.Module): def __init__(self, d_model: int = 128, hidden: int = 1024): super().__init__() offsets, acc = [], 0 for p in PRIMES: offsets.append(acc) acc += p table = acc # 6081 self.pair_emb = nn.Embedding(table, d_model) self.out_emb = nn.Embedding(table, d_model) self.prime_emb = nn.Embedding(len(PRIMES), d_model) self.trunk = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, d_model), ) self.ln_out = nn.LayerNorm(d_model) self.register_buffer( "primes_t", torch.tensor(PRIMES, dtype=torch.long), persistent=False ) self.register_buffer( "offsets_t", torch.tensor(offsets, dtype=torch.long), persistent=False ) lookup = torch.full((MAX_P + 1,), -1, dtype=torch.long) for i, p in enumerate(PRIMES): lookup[p] = i self.register_buffer("prime_lookup", lookup, persistent=False) self.register_buffer( "class_grid", torch.arange(MAX_P, dtype=torch.long), persistent=False ) def forward( self, ix: torch.Tensor, iy: torch.Tensor, p_idx: torch.Tensor ) -> torch.Tensor: h = self.pair_emb(ix) + self.pair_emb(iy) + self.prime_emb(p_idx) g = self.ln_out(h + self.trunk(h)) off = self.offsets_t[p_idx] pv = self.primes_t[p_idx] grid = self.class_grid.unsqueeze(0) valid = grid < pv.unsqueeze(1) logits = (g @ self.out_emb.weight.t()).gather(1, off.unsqueeze(1) + grid) return logits.masked_fill(~valid, float("-inf")) @torch.no_grad() def predict( self, x: torch.Tensor, y: torch.Tensor, p: torch.Tensor ) -> torch.Tensor: p_idx = self.prime_lookup[p] off = self.offsets_t[p_idx] return self.forward(off + x, off + y, p_idx).argmax(dim=-1) # =========================================================================== # Shared step-net architecture (used by both tier-3 and tier-4 geometries) # =========================================================================== class StepMLP(nn.Module): """Plain GELU MLP step: n_in local-state bits -> n_out logits.""" def __init__(self, n_in: int, n_out: int, width: int, depth: int): super().__init__() self.layers = nn.ModuleList([nn.Linear(n_in, width)]) for _ in range(depth - 1): self.layers.append(nn.Linear(width, width)) self.head = nn.Linear(width, n_out) self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: h = x for lin in self.layers: h = self.act(lin(h)) return self.head(h) # =========================================================================== # Tier 3 geometry (16x16 -> 32-bit; 5-nibble reduction) # =========================================================================== T3_MUL_OPB = 16 T3_MUL_PRB = 32 T3_MUL_COLS = 2 * T3_MUL_OPB - 1 # 31 T3_MUL_SUMB = 5 T3_MUL_CARB = 4 T3_MUL_IN = T3_MUL_SUMB + T3_MUL_CARB # 9 T3_NIB = 4 T3_RED_NIBBLES = 5 T3_RED_IN = T3_NIB + T3_NIB + 1 # 9 T3_RED_OUT = T3_NIB + 1 # 5 T3_T_BITS = 32 # =========================================================================== # Tier 4 geometry (32x32 -> 64-bit; 9-nibble reduction) # =========================================================================== T4_MUL_OPB = 32 T4_MUL_PRB = 64 T4_MUL_COLS = 2 * T4_MUL_OPB - 1 # 63 T4_MUL_SUMB = 6 T4_MUL_CARB = 5 T4_MUL_IN = T4_MUL_SUMB + T4_MUL_CARB # 11 T4_NIB = 4 T4_RED_NIBBLES = 9 T4_RED_IN = T4_NIB + T4_NIB + 1 # 9 T4_RED_OUT = T4_NIB + 1 # 5 T4_T_BITS = 64 # -- generic carry-save / reduction wiring (parameterized by geometry) ------- def _bits(v: torch.Tensor, nb: int) -> torch.Tensor: return ((v.unsqueeze(1) >> torch.arange(nb, device=v.device)) & 1).float() def _column_sums(x_bits: torch.Tensor, y_bits: torch.Tensor, opb: int, cols: int) -> torch.Tensor: outer = x_bits.unsqueeze(2) * y_bits.unsqueeze(1) n = outer.shape[0] s = torch.zeros(n, cols, dtype=outer.dtype, device=outer.device) for i in range(opb): for j in range(opb): s[:, i + j] += outer[:, i, j] return s def _encode_carry(s: torch.Tensor, c: torch.Tensor, sumb: int, carb: int) -> torch.Tensor: si = torch.arange(sumb, device=s.device) ci = torch.arange(carb, device=c.device) sb = ((s.unsqueeze(1) >> si) & 1).float() cb = ((c.unsqueeze(1) >> ci) & 1).float() return torch.cat([sb, cb], dim=1) def _carry_bits_to_int(bits: torch.Tensor, carb: int) -> torch.Tensor: w = (1 << torch.arange(carb, device=bits.device)).long() return (bits.round().clamp(0, 1).long() * w).sum(dim=-1) @torch.no_grad() def _closed_loop_mul(step, col_sums, cols, carb): n = col_sums.shape[0] s = col_sums.long() carry = torch.zeros(n, dtype=torch.long, device=s.device) out = torch.empty(n, cols * carb, device=col_sums.device) sumb = T4_MUL_SUMB if cols == T4_MUL_COLS else T3_MUL_SUMB for c in range(cols): lg = step(_encode_carry(s[:, c], carry, sumb, carb)) out[:, carb * c:carb * (c + 1)] = lg carry = _carry_bits_to_int((lg > 0).float(), carb) return out def _routed_product_bits(carry_logits, col_parity, carb): """Fixed parity readout: bit_c = parity(S_c) XOR lsb(carry into c).""" BIG = 20.0 lsb = carry_logits[:, 0::carb] bit0 = (2.0 * col_parity[:, 0:1] - 1.0) * BIG mid = (1.0 - 2.0 * col_parity[:, 1:]) * lsb[:, :-1] bit_last = lsb[:, -1:] return torch.cat([bit0, mid, bit_last], dim=1) @torch.no_grad() def _composed_product_bits(step, x, y, opb, cols, prb, carb): """Trained carry step (closed loop) + parity readout -> product BITS (B, prb).""" col_sums = _column_sums(_bits(x, opb), _bits(y, opb), opb, cols) logits = _closed_loop_mul(step, col_sums, cols, carb) col_parity = (col_sums.long() & 1).float() bit_logits = _routed_product_bits(logits, col_parity, carb) return (bit_logits > 0).long() # (B, prb) bits LSB first def _encode_red(a, b, bin_, nib): ai = torch.arange(nib, device=a.device) aa = ((a.unsqueeze(1) >> ai) & 1).float() bb = ((b.unsqueeze(1) >> ai) & 1).float() cc = bin_.float().unsqueeze(1) return torch.cat([aa, bb, cc], dim=1) def _red_bits_to_out(bits, nib): hb = (bits > 0).long() w = (1 << torch.arange(nib, device=bits.device)).long() d = (hb[:, :nib] * w).sum(dim=1) bout = hb[:, nib] return d, bout @torch.no_grad() def _composed_reduce_bits(step, t_bits, p, nib, nibbles, t_bits_n): """Restoring division of the bit-represented product by p -> r (B,). R stays in [0, p), so R never overflows int64 even when the full product does. The bit shifts, ge-from-final-borrow, and keep/replace of R are fixed wiring; the per-nibble subtract DECISION is the trained step. """ n = t_bits.shape[0] device = t_bits.device R = torch.zeros(n, dtype=torch.long, device=device) p_nib = torch.stack([(p >> (nib * k)) & 0xF for k in range(nibbles)], dim=1) wk = (1 << (nib * torch.arange(nibbles, device=device))).long() for i in range(t_bits_n - 1, -1, -1): bit = t_bits[:, i].long() Rpre = (R << 1) | bit borrow = torch.zeros(n, dtype=torch.long, device=device) diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device) for k in range(nibbles): an = (Rpre >> (nib * k)) & 0xF bn = p_nib[:, k] lg = step(_encode_red(an, bn, borrow, nib)) d, bout = _red_bits_to_out(lg, nib) diff_nib[:, k] = d borrow = bout ge = (borrow == 0).long() diff_val = (diff_nib * wk).sum(dim=1) R = torch.where(ge.bool(), diff_val, Rpre) return R # =========================================================================== # Router # =========================================================================== T3_MIN_P = MAX_P + 1 # 252 T3_MAX_P = (1 << 16) - 1 # tier-3 primes are 9-16 bits T4_MIN_P = 1 << 16 # 65536 T4_MAX_P = (1 << 32) - 1 # tier-4 primes are 17-32 bits class ResidueRouterV1(ModularMultiplicationModel): """Router over per-tier specialists, selected by the size of p. Kept the class name ``ResidueRouterV1`` so the manifest entry_class is stable across versions; this is v3 (tiers 1-4). """ def __init__(self): self.small: SmallResidueNet | None = None self.t3_mul: StepMLP | None = None self.t3_red: StepMLP | None = None self.t4_mul: StepMLP | None = None self.t4_red: StepMLP | None = None def load(self, model_dir: str) -> None: from safetensors.torch import load_file torch.manual_seed(0) model_dir = Path(model_dir) config = json.loads((model_dir / "config.json").read_text()) tensors = load_file(str(model_dir / "weights.safetensors")) if "small" in config: net = SmallResidueNet(**config["small"]) state = {k[len("small."):]: v for k, v in tensors.items() if k.startswith("small.")} net.load_state_dict(state, strict=True) net.eval() self.small = net if "t3" in config: w, d = config["t3"]["width"], config["t3"]["depth"] mul = StepMLP(T3_MUL_IN, T3_MUL_CARB, w, d) red = StepMLP(T3_RED_IN, T3_RED_OUT, w, d) mul.load_state_dict(_remap(load_file(str(model_dir / "t3_mul.safetensors"))), strict=True) red.load_state_dict(_remap(load_file(str(model_dir / "t3_red.safetensors"))), strict=True) mul.eval(); red.eval() self.t3_mul, self.t3_red = mul, red if "t4" in config: mw, rw, d = config["t4"]["mul_width"], config["t4"]["red_width"], config["t4"]["depth"] mul = StepMLP(T4_MUL_IN, T4_MUL_CARB, mw, d) red = StepMLP(T4_RED_IN, T4_RED_OUT, rw, d) mul.load_state_dict(_remap(load_file(str(model_dir / "t4_mul.safetensors"))), strict=True) red.load_state_dict(_remap(load_file(str(model_dir / "t4_red.safetensors"))), strict=True) mul.eval(); red.eval() self.t4_mul, self.t4_red = mul, red def preprocess_a(self, a): return a def preprocess_b(self, b): return b def preprocess_p(self, p): return p @torch.no_grad() def predict_digits(self, a_enc, b_enc, p_enc): return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0] @torch.no_grad() def predict_digits_batch(self, inputs): out: list[list[int] | None] = [None] * len(inputs) s_x, s_y, s_p, s_idx = [], [], [], [] # tier 1-2 t3_x, t3_y, t3_p, t3_idx = [], [], [], [] # tier 3 t4_x, t4_y, t4_p, t4_idx = [], [], [], [] # tier 4 for i, (a_enc, b_enc, p_enc) in enumerate(inputs): try: p = int(p_enc) except (ValueError, TypeError): out[i] = [0] continue # Operand normalization: a with p, then b with p (never all three). try: xr = int(a_enc) % p yr = int(b_enc) % p except (ValueError, TypeError): out[i] = [0] continue if self.small is not None and 2 <= p <= MAX_P and int(self.small.prime_lookup[p]) >= 0: s_x.append(xr); s_y.append(yr); s_p.append(p); s_idx.append(i) elif self.t3_mul is not None and T3_MIN_P <= p <= T3_MAX_P: t3_x.append(xr); t3_y.append(yr); t3_p.append(p); t3_idx.append(i) elif self.t4_mul is not None and T4_MIN_P <= p <= T4_MAX_P: t4_x.append(xr); t4_y.append(yr); t4_p.append(p); t4_idx.append(i) else: out[i] = [0] # outside the trained regime -> honest fallback if s_idx: preds = self.small.predict( torch.tensor(s_x, dtype=torch.long), torch.tensor(s_y, dtype=torch.long), torch.tensor(s_p, dtype=torch.long), ).tolist() for j, i in enumerate(s_idx): out[i] = [int(preds[j])] if t3_idx: x_t = torch.tensor(t3_x, dtype=torch.long) y_t = torch.tensor(t3_y, dtype=torch.long) p_t = torch.tensor(t3_p, dtype=torch.long) tb = _composed_product_bits(self.t3_mul, x_t, y_t, T3_MUL_OPB, T3_MUL_COLS, T3_MUL_PRB, T3_MUL_CARB) r = _composed_reduce_bits(self.t3_red, tb, p_t, T3_NIB, T3_RED_NIBBLES, T3_T_BITS) for j, i in enumerate(t3_idx): out[i] = _digits_msb(int(r[j].item())) if t4_idx: x_t = torch.tensor(t4_x, dtype=torch.long) y_t = torch.tensor(t4_y, dtype=torch.long) p_t = torch.tensor(t4_p, dtype=torch.long) tb = _composed_product_bits(self.t4_mul, x_t, y_t, T4_MUL_OPB, T4_MUL_COLS, T4_MUL_PRB, T4_MUL_CARB) r = _composed_reduce_bits(self.t4_red, tb, p_t, T4_NIB, T4_RED_NIBBLES, T4_T_BITS) for j, i in enumerate(t4_idx): out[i] = _digits_msb(int(r[j].item())) return [o if o is not None else [0] for o in out] def max_batch_size(self) -> int: return 512 def _digits_msb(v: int) -> list[int]: """Base-256 digits, MSB-first; at least one digit.""" if v == 0: return [0] ds = [] while v > 0: ds.append(v & 0xFF) v >>= 8 return ds[::-1] def _remap(state: dict) -> dict: """Map the trained step-net state-dict keys onto the StepMLP layout. The training-side MulCarryStep/RedBorrowStep store layers under the same ``layers.*`` / ``head.*`` names as StepMLP, so this is identity; kept as a seam in case a future export renames keys. """ return state