rob-rbyte-v4 / model.py
TrickyRex's picture
Upload folder using huggingface_hub
477cc48 verified
Raw
History Blame Contribute Delete
22.2 kB
"""Residue router, version 4: small-prime specialist (tiers 1-2), a lifted
local-step pipeline for tier 3 (16-bit residues), the same two shared rules
lifted to 32-bit limbs for tier 4 (17-32-bit primes), and lifted again to 64-bit
limbs for tier 5 (33-64-bit primes, operands to 128 bits).
Routing by the size of p:
* p <= 251 (tiers 1-2): the v1 residue specialist. Each operand residue is
looked up in a shared per-(prime, residue) table; the two vectors are
combined by ADDITION (a discrete-log inductive bias: logs add under
multiplication); a residual MLP trunk transforms the sum; logits come from a
per-(prime, class) output table masked to the p classes of the current
prime. The answer is a single base-256 digit (p <= 251 < 256).
* 251 < p < 65536 (tier 3): two trained shared LOCAL-RULE step nets composed
through fixed wiring. After reduction, x, y are 16-bit residues. A MULTIPLY
step (the shared carry rule c' = floor((S+c)/2) over the carry-save column
sums, composed closed-loop through a fixed parity readout) emits the exact
32-bit product t = x*y. A REDUCTION step (a shared per-nibble borrow/compare
rule, composed through fixed restoring-division wiring) emits r = t mod p.
The answer r is emitted as base-256 digits MSB-first.
* 65536 <= p < 2^32 (tier 4): the SAME two rules at 32-bit geometry. After
reduction, x, y are 32-bit residues. The MULTIPLY step (33x32-case carry
rule over the 63 carry-save columns, parity readout widened to 64 bits)
emits the 64-bit product as BITS. The REDUCTION step (the identical 512-case
borrow rule) composed over 64 division positions x 9 nibbles emits
r = t mod p. The answer r (< 2^32) is emitted as up to four base-256 digits.
* 2^32 <= p < 2^64 (tier 5): the SAME two rules at 64-bit geometry. After
reduction, x, y are 64-bit residues. Because a 64-bit residue and the 65-bit
division register both overflow signed int64, tier 5 carries operands, p,
the product, and the division register as BIT tensors -- no wide value is
ever materialized as an int64 scalar. The MULTIPLY step (65x64-case carry
rule over the 127 carry-save columns, parity readout widened to 128 bits)
emits the 128-bit product as bits. The REDUCTION step (the identical 512-case
borrow rule) composed over 128 division positions x 17 nibbles emits
r = t mod p in [0, p). The answer r (< 2^64) is emitted as up to eight
base-256 digits MSB-first.
* p >= 2^64 (tiers 6-10): outside the trained regime; returns [0].
Nothing in the forward pass hand-codes the arithmetic over the actual (a, b, p):
the carry-save column sums, the parity readout, the bit shifts, the restoring-
division topology, and the ge-from-final-borrow decision are FIXED scaffold; the
two NONTRIVIAL decisions -- the carry rule and the borrow/compare rule -- live
in trained MLP parameters (separate nets per tier-3 / tier-4 / tier-5 geometry).
Randomizing any step net's weights collapses its tier.
"""
from __future__ import annotations
import json
from pathlib import Path
import torch
import torch.nn as nn
from modchallenge.interface.base_model import ModularMultiplicationModel
# ===========================================================================
# Tier 1-2 specialist (v1 residue net)
# ===========================================================================
PRIMES = (
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211,
223, 227, 229, 233, 239, 241, 251,
)
MAX_P = 251
class SmallResidueNet(nn.Module):
def __init__(self, d_model: int = 128, hidden: int = 1024):
super().__init__()
offsets, acc = [], 0
for p in PRIMES:
offsets.append(acc)
acc += p
table = acc # 6081
self.pair_emb = nn.Embedding(table, d_model)
self.out_emb = nn.Embedding(table, d_model)
self.prime_emb = nn.Embedding(len(PRIMES), d_model)
self.trunk = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, hidden),
nn.GELU(),
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Linear(hidden, d_model),
)
self.ln_out = nn.LayerNorm(d_model)
self.register_buffer(
"primes_t", torch.tensor(PRIMES, dtype=torch.long), persistent=False
)
self.register_buffer(
"offsets_t", torch.tensor(offsets, dtype=torch.long), persistent=False
)
lookup = torch.full((MAX_P + 1,), -1, dtype=torch.long)
for i, p in enumerate(PRIMES):
lookup[p] = i
self.register_buffer("prime_lookup", lookup, persistent=False)
self.register_buffer(
"class_grid", torch.arange(MAX_P, dtype=torch.long), persistent=False
)
def forward(
self, ix: torch.Tensor, iy: torch.Tensor, p_idx: torch.Tensor
) -> torch.Tensor:
h = self.pair_emb(ix) + self.pair_emb(iy) + self.prime_emb(p_idx)
g = self.ln_out(h + self.trunk(h))
off = self.offsets_t[p_idx]
pv = self.primes_t[p_idx]
grid = self.class_grid.unsqueeze(0)
valid = grid < pv.unsqueeze(1)
logits = (g @ self.out_emb.weight.t()).gather(1, off.unsqueeze(1) + grid)
return logits.masked_fill(~valid, float("-inf"))
@torch.no_grad()
def predict(
self, x: torch.Tensor, y: torch.Tensor, p: torch.Tensor
) -> torch.Tensor:
p_idx = self.prime_lookup[p]
off = self.offsets_t[p_idx]
return self.forward(off + x, off + y, p_idx).argmax(dim=-1)
# ===========================================================================
# Shared step-net architecture (used by tier-3 / tier-4 / tier-5 geometries)
# ===========================================================================
class StepMLP(nn.Module):
"""Plain GELU MLP step: n_in local-state bits -> n_out logits."""
def __init__(self, n_in: int, n_out: int, width: int, depth: int):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(n_in, width)])
for _ in range(depth - 1):
self.layers.append(nn.Linear(width, width))
self.head = nn.Linear(width, n_out)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = x
for lin in self.layers:
h = self.act(lin(h))
return self.head(h)
# ===========================================================================
# Per-tier multiply / reduction geometries
# ===========================================================================
# Tier 3 (16x16 -> 32-bit; 5-nibble reduction)
T3_MUL_OPB = 16
T3_MUL_PRB = 32
T3_MUL_COLS = 2 * T3_MUL_OPB - 1 # 31
T3_MUL_SUMB = 5
T3_MUL_CARB = 4
T3_MUL_IN = T3_MUL_SUMB + T3_MUL_CARB # 9
T3_NIB = 4
T3_RED_NIBBLES = 5
T3_RED_IN = T3_NIB + T3_NIB + 1 # 9
T3_RED_OUT = T3_NIB + 1 # 5
T3_T_BITS = 32
# Tier 4 (32x32 -> 64-bit; 9-nibble reduction)
T4_MUL_OPB = 32
T4_MUL_PRB = 64
T4_MUL_COLS = 2 * T4_MUL_OPB - 1 # 63
T4_MUL_SUMB = 6
T4_MUL_CARB = 5
T4_MUL_IN = T4_MUL_SUMB + T4_MUL_CARB # 11
T4_NIB = 4
T4_RED_NIBBLES = 9
T4_RED_IN = T4_NIB + T4_NIB + 1 # 9
T4_RED_OUT = T4_NIB + 1 # 5
T4_T_BITS = 64
# Tier 5 (64x64 -> 128-bit; 17-nibble reduction). Wide values are bit tensors.
T5_MUL_OPB = 64
T5_MUL_PRB = 128
T5_MUL_COLS = 2 * T5_MUL_OPB - 1 # 127
T5_MUL_SUMB = 7 # S_c <= 64
T5_MUL_CARB = 6 # carry <= 63
T5_MUL_IN = T5_MUL_SUMB + T5_MUL_CARB # 13
T5_NIB = 4
T5_RED_NIBBLES = 17 # 65-bit R_pre / 64-bit p
T5_RED_IN = T5_NIB + T5_NIB + 1 # 9
T5_RED_OUT = T5_NIB + 1 # 5
T5_T_BITS = 128
# -- generic carry-save / reduction wiring (parameterized by geometry) -------
def _bits(v: torch.Tensor, nb: int) -> torch.Tensor:
return ((v.unsqueeze(1) >> torch.arange(nb, device=v.device)) & 1).float()
def _column_sums(x_bits: torch.Tensor, y_bits: torch.Tensor, opb: int, cols: int) -> torch.Tensor:
n = x_bits.shape[0]
s = torch.zeros(n, cols, dtype=x_bits.dtype, device=x_bits.device)
for i in range(opb):
s[:, i:i + opb] += x_bits[:, i:i + 1] * y_bits
return s
def _encode_carry(s: torch.Tensor, c: torch.Tensor, sumb: int, carb: int) -> torch.Tensor:
si = torch.arange(sumb, device=s.device)
ci = torch.arange(carb, device=c.device)
sb = ((s.unsqueeze(1) >> si) & 1).float()
cb = ((c.unsqueeze(1) >> ci) & 1).float()
return torch.cat([sb, cb], dim=1)
def _carry_bits_to_int(bits: torch.Tensor, carb: int) -> torch.Tensor:
w = (1 << torch.arange(carb, device=bits.device)).long()
return (bits.round().clamp(0, 1).long() * w).sum(dim=-1)
_SUMB_FOR_COLS = {T3_MUL_COLS: T3_MUL_SUMB, T4_MUL_COLS: T4_MUL_SUMB, T5_MUL_COLS: T5_MUL_SUMB}
@torch.no_grad()
def _closed_loop_mul(step, col_sums, cols, carb):
n = col_sums.shape[0]
s = col_sums.long()
carry = torch.zeros(n, dtype=torch.long, device=s.device)
out = torch.empty(n, cols * carb, device=col_sums.device)
sumb = _SUMB_FOR_COLS[cols]
for c in range(cols):
lg = step(_encode_carry(s[:, c], carry, sumb, carb))
out[:, carb * c:carb * (c + 1)] = lg
carry = _carry_bits_to_int((lg > 0).float(), carb)
return out
def _routed_product_bits(carry_logits, col_parity, carb):
"""Fixed parity readout: bit_c = parity(S_c) XOR lsb(carry into c)."""
BIG = 20.0
lsb = carry_logits[:, 0::carb]
bit0 = (2.0 * col_parity[:, 0:1] - 1.0) * BIG
mid = (1.0 - 2.0 * col_parity[:, 1:]) * lsb[:, :-1]
bit_last = lsb[:, -1:]
return torch.cat([bit0, mid, bit_last], dim=1)
@torch.no_grad()
def _composed_product_bits(step, x_bits, y_bits, opb, cols, prb, carb):
"""Trained carry step (closed loop) + parity readout -> product BITS (B, prb).
x_bits, y_bits are (B, opb) LSB-first operand bit tensors.
"""
col_sums = _column_sums(x_bits, y_bits, opb, cols)
logits = _closed_loop_mul(step, col_sums, cols, carb)
col_parity = (col_sums.long() & 1).float()
bit_logits = _routed_product_bits(logits, col_parity, carb)
return (bit_logits > 0).long() # (B, prb) bits LSB first
def _encode_red(a, b, bin_, nib):
ai = torch.arange(nib, device=a.device)
aa = ((a.unsqueeze(1) >> ai) & 1).float()
bb = ((b.unsqueeze(1) >> ai) & 1).float()
cc = bin_.float().unsqueeze(1)
return torch.cat([aa, bb, cc], dim=1)
def _red_bits_to_out(bits, nib):
hb = (bits > 0).long()
w = (1 << torch.arange(nib, device=bits.device)).long()
d = (hb[:, :nib] * w).sum(dim=1)
bout = hb[:, nib]
return d, bout
@torch.no_grad()
def _composed_reduce_int(step, t_bits, p, nib, nibbles, t_bits_n):
"""Restoring division by p when p and R fit signed int64 (tiers 3-4).
R stays in [0, p) (< 2^32), so R never overflows int64 even though the full
product does. Fixed wiring; per-nibble subtract DECISION is the trained step.
"""
n = t_bits.shape[0]
device = t_bits.device
R = torch.zeros(n, dtype=torch.long, device=device)
p_nib = torch.stack([(p >> (nib * k)) & 0xF for k in range(nibbles)], dim=1)
wk = (1 << (nib * torch.arange(nibbles, device=device))).long()
for i in range(t_bits_n - 1, -1, -1):
bit = t_bits[:, i].long()
Rpre = (R << 1) | bit
borrow = torch.zeros(n, dtype=torch.long, device=device)
diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device)
for k in range(nibbles):
an = (Rpre >> (nib * k)) & 0xF
bn = p_nib[:, k]
lg = step(_encode_red(an, bn, borrow, nib))
d, bout = _red_bits_to_out(lg, nib)
diff_nib[:, k] = d
borrow = bout
ge = (borrow == 0).long()
diff_val = (diff_nib * wk).sum(dim=1)
R = torch.where(ge.bool(), diff_val, Rpre)
return R
def _p_nibbles_from_bits(p_bits, nib, nibbles):
n = p_bits.shape[0]
out = torch.zeros(n, nibbles, dtype=torch.long, device=p_bits.device)
wt = (1 << torch.arange(nib, device=p_bits.device)).long()
for k in range(nibbles):
chunk = p_bits[:, nib * k:nib * (k + 1)]
if chunk.shape[1] < nib:
pad = torch.zeros(n, nib - chunk.shape[1], device=p_bits.device)
chunk = torch.cat([chunk, pad], dim=1)
out[:, k] = (chunk.long() * wt).sum(dim=1)
return out
@torch.no_grad()
def _composed_reduce_bits(step, t_bits, p_bits, nib, nibbles, t_bits_n):
"""Restoring division by p with R, R_pre, p carried as BIT tensors (tier 5).
Nothing overflows int64: R has up to 64 bits, R_pre = 2R + bit up to 65 bits
-> 17 nibbles. Returns the remainder as (B, nibbles*nib) bits LSB-first.
"""
n = t_bits.shape[0]
device = t_bits.device
RB = nibbles * nib
R = torch.zeros(n, RB, dtype=torch.long, device=device)
p_nib = _p_nibbles_from_bits(p_bits, nib, nibbles)
wt = (1 << torch.arange(nib, device=device)).long()
di = torch.arange(nib, device=device)
for i in range(t_bits_n - 1, -1, -1):
Rpre = torch.zeros(n, RB, dtype=torch.long, device=device)
Rpre[:, 1:] = R[:, :-1]
Rpre[:, 0] = t_bits[:, i].long()
borrow = torch.zeros(n, dtype=torch.long, device=device)
diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device)
for k in range(nibbles):
an = (Rpre[:, nib * k:nib * (k + 1)].long() * wt).sum(dim=1)
bn = p_nib[:, k]
lg = step(_encode_red(an, bn, borrow, nib))
d, bout = _red_bits_to_out(lg, nib)
diff_nib[:, k] = d
borrow = bout
ge = (borrow == 0)
diff_bits = torch.zeros(n, RB, dtype=torch.long, device=device)
for k in range(nibbles):
diff_bits[:, nib * k:nib * (k + 1)] = (diff_nib[:, k:k + 1] >> di) & 1
R = torch.where(ge.unsqueeze(1), diff_bits, Rpre)
return R.float()
# -- bigint <-> bit-tensor helpers (tier 5: residues exceed signed int64) ----
def _int_bits(values, nb: int) -> torch.Tensor:
out = torch.zeros(len(values), nb, dtype=torch.float32)
for r, v in enumerate(values):
v = int(v)
b = 0
while v and b < nb:
out[r, b] = float(v & 1)
v >>= 1
b += 1
return out
def _bits_to_ints(bits: torch.Tensor) -> list[int]:
hb = (bits > 0.5).long().tolist()
out = []
for row in hb:
v = 0
for b, bit in enumerate(row):
if bit:
v |= (1 << b)
out.append(v)
return out
# ===========================================================================
# Router
# ===========================================================================
T3_MIN_P = MAX_P + 1 # 252
T3_MAX_P = (1 << 16) - 1
T4_MIN_P = 1 << 16
T4_MAX_P = (1 << 32) - 1
T5_MIN_P = 1 << 32
T5_MAX_P = (1 << 64) - 1
class ResidueRouterV1(ModularMultiplicationModel):
"""Router over per-tier specialists, selected by the size of p.
Kept the class name ``ResidueRouterV1`` so the manifest entry_class is
stable across versions; this is v4 (tiers 1-5).
"""
def __init__(self):
self.small: SmallResidueNet | None = None
self.t3_mul: StepMLP | None = None
self.t3_red: StepMLP | None = None
self.t4_mul: StepMLP | None = None
self.t4_red: StepMLP | None = None
self.t5_mul: StepMLP | None = None
self.t5_red: StepMLP | None = None
def load(self, model_dir: str) -> None:
from safetensors.torch import load_file
torch.manual_seed(0)
model_dir = Path(model_dir)
config = json.loads((model_dir / "config.json").read_text())
tensors = load_file(str(model_dir / "weights.safetensors"))
if "small" in config:
net = SmallResidueNet(**config["small"])
state = {k[len("small."):]: v for k, v in tensors.items() if k.startswith("small.")}
net.load_state_dict(state, strict=True)
net.eval()
self.small = net
if "t3" in config:
w, d = config["t3"]["width"], config["t3"]["depth"]
mul = StepMLP(T3_MUL_IN, T3_MUL_CARB, w, d)
red = StepMLP(T3_RED_IN, T3_RED_OUT, w, d)
mul.load_state_dict(load_file(str(model_dir / "t3_mul.safetensors")), strict=True)
red.load_state_dict(load_file(str(model_dir / "t3_red.safetensors")), strict=True)
mul.eval(); red.eval()
self.t3_mul, self.t3_red = mul, red
if "t4" in config:
mw, rw, d = config["t4"]["mul_width"], config["t4"]["red_width"], config["t4"]["depth"]
mul = StepMLP(T4_MUL_IN, T4_MUL_CARB, mw, d)
red = StepMLP(T4_RED_IN, T4_RED_OUT, rw, d)
mul.load_state_dict(load_file(str(model_dir / "t4_mul.safetensors")), strict=True)
red.load_state_dict(load_file(str(model_dir / "t4_red.safetensors")), strict=True)
mul.eval(); red.eval()
self.t4_mul, self.t4_red = mul, red
if "t5" in config:
mw, rw, d = config["t5"]["mul_width"], config["t5"]["red_width"], config["t5"]["depth"]
mul = StepMLP(T5_MUL_IN, T5_MUL_CARB, mw, d)
red = StepMLP(T5_RED_IN, T5_RED_OUT, rw, d)
mul.load_state_dict(load_file(str(model_dir / "t5_mul.safetensors")), strict=True)
red.load_state_dict(load_file(str(model_dir / "t5_red.safetensors")), strict=True)
mul.eval(); red.eval()
self.t5_mul, self.t5_red = mul, red
def preprocess_a(self, a):
return a
def preprocess_b(self, b):
return b
def preprocess_p(self, p):
return p
@torch.no_grad()
def predict_digits(self, a_enc, b_enc, p_enc):
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
@torch.no_grad()
def predict_digits_batch(self, inputs):
out: list[list[int] | None] = [None] * len(inputs)
s_x, s_y, s_p, s_idx = [], [], [], [] # tier 1-2
t3_x, t3_y, t3_p, t3_idx = [], [], [], [] # tier 3
t4_x, t4_y, t4_p, t4_idx = [], [], [], [] # tier 4
t5_x, t5_y, t5_p, t5_idx = [], [], [], [] # tier 5
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
try:
p = int(p_enc)
except (ValueError, TypeError):
out[i] = [0]
continue
# Operand normalization: a with p, then b with p (never all three).
try:
xr = int(a_enc) % p
yr = int(b_enc) % p
except (ValueError, TypeError):
out[i] = [0]
continue
if self.small is not None and 2 <= p <= MAX_P and int(self.small.prime_lookup[p]) >= 0:
s_x.append(xr); s_y.append(yr); s_p.append(p); s_idx.append(i)
elif self.t3_mul is not None and T3_MIN_P <= p <= T3_MAX_P:
t3_x.append(xr); t3_y.append(yr); t3_p.append(p); t3_idx.append(i)
elif self.t4_mul is not None and T4_MIN_P <= p <= T4_MAX_P:
t4_x.append(xr); t4_y.append(yr); t4_p.append(p); t4_idx.append(i)
elif self.t5_mul is not None and T5_MIN_P <= p <= T5_MAX_P:
t5_x.append(xr); t5_y.append(yr); t5_p.append(p); t5_idx.append(i)
else:
out[i] = [0] # outside the trained regime -> honest fallback
if s_idx:
preds = self.small.predict(
torch.tensor(s_x, dtype=torch.long),
torch.tensor(s_y, dtype=torch.long),
torch.tensor(s_p, dtype=torch.long),
).tolist()
for j, i in enumerate(s_idx):
out[i] = [int(preds[j])]
if t3_idx:
xb = _bits(torch.tensor(t3_x, dtype=torch.long), T3_MUL_OPB)
yb = _bits(torch.tensor(t3_y, dtype=torch.long), T3_MUL_OPB)
p_t = torch.tensor(t3_p, dtype=torch.long)
tb = _composed_product_bits(self.t3_mul, xb, yb, T3_MUL_OPB, T3_MUL_COLS,
T3_MUL_PRB, T3_MUL_CARB)
r = _composed_reduce_int(self.t3_red, tb, p_t, T3_NIB, T3_RED_NIBBLES, T3_T_BITS)
for j, i in enumerate(t3_idx):
out[i] = _digits_msb(int(r[j].item()))
if t4_idx:
xb = _bits(torch.tensor(t4_x, dtype=torch.long), T4_MUL_OPB)
yb = _bits(torch.tensor(t4_y, dtype=torch.long), T4_MUL_OPB)
p_t = torch.tensor(t4_p, dtype=torch.long)
tb = _composed_product_bits(self.t4_mul, xb, yb, T4_MUL_OPB, T4_MUL_COLS,
T4_MUL_PRB, T4_MUL_CARB)
r = _composed_reduce_int(self.t4_red, tb, p_t, T4_NIB, T4_RED_NIBBLES, T4_T_BITS)
for j, i in enumerate(t4_idx):
out[i] = _digits_msb(int(r[j].item()))
if t5_idx:
# 64-bit residues overflow signed int64: carry x, y, p as bit tensors.
xb = _int_bits(t5_x, T5_MUL_OPB)
yb = _int_bits(t5_y, T5_MUL_OPB)
pb = _int_bits(t5_p, T5_RED_NIBBLES * T5_NIB)
tb = _composed_product_bits(self.t5_mul, xb, yb, T5_MUL_OPB, T5_MUL_COLS,
T5_MUL_PRB, T5_MUL_CARB)
r_bits = _composed_reduce_bits(self.t5_red, tb, pb, T5_NIB, T5_RED_NIBBLES, T5_T_BITS)
r_vals = _bits_to_ints(r_bits[:, :T5_MUL_OPB])
for j, i in enumerate(t5_idx):
out[i] = _digits_msb(r_vals[j])
return [o if o is not None else [0] for o in out]
def max_batch_size(self) -> int:
return 512
def _digits_msb(v: int) -> list[int]:
"""Base-256 digits, MSB-first; at least one digit."""
if v == 0:
return [0]
ds = []
while v > 0:
ds.append(v & 0xFF)
v >>= 8
return ds[::-1]