"""Residue router, version 4: small-prime specialist (tiers 1-2), a lifted local-step pipeline for tier 3 (16-bit residues), the same two shared rules lifted to 32-bit limbs for tier 4 (17-32-bit primes), and lifted again to 64-bit limbs for tier 5 (33-64-bit primes, operands to 128 bits). Routing by the size of p: * p <= 251 (tiers 1-2): the v1 residue specialist. Each operand residue is looked up in a shared per-(prime, residue) table; the two vectors are combined by ADDITION (a discrete-log inductive bias: logs add under multiplication); a residual MLP trunk transforms the sum; logits come from a per-(prime, class) output table masked to the p classes of the current prime. The answer is a single base-256 digit (p <= 251 < 256). * 251 < p < 65536 (tier 3): two trained shared LOCAL-RULE step nets composed through fixed wiring. After reduction, x, y are 16-bit residues. A MULTIPLY step (the shared carry rule c' = floor((S+c)/2) over the carry-save column sums, composed closed-loop through a fixed parity readout) emits the exact 32-bit product t = x*y. A REDUCTION step (a shared per-nibble borrow/compare rule, composed through fixed restoring-division wiring) emits r = t mod p. The answer r is emitted as base-256 digits MSB-first. * 65536 <= p < 2^32 (tier 4): the SAME two rules at 32-bit geometry. After reduction, x, y are 32-bit residues. The MULTIPLY step (33x32-case carry rule over the 63 carry-save columns, parity readout widened to 64 bits) emits the 64-bit product as BITS. The REDUCTION step (the identical 512-case borrow rule) composed over 64 division positions x 9 nibbles emits r = t mod p. The answer r (< 2^32) is emitted as up to four base-256 digits. * 2^32 <= p < 2^64 (tier 5): the SAME two rules at 64-bit geometry. After reduction, x, y are 64-bit residues. Because a 64-bit residue and the 65-bit division register both overflow signed int64, tier 5 carries operands, p, the product, and the division register as BIT tensors -- no wide value is ever materialized as an int64 scalar. The MULTIPLY step (65x64-case carry rule over the 127 carry-save columns, parity readout widened to 128 bits) emits the 128-bit product as bits. The REDUCTION step (the identical 512-case borrow rule) composed over 128 division positions x 17 nibbles emits r = t mod p in [0, p). The answer r (< 2^64) is emitted as up to eight base-256 digits MSB-first. * p >= 2^64 (tiers 6-10): outside the trained regime; returns [0]. Nothing in the forward pass hand-codes the arithmetic over the actual (a, b, p): the carry-save column sums, the parity readout, the bit shifts, the restoring- division topology, and the ge-from-final-borrow decision are FIXED scaffold; the two NONTRIVIAL decisions -- the carry rule and the borrow/compare rule -- live in trained MLP parameters (separate nets per tier-3 / tier-4 / tier-5 geometry). Randomizing any step net's weights collapses its tier. """ from __future__ import annotations import json from pathlib import Path import torch import torch.nn as nn from modchallenge.interface.base_model import ModularMultiplicationModel # =========================================================================== # Tier 1-2 specialist (v1 residue net) # =========================================================================== PRIMES = ( 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, ) MAX_P = 251 class SmallResidueNet(nn.Module): def __init__(self, d_model: int = 128, hidden: int = 1024): super().__init__() offsets, acc = [], 0 for p in PRIMES: offsets.append(acc) acc += p table = acc # 6081 self.pair_emb = nn.Embedding(table, d_model) self.out_emb = nn.Embedding(table, d_model) self.prime_emb = nn.Embedding(len(PRIMES), d_model) self.trunk = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, hidden), nn.GELU(), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, d_model), ) self.ln_out = nn.LayerNorm(d_model) self.register_buffer( "primes_t", torch.tensor(PRIMES, dtype=torch.long), persistent=False ) self.register_buffer( "offsets_t", torch.tensor(offsets, dtype=torch.long), persistent=False ) lookup = torch.full((MAX_P + 1,), -1, dtype=torch.long) for i, p in enumerate(PRIMES): lookup[p] = i self.register_buffer("prime_lookup", lookup, persistent=False) self.register_buffer( "class_grid", torch.arange(MAX_P, dtype=torch.long), persistent=False ) def forward( self, ix: torch.Tensor, iy: torch.Tensor, p_idx: torch.Tensor ) -> torch.Tensor: h = self.pair_emb(ix) + self.pair_emb(iy) + self.prime_emb(p_idx) g = self.ln_out(h + self.trunk(h)) off = self.offsets_t[p_idx] pv = self.primes_t[p_idx] grid = self.class_grid.unsqueeze(0) valid = grid < pv.unsqueeze(1) logits = (g @ self.out_emb.weight.t()).gather(1, off.unsqueeze(1) + grid) return logits.masked_fill(~valid, float("-inf")) @torch.no_grad() def predict( self, x: torch.Tensor, y: torch.Tensor, p: torch.Tensor ) -> torch.Tensor: p_idx = self.prime_lookup[p] off = self.offsets_t[p_idx] return self.forward(off + x, off + y, p_idx).argmax(dim=-1) # =========================================================================== # Shared step-net architecture (used by tier-3 / tier-4 / tier-5 geometries) # =========================================================================== class StepMLP(nn.Module): """Plain GELU MLP step: n_in local-state bits -> n_out logits.""" def __init__(self, n_in: int, n_out: int, width: int, depth: int): super().__init__() self.layers = nn.ModuleList([nn.Linear(n_in, width)]) for _ in range(depth - 1): self.layers.append(nn.Linear(width, width)) self.head = nn.Linear(width, n_out) self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: h = x for lin in self.layers: h = self.act(lin(h)) return self.head(h) # =========================================================================== # Per-tier multiply / reduction geometries # =========================================================================== # Tier 3 (16x16 -> 32-bit; 5-nibble reduction) T3_MUL_OPB = 16 T3_MUL_PRB = 32 T3_MUL_COLS = 2 * T3_MUL_OPB - 1 # 31 T3_MUL_SUMB = 5 T3_MUL_CARB = 4 T3_MUL_IN = T3_MUL_SUMB + T3_MUL_CARB # 9 T3_NIB = 4 T3_RED_NIBBLES = 5 T3_RED_IN = T3_NIB + T3_NIB + 1 # 9 T3_RED_OUT = T3_NIB + 1 # 5 T3_T_BITS = 32 # Tier 4 (32x32 -> 64-bit; 9-nibble reduction) T4_MUL_OPB = 32 T4_MUL_PRB = 64 T4_MUL_COLS = 2 * T4_MUL_OPB - 1 # 63 T4_MUL_SUMB = 6 T4_MUL_CARB = 5 T4_MUL_IN = T4_MUL_SUMB + T4_MUL_CARB # 11 T4_NIB = 4 T4_RED_NIBBLES = 9 T4_RED_IN = T4_NIB + T4_NIB + 1 # 9 T4_RED_OUT = T4_NIB + 1 # 5 T4_T_BITS = 64 # Tier 5 (64x64 -> 128-bit; 17-nibble reduction). Wide values are bit tensors. T5_MUL_OPB = 64 T5_MUL_PRB = 128 T5_MUL_COLS = 2 * T5_MUL_OPB - 1 # 127 T5_MUL_SUMB = 7 # S_c <= 64 T5_MUL_CARB = 6 # carry <= 63 T5_MUL_IN = T5_MUL_SUMB + T5_MUL_CARB # 13 T5_NIB = 4 T5_RED_NIBBLES = 17 # 65-bit R_pre / 64-bit p T5_RED_IN = T5_NIB + T5_NIB + 1 # 9 T5_RED_OUT = T5_NIB + 1 # 5 T5_T_BITS = 128 # -- generic carry-save / reduction wiring (parameterized by geometry) ------- def _bits(v: torch.Tensor, nb: int) -> torch.Tensor: return ((v.unsqueeze(1) >> torch.arange(nb, device=v.device)) & 1).float() def _column_sums(x_bits: torch.Tensor, y_bits: torch.Tensor, opb: int, cols: int) -> torch.Tensor: n = x_bits.shape[0] s = torch.zeros(n, cols, dtype=x_bits.dtype, device=x_bits.device) for i in range(opb): s[:, i:i + opb] += x_bits[:, i:i + 1] * y_bits return s def _encode_carry(s: torch.Tensor, c: torch.Tensor, sumb: int, carb: int) -> torch.Tensor: si = torch.arange(sumb, device=s.device) ci = torch.arange(carb, device=c.device) sb = ((s.unsqueeze(1) >> si) & 1).float() cb = ((c.unsqueeze(1) >> ci) & 1).float() return torch.cat([sb, cb], dim=1) def _carry_bits_to_int(bits: torch.Tensor, carb: int) -> torch.Tensor: w = (1 << torch.arange(carb, device=bits.device)).long() return (bits.round().clamp(0, 1).long() * w).sum(dim=-1) _SUMB_FOR_COLS = {T3_MUL_COLS: T3_MUL_SUMB, T4_MUL_COLS: T4_MUL_SUMB, T5_MUL_COLS: T5_MUL_SUMB} @torch.no_grad() def _closed_loop_mul(step, col_sums, cols, carb): n = col_sums.shape[0] s = col_sums.long() carry = torch.zeros(n, dtype=torch.long, device=s.device) out = torch.empty(n, cols * carb, device=col_sums.device) sumb = _SUMB_FOR_COLS[cols] for c in range(cols): lg = step(_encode_carry(s[:, c], carry, sumb, carb)) out[:, carb * c:carb * (c + 1)] = lg carry = _carry_bits_to_int((lg > 0).float(), carb) return out def _routed_product_bits(carry_logits, col_parity, carb): """Fixed parity readout: bit_c = parity(S_c) XOR lsb(carry into c).""" BIG = 20.0 lsb = carry_logits[:, 0::carb] bit0 = (2.0 * col_parity[:, 0:1] - 1.0) * BIG mid = (1.0 - 2.0 * col_parity[:, 1:]) * lsb[:, :-1] bit_last = lsb[:, -1:] return torch.cat([bit0, mid, bit_last], dim=1) @torch.no_grad() def _composed_product_bits(step, x_bits, y_bits, opb, cols, prb, carb): """Trained carry step (closed loop) + parity readout -> product BITS (B, prb). x_bits, y_bits are (B, opb) LSB-first operand bit tensors. """ col_sums = _column_sums(x_bits, y_bits, opb, cols) logits = _closed_loop_mul(step, col_sums, cols, carb) col_parity = (col_sums.long() & 1).float() bit_logits = _routed_product_bits(logits, col_parity, carb) return (bit_logits > 0).long() # (B, prb) bits LSB first def _encode_red(a, b, bin_, nib): ai = torch.arange(nib, device=a.device) aa = ((a.unsqueeze(1) >> ai) & 1).float() bb = ((b.unsqueeze(1) >> ai) & 1).float() cc = bin_.float().unsqueeze(1) return torch.cat([aa, bb, cc], dim=1) def _red_bits_to_out(bits, nib): hb = (bits > 0).long() w = (1 << torch.arange(nib, device=bits.device)).long() d = (hb[:, :nib] * w).sum(dim=1) bout = hb[:, nib] return d, bout @torch.no_grad() def _composed_reduce_int(step, t_bits, p, nib, nibbles, t_bits_n): """Restoring division by p when p and R fit signed int64 (tiers 3-4). R stays in [0, p) (< 2^32), so R never overflows int64 even though the full product does. Fixed wiring; per-nibble subtract DECISION is the trained step. """ n = t_bits.shape[0] device = t_bits.device R = torch.zeros(n, dtype=torch.long, device=device) p_nib = torch.stack([(p >> (nib * k)) & 0xF for k in range(nibbles)], dim=1) wk = (1 << (nib * torch.arange(nibbles, device=device))).long() for i in range(t_bits_n - 1, -1, -1): bit = t_bits[:, i].long() Rpre = (R << 1) | bit borrow = torch.zeros(n, dtype=torch.long, device=device) diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device) for k in range(nibbles): an = (Rpre >> (nib * k)) & 0xF bn = p_nib[:, k] lg = step(_encode_red(an, bn, borrow, nib)) d, bout = _red_bits_to_out(lg, nib) diff_nib[:, k] = d borrow = bout ge = (borrow == 0).long() diff_val = (diff_nib * wk).sum(dim=1) R = torch.where(ge.bool(), diff_val, Rpre) return R def _p_nibbles_from_bits(p_bits, nib, nibbles): n = p_bits.shape[0] out = torch.zeros(n, nibbles, dtype=torch.long, device=p_bits.device) wt = (1 << torch.arange(nib, device=p_bits.device)).long() for k in range(nibbles): chunk = p_bits[:, nib * k:nib * (k + 1)] if chunk.shape[1] < nib: pad = torch.zeros(n, nib - chunk.shape[1], device=p_bits.device) chunk = torch.cat([chunk, pad], dim=1) out[:, k] = (chunk.long() * wt).sum(dim=1) return out @torch.no_grad() def _composed_reduce_bits(step, t_bits, p_bits, nib, nibbles, t_bits_n): """Restoring division by p with R, R_pre, p carried as BIT tensors (tier 5). Nothing overflows int64: R has up to 64 bits, R_pre = 2R + bit up to 65 bits -> 17 nibbles. Returns the remainder as (B, nibbles*nib) bits LSB-first. """ n = t_bits.shape[0] device = t_bits.device RB = nibbles * nib R = torch.zeros(n, RB, dtype=torch.long, device=device) p_nib = _p_nibbles_from_bits(p_bits, nib, nibbles) wt = (1 << torch.arange(nib, device=device)).long() di = torch.arange(nib, device=device) for i in range(t_bits_n - 1, -1, -1): Rpre = torch.zeros(n, RB, dtype=torch.long, device=device) Rpre[:, 1:] = R[:, :-1] Rpre[:, 0] = t_bits[:, i].long() borrow = torch.zeros(n, dtype=torch.long, device=device) diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device) for k in range(nibbles): an = (Rpre[:, nib * k:nib * (k + 1)].long() * wt).sum(dim=1) bn = p_nib[:, k] lg = step(_encode_red(an, bn, borrow, nib)) d, bout = _red_bits_to_out(lg, nib) diff_nib[:, k] = d borrow = bout ge = (borrow == 0) diff_bits = torch.zeros(n, RB, dtype=torch.long, device=device) for k in range(nibbles): diff_bits[:, nib * k:nib * (k + 1)] = (diff_nib[:, k:k + 1] >> di) & 1 R = torch.where(ge.unsqueeze(1), diff_bits, Rpre) return R.float() # -- bigint <-> bit-tensor helpers (tier 5: residues exceed signed int64) ---- def _int_bits(values, nb: int) -> torch.Tensor: out = torch.zeros(len(values), nb, dtype=torch.float32) for r, v in enumerate(values): v = int(v) b = 0 while v and b < nb: out[r, b] = float(v & 1) v >>= 1 b += 1 return out def _bits_to_ints(bits: torch.Tensor) -> list[int]: hb = (bits > 0.5).long().tolist() out = [] for row in hb: v = 0 for b, bit in enumerate(row): if bit: v |= (1 << b) out.append(v) return out # =========================================================================== # Router # =========================================================================== T3_MIN_P = MAX_P + 1 # 252 T3_MAX_P = (1 << 16) - 1 T4_MIN_P = 1 << 16 T4_MAX_P = (1 << 32) - 1 T5_MIN_P = 1 << 32 T5_MAX_P = (1 << 64) - 1 class ResidueRouterV1(ModularMultiplicationModel): """Router over per-tier specialists, selected by the size of p. Kept the class name ``ResidueRouterV1`` so the manifest entry_class is stable across versions; this is v4 (tiers 1-5). """ def __init__(self): self.small: SmallResidueNet | None = None self.t3_mul: StepMLP | None = None self.t3_red: StepMLP | None = None self.t4_mul: StepMLP | None = None self.t4_red: StepMLP | None = None self.t5_mul: StepMLP | None = None self.t5_red: StepMLP | None = None def load(self, model_dir: str) -> None: from safetensors.torch import load_file torch.manual_seed(0) model_dir = Path(model_dir) config = json.loads((model_dir / "config.json").read_text()) tensors = load_file(str(model_dir / "weights.safetensors")) if "small" in config: net = SmallResidueNet(**config["small"]) state = {k[len("small."):]: v for k, v in tensors.items() if k.startswith("small.")} net.load_state_dict(state, strict=True) net.eval() self.small = net if "t3" in config: w, d = config["t3"]["width"], config["t3"]["depth"] mul = StepMLP(T3_MUL_IN, T3_MUL_CARB, w, d) red = StepMLP(T3_RED_IN, T3_RED_OUT, w, d) mul.load_state_dict(load_file(str(model_dir / "t3_mul.safetensors")), strict=True) red.load_state_dict(load_file(str(model_dir / "t3_red.safetensors")), strict=True) mul.eval(); red.eval() self.t3_mul, self.t3_red = mul, red if "t4" in config: mw, rw, d = config["t4"]["mul_width"], config["t4"]["red_width"], config["t4"]["depth"] mul = StepMLP(T4_MUL_IN, T4_MUL_CARB, mw, d) red = StepMLP(T4_RED_IN, T4_RED_OUT, rw, d) mul.load_state_dict(load_file(str(model_dir / "t4_mul.safetensors")), strict=True) red.load_state_dict(load_file(str(model_dir / "t4_red.safetensors")), strict=True) mul.eval(); red.eval() self.t4_mul, self.t4_red = mul, red if "t5" in config: mw, rw, d = config["t5"]["mul_width"], config["t5"]["red_width"], config["t5"]["depth"] mul = StepMLP(T5_MUL_IN, T5_MUL_CARB, mw, d) red = StepMLP(T5_RED_IN, T5_RED_OUT, rw, d) mul.load_state_dict(load_file(str(model_dir / "t5_mul.safetensors")), strict=True) red.load_state_dict(load_file(str(model_dir / "t5_red.safetensors")), strict=True) mul.eval(); red.eval() self.t5_mul, self.t5_red = mul, red def preprocess_a(self, a): return a def preprocess_b(self, b): return b def preprocess_p(self, p): return p @torch.no_grad() def predict_digits(self, a_enc, b_enc, p_enc): return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0] @torch.no_grad() def predict_digits_batch(self, inputs): out: list[list[int] | None] = [None] * len(inputs) s_x, s_y, s_p, s_idx = [], [], [], [] # tier 1-2 t3_x, t3_y, t3_p, t3_idx = [], [], [], [] # tier 3 t4_x, t4_y, t4_p, t4_idx = [], [], [], [] # tier 4 t5_x, t5_y, t5_p, t5_idx = [], [], [], [] # tier 5 for i, (a_enc, b_enc, p_enc) in enumerate(inputs): try: p = int(p_enc) except (ValueError, TypeError): out[i] = [0] continue # Operand normalization: a with p, then b with p (never all three). try: xr = int(a_enc) % p yr = int(b_enc) % p except (ValueError, TypeError): out[i] = [0] continue if self.small is not None and 2 <= p <= MAX_P and int(self.small.prime_lookup[p]) >= 0: s_x.append(xr); s_y.append(yr); s_p.append(p); s_idx.append(i) elif self.t3_mul is not None and T3_MIN_P <= p <= T3_MAX_P: t3_x.append(xr); t3_y.append(yr); t3_p.append(p); t3_idx.append(i) elif self.t4_mul is not None and T4_MIN_P <= p <= T4_MAX_P: t4_x.append(xr); t4_y.append(yr); t4_p.append(p); t4_idx.append(i) elif self.t5_mul is not None and T5_MIN_P <= p <= T5_MAX_P: t5_x.append(xr); t5_y.append(yr); t5_p.append(p); t5_idx.append(i) else: out[i] = [0] # outside the trained regime -> honest fallback if s_idx: preds = self.small.predict( torch.tensor(s_x, dtype=torch.long), torch.tensor(s_y, dtype=torch.long), torch.tensor(s_p, dtype=torch.long), ).tolist() for j, i in enumerate(s_idx): out[i] = [int(preds[j])] if t3_idx: xb = _bits(torch.tensor(t3_x, dtype=torch.long), T3_MUL_OPB) yb = _bits(torch.tensor(t3_y, dtype=torch.long), T3_MUL_OPB) p_t = torch.tensor(t3_p, dtype=torch.long) tb = _composed_product_bits(self.t3_mul, xb, yb, T3_MUL_OPB, T3_MUL_COLS, T3_MUL_PRB, T3_MUL_CARB) r = _composed_reduce_int(self.t3_red, tb, p_t, T3_NIB, T3_RED_NIBBLES, T3_T_BITS) for j, i in enumerate(t3_idx): out[i] = _digits_msb(int(r[j].item())) if t4_idx: xb = _bits(torch.tensor(t4_x, dtype=torch.long), T4_MUL_OPB) yb = _bits(torch.tensor(t4_y, dtype=torch.long), T4_MUL_OPB) p_t = torch.tensor(t4_p, dtype=torch.long) tb = _composed_product_bits(self.t4_mul, xb, yb, T4_MUL_OPB, T4_MUL_COLS, T4_MUL_PRB, T4_MUL_CARB) r = _composed_reduce_int(self.t4_red, tb, p_t, T4_NIB, T4_RED_NIBBLES, T4_T_BITS) for j, i in enumerate(t4_idx): out[i] = _digits_msb(int(r[j].item())) if t5_idx: # 64-bit residues overflow signed int64: carry x, y, p as bit tensors. xb = _int_bits(t5_x, T5_MUL_OPB) yb = _int_bits(t5_y, T5_MUL_OPB) pb = _int_bits(t5_p, T5_RED_NIBBLES * T5_NIB) tb = _composed_product_bits(self.t5_mul, xb, yb, T5_MUL_OPB, T5_MUL_COLS, T5_MUL_PRB, T5_MUL_CARB) r_bits = _composed_reduce_bits(self.t5_red, tb, pb, T5_NIB, T5_RED_NIBBLES, T5_T_BITS) r_vals = _bits_to_ints(r_bits[:, :T5_MUL_OPB]) for j, i in enumerate(t5_idx): out[i] = _digits_msb(r_vals[j]) return [o if o is not None else [0] for o in out] def max_batch_size(self) -> int: return 512 def _digits_msb(v: int) -> list[int]: """Base-256 digits, MSB-first; at least one digit.""" if v == 0: return [0] ds = [] while v > 0: ds.append(v & 0xFF) v >>= 8 return ds[::-1]