| """Residue router, version 4: small-prime specialist (tiers 1-2), a lifted |
| local-step pipeline for tier 3 (16-bit residues), the same two shared rules |
| lifted to 32-bit limbs for tier 4 (17-32-bit primes), and lifted again to 64-bit |
| limbs for tier 5 (33-64-bit primes, operands to 128 bits). |
| |
| Routing by the size of p: |
| |
| * p <= 251 (tiers 1-2): the v1 residue specialist. Each operand residue is |
| looked up in a shared per-(prime, residue) table; the two vectors are |
| combined by ADDITION (a discrete-log inductive bias: logs add under |
| multiplication); a residual MLP trunk transforms the sum; logits come from a |
| per-(prime, class) output table masked to the p classes of the current |
| prime. The answer is a single base-256 digit (p <= 251 < 256). |
| |
| * 251 < p < 65536 (tier 3): two trained shared LOCAL-RULE step nets composed |
| through fixed wiring. After reduction, x, y are 16-bit residues. A MULTIPLY |
| step (the shared carry rule c' = floor((S+c)/2) over the carry-save column |
| sums, composed closed-loop through a fixed parity readout) emits the exact |
| 32-bit product t = x*y. A REDUCTION step (a shared per-nibble borrow/compare |
| rule, composed through fixed restoring-division wiring) emits r = t mod p. |
| The answer r is emitted as base-256 digits MSB-first. |
| |
| * 65536 <= p < 2^32 (tier 4): the SAME two rules at 32-bit geometry. After |
| reduction, x, y are 32-bit residues. The MULTIPLY step (33x32-case carry |
| rule over the 63 carry-save columns, parity readout widened to 64 bits) |
| emits the 64-bit product as BITS. The REDUCTION step (the identical 512-case |
| borrow rule) composed over 64 division positions x 9 nibbles emits |
| r = t mod p. The answer r (< 2^32) is emitted as up to four base-256 digits. |
| |
| * 2^32 <= p < 2^64 (tier 5): the SAME two rules at 64-bit geometry. After |
| reduction, x, y are 64-bit residues. Because a 64-bit residue and the 65-bit |
| division register both overflow signed int64, tier 5 carries operands, p, |
| the product, and the division register as BIT tensors -- no wide value is |
| ever materialized as an int64 scalar. The MULTIPLY step (65x64-case carry |
| rule over the 127 carry-save columns, parity readout widened to 128 bits) |
| emits the 128-bit product as bits. The REDUCTION step (the identical 512-case |
| borrow rule) composed over 128 division positions x 17 nibbles emits |
| r = t mod p in [0, p). The answer r (< 2^64) is emitted as up to eight |
| base-256 digits MSB-first. |
| |
| * p >= 2^64 (tiers 6-10): outside the trained regime; returns [0]. |
| |
| Nothing in the forward pass hand-codes the arithmetic over the actual (a, b, p): |
| the carry-save column sums, the parity readout, the bit shifts, the restoring- |
| division topology, and the ge-from-final-borrow decision are FIXED scaffold; the |
| two NONTRIVIAL decisions -- the carry rule and the borrow/compare rule -- live |
| in trained MLP parameters (separate nets per tier-3 / tier-4 / tier-5 geometry). |
| Randomizing any step net's weights collapses its tier. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from modchallenge.interface.base_model import ModularMultiplicationModel |
|
|
| |
| |
| |
|
|
| PRIMES = ( |
| 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, |
| 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, |
| 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, |
| 223, 227, 229, 233, 239, 241, 251, |
| ) |
| MAX_P = 251 |
|
|
|
|
| class SmallResidueNet(nn.Module): |
| def __init__(self, d_model: int = 128, hidden: int = 1024): |
| super().__init__() |
| offsets, acc = [], 0 |
| for p in PRIMES: |
| offsets.append(acc) |
| acc += p |
| table = acc |
| self.pair_emb = nn.Embedding(table, d_model) |
| self.out_emb = nn.Embedding(table, d_model) |
| self.prime_emb = nn.Embedding(len(PRIMES), d_model) |
| self.trunk = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, d_model), |
| ) |
| self.ln_out = nn.LayerNorm(d_model) |
|
|
| self.register_buffer( |
| "primes_t", torch.tensor(PRIMES, dtype=torch.long), persistent=False |
| ) |
| self.register_buffer( |
| "offsets_t", torch.tensor(offsets, dtype=torch.long), persistent=False |
| ) |
| lookup = torch.full((MAX_P + 1,), -1, dtype=torch.long) |
| for i, p in enumerate(PRIMES): |
| lookup[p] = i |
| self.register_buffer("prime_lookup", lookup, persistent=False) |
| self.register_buffer( |
| "class_grid", torch.arange(MAX_P, dtype=torch.long), persistent=False |
| ) |
|
|
| def forward( |
| self, ix: torch.Tensor, iy: torch.Tensor, p_idx: torch.Tensor |
| ) -> torch.Tensor: |
| h = self.pair_emb(ix) + self.pair_emb(iy) + self.prime_emb(p_idx) |
| g = self.ln_out(h + self.trunk(h)) |
| off = self.offsets_t[p_idx] |
| pv = self.primes_t[p_idx] |
| grid = self.class_grid.unsqueeze(0) |
| valid = grid < pv.unsqueeze(1) |
| logits = (g @ self.out_emb.weight.t()).gather(1, off.unsqueeze(1) + grid) |
| return logits.masked_fill(~valid, float("-inf")) |
|
|
| @torch.no_grad() |
| def predict( |
| self, x: torch.Tensor, y: torch.Tensor, p: torch.Tensor |
| ) -> torch.Tensor: |
| p_idx = self.prime_lookup[p] |
| off = self.offsets_t[p_idx] |
| return self.forward(off + x, off + y, p_idx).argmax(dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class StepMLP(nn.Module): |
| """Plain GELU MLP step: n_in local-state bits -> n_out logits.""" |
|
|
| def __init__(self, n_in: int, n_out: int, width: int, depth: int): |
| super().__init__() |
| self.layers = nn.ModuleList([nn.Linear(n_in, width)]) |
| for _ in range(depth - 1): |
| self.layers.append(nn.Linear(width, width)) |
| self.head = nn.Linear(width, n_out) |
| self.act = nn.GELU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = x |
| for lin in self.layers: |
| h = self.act(lin(h)) |
| return self.head(h) |
|
|
|
|
| |
| |
| |
|
|
| |
| T3_MUL_OPB = 16 |
| T3_MUL_PRB = 32 |
| T3_MUL_COLS = 2 * T3_MUL_OPB - 1 |
| T3_MUL_SUMB = 5 |
| T3_MUL_CARB = 4 |
| T3_MUL_IN = T3_MUL_SUMB + T3_MUL_CARB |
| T3_NIB = 4 |
| T3_RED_NIBBLES = 5 |
| T3_RED_IN = T3_NIB + T3_NIB + 1 |
| T3_RED_OUT = T3_NIB + 1 |
| T3_T_BITS = 32 |
|
|
| |
| T4_MUL_OPB = 32 |
| T4_MUL_PRB = 64 |
| T4_MUL_COLS = 2 * T4_MUL_OPB - 1 |
| T4_MUL_SUMB = 6 |
| T4_MUL_CARB = 5 |
| T4_MUL_IN = T4_MUL_SUMB + T4_MUL_CARB |
| T4_NIB = 4 |
| T4_RED_NIBBLES = 9 |
| T4_RED_IN = T4_NIB + T4_NIB + 1 |
| T4_RED_OUT = T4_NIB + 1 |
| T4_T_BITS = 64 |
|
|
| |
| T5_MUL_OPB = 64 |
| T5_MUL_PRB = 128 |
| T5_MUL_COLS = 2 * T5_MUL_OPB - 1 |
| T5_MUL_SUMB = 7 |
| T5_MUL_CARB = 6 |
| T5_MUL_IN = T5_MUL_SUMB + T5_MUL_CARB |
| T5_NIB = 4 |
| T5_RED_NIBBLES = 17 |
| T5_RED_IN = T5_NIB + T5_NIB + 1 |
| T5_RED_OUT = T5_NIB + 1 |
| T5_T_BITS = 128 |
|
|
|
|
| |
|
|
| def _bits(v: torch.Tensor, nb: int) -> torch.Tensor: |
| return ((v.unsqueeze(1) >> torch.arange(nb, device=v.device)) & 1).float() |
|
|
|
|
| def _column_sums(x_bits: torch.Tensor, y_bits: torch.Tensor, opb: int, cols: int) -> torch.Tensor: |
| n = x_bits.shape[0] |
| s = torch.zeros(n, cols, dtype=x_bits.dtype, device=x_bits.device) |
| for i in range(opb): |
| s[:, i:i + opb] += x_bits[:, i:i + 1] * y_bits |
| return s |
|
|
|
|
| def _encode_carry(s: torch.Tensor, c: torch.Tensor, sumb: int, carb: int) -> torch.Tensor: |
| si = torch.arange(sumb, device=s.device) |
| ci = torch.arange(carb, device=c.device) |
| sb = ((s.unsqueeze(1) >> si) & 1).float() |
| cb = ((c.unsqueeze(1) >> ci) & 1).float() |
| return torch.cat([sb, cb], dim=1) |
|
|
|
|
| def _carry_bits_to_int(bits: torch.Tensor, carb: int) -> torch.Tensor: |
| w = (1 << torch.arange(carb, device=bits.device)).long() |
| return (bits.round().clamp(0, 1).long() * w).sum(dim=-1) |
|
|
|
|
| _SUMB_FOR_COLS = {T3_MUL_COLS: T3_MUL_SUMB, T4_MUL_COLS: T4_MUL_SUMB, T5_MUL_COLS: T5_MUL_SUMB} |
|
|
|
|
| @torch.no_grad() |
| def _closed_loop_mul(step, col_sums, cols, carb): |
| n = col_sums.shape[0] |
| s = col_sums.long() |
| carry = torch.zeros(n, dtype=torch.long, device=s.device) |
| out = torch.empty(n, cols * carb, device=col_sums.device) |
| sumb = _SUMB_FOR_COLS[cols] |
| for c in range(cols): |
| lg = step(_encode_carry(s[:, c], carry, sumb, carb)) |
| out[:, carb * c:carb * (c + 1)] = lg |
| carry = _carry_bits_to_int((lg > 0).float(), carb) |
| return out |
|
|
|
|
| def _routed_product_bits(carry_logits, col_parity, carb): |
| """Fixed parity readout: bit_c = parity(S_c) XOR lsb(carry into c).""" |
| BIG = 20.0 |
| lsb = carry_logits[:, 0::carb] |
| bit0 = (2.0 * col_parity[:, 0:1] - 1.0) * BIG |
| mid = (1.0 - 2.0 * col_parity[:, 1:]) * lsb[:, :-1] |
| bit_last = lsb[:, -1:] |
| return torch.cat([bit0, mid, bit_last], dim=1) |
|
|
|
|
| @torch.no_grad() |
| def _composed_product_bits(step, x_bits, y_bits, opb, cols, prb, carb): |
| """Trained carry step (closed loop) + parity readout -> product BITS (B, prb). |
| |
| x_bits, y_bits are (B, opb) LSB-first operand bit tensors. |
| """ |
| col_sums = _column_sums(x_bits, y_bits, opb, cols) |
| logits = _closed_loop_mul(step, col_sums, cols, carb) |
| col_parity = (col_sums.long() & 1).float() |
| bit_logits = _routed_product_bits(logits, col_parity, carb) |
| return (bit_logits > 0).long() |
|
|
|
|
| def _encode_red(a, b, bin_, nib): |
| ai = torch.arange(nib, device=a.device) |
| aa = ((a.unsqueeze(1) >> ai) & 1).float() |
| bb = ((b.unsqueeze(1) >> ai) & 1).float() |
| cc = bin_.float().unsqueeze(1) |
| return torch.cat([aa, bb, cc], dim=1) |
|
|
|
|
| def _red_bits_to_out(bits, nib): |
| hb = (bits > 0).long() |
| w = (1 << torch.arange(nib, device=bits.device)).long() |
| d = (hb[:, :nib] * w).sum(dim=1) |
| bout = hb[:, nib] |
| return d, bout |
|
|
|
|
| @torch.no_grad() |
| def _composed_reduce_int(step, t_bits, p, nib, nibbles, t_bits_n): |
| """Restoring division by p when p and R fit signed int64 (tiers 3-4). |
| |
| R stays in [0, p) (< 2^32), so R never overflows int64 even though the full |
| product does. Fixed wiring; per-nibble subtract DECISION is the trained step. |
| """ |
| n = t_bits.shape[0] |
| device = t_bits.device |
| R = torch.zeros(n, dtype=torch.long, device=device) |
| p_nib = torch.stack([(p >> (nib * k)) & 0xF for k in range(nibbles)], dim=1) |
| wk = (1 << (nib * torch.arange(nibbles, device=device))).long() |
| for i in range(t_bits_n - 1, -1, -1): |
| bit = t_bits[:, i].long() |
| Rpre = (R << 1) | bit |
| borrow = torch.zeros(n, dtype=torch.long, device=device) |
| diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device) |
| for k in range(nibbles): |
| an = (Rpre >> (nib * k)) & 0xF |
| bn = p_nib[:, k] |
| lg = step(_encode_red(an, bn, borrow, nib)) |
| d, bout = _red_bits_to_out(lg, nib) |
| diff_nib[:, k] = d |
| borrow = bout |
| ge = (borrow == 0).long() |
| diff_val = (diff_nib * wk).sum(dim=1) |
| R = torch.where(ge.bool(), diff_val, Rpre) |
| return R |
|
|
|
|
| def _p_nibbles_from_bits(p_bits, nib, nibbles): |
| n = p_bits.shape[0] |
| out = torch.zeros(n, nibbles, dtype=torch.long, device=p_bits.device) |
| wt = (1 << torch.arange(nib, device=p_bits.device)).long() |
| for k in range(nibbles): |
| chunk = p_bits[:, nib * k:nib * (k + 1)] |
| if chunk.shape[1] < nib: |
| pad = torch.zeros(n, nib - chunk.shape[1], device=p_bits.device) |
| chunk = torch.cat([chunk, pad], dim=1) |
| out[:, k] = (chunk.long() * wt).sum(dim=1) |
| return out |
|
|
|
|
| @torch.no_grad() |
| def _composed_reduce_bits(step, t_bits, p_bits, nib, nibbles, t_bits_n): |
| """Restoring division by p with R, R_pre, p carried as BIT tensors (tier 5). |
| |
| Nothing overflows int64: R has up to 64 bits, R_pre = 2R + bit up to 65 bits |
| -> 17 nibbles. Returns the remainder as (B, nibbles*nib) bits LSB-first. |
| """ |
| n = t_bits.shape[0] |
| device = t_bits.device |
| RB = nibbles * nib |
| R = torch.zeros(n, RB, dtype=torch.long, device=device) |
| p_nib = _p_nibbles_from_bits(p_bits, nib, nibbles) |
| wt = (1 << torch.arange(nib, device=device)).long() |
| di = torch.arange(nib, device=device) |
| for i in range(t_bits_n - 1, -1, -1): |
| Rpre = torch.zeros(n, RB, dtype=torch.long, device=device) |
| Rpre[:, 1:] = R[:, :-1] |
| Rpre[:, 0] = t_bits[:, i].long() |
| borrow = torch.zeros(n, dtype=torch.long, device=device) |
| diff_nib = torch.zeros(n, nibbles, dtype=torch.long, device=device) |
| for k in range(nibbles): |
| an = (Rpre[:, nib * k:nib * (k + 1)].long() * wt).sum(dim=1) |
| bn = p_nib[:, k] |
| lg = step(_encode_red(an, bn, borrow, nib)) |
| d, bout = _red_bits_to_out(lg, nib) |
| diff_nib[:, k] = d |
| borrow = bout |
| ge = (borrow == 0) |
| diff_bits = torch.zeros(n, RB, dtype=torch.long, device=device) |
| for k in range(nibbles): |
| diff_bits[:, nib * k:nib * (k + 1)] = (diff_nib[:, k:k + 1] >> di) & 1 |
| R = torch.where(ge.unsqueeze(1), diff_bits, Rpre) |
| return R.float() |
|
|
|
|
| |
|
|
| def _int_bits(values, nb: int) -> torch.Tensor: |
| out = torch.zeros(len(values), nb, dtype=torch.float32) |
| for r, v in enumerate(values): |
| v = int(v) |
| b = 0 |
| while v and b < nb: |
| out[r, b] = float(v & 1) |
| v >>= 1 |
| b += 1 |
| return out |
|
|
|
|
| def _bits_to_ints(bits: torch.Tensor) -> list[int]: |
| hb = (bits > 0.5).long().tolist() |
| out = [] |
| for row in hb: |
| v = 0 |
| for b, bit in enumerate(row): |
| if bit: |
| v |= (1 << b) |
| out.append(v) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| T3_MIN_P = MAX_P + 1 |
| T3_MAX_P = (1 << 16) - 1 |
| T4_MIN_P = 1 << 16 |
| T4_MAX_P = (1 << 32) - 1 |
| T5_MIN_P = 1 << 32 |
| T5_MAX_P = (1 << 64) - 1 |
|
|
|
|
| class ResidueRouterV1(ModularMultiplicationModel): |
| """Router over per-tier specialists, selected by the size of p. |
| |
| Kept the class name ``ResidueRouterV1`` so the manifest entry_class is |
| stable across versions; this is v4 (tiers 1-5). |
| """ |
|
|
| def __init__(self): |
| self.small: SmallResidueNet | None = None |
| self.t3_mul: StepMLP | None = None |
| self.t3_red: StepMLP | None = None |
| self.t4_mul: StepMLP | None = None |
| self.t4_red: StepMLP | None = None |
| self.t5_mul: StepMLP | None = None |
| self.t5_red: StepMLP | None = None |
|
|
| def load(self, model_dir: str) -> None: |
| from safetensors.torch import load_file |
|
|
| torch.manual_seed(0) |
| model_dir = Path(model_dir) |
| config = json.loads((model_dir / "config.json").read_text()) |
|
|
| tensors = load_file(str(model_dir / "weights.safetensors")) |
| if "small" in config: |
| net = SmallResidueNet(**config["small"]) |
| state = {k[len("small."):]: v for k, v in tensors.items() if k.startswith("small.")} |
| net.load_state_dict(state, strict=True) |
| net.eval() |
| self.small = net |
|
|
| if "t3" in config: |
| w, d = config["t3"]["width"], config["t3"]["depth"] |
| mul = StepMLP(T3_MUL_IN, T3_MUL_CARB, w, d) |
| red = StepMLP(T3_RED_IN, T3_RED_OUT, w, d) |
| mul.load_state_dict(load_file(str(model_dir / "t3_mul.safetensors")), strict=True) |
| red.load_state_dict(load_file(str(model_dir / "t3_red.safetensors")), strict=True) |
| mul.eval(); red.eval() |
| self.t3_mul, self.t3_red = mul, red |
|
|
| if "t4" in config: |
| mw, rw, d = config["t4"]["mul_width"], config["t4"]["red_width"], config["t4"]["depth"] |
| mul = StepMLP(T4_MUL_IN, T4_MUL_CARB, mw, d) |
| red = StepMLP(T4_RED_IN, T4_RED_OUT, rw, d) |
| mul.load_state_dict(load_file(str(model_dir / "t4_mul.safetensors")), strict=True) |
| red.load_state_dict(load_file(str(model_dir / "t4_red.safetensors")), strict=True) |
| mul.eval(); red.eval() |
| self.t4_mul, self.t4_red = mul, red |
|
|
| if "t5" in config: |
| mw, rw, d = config["t5"]["mul_width"], config["t5"]["red_width"], config["t5"]["depth"] |
| mul = StepMLP(T5_MUL_IN, T5_MUL_CARB, mw, d) |
| red = StepMLP(T5_RED_IN, T5_RED_OUT, rw, d) |
| mul.load_state_dict(load_file(str(model_dir / "t5_mul.safetensors")), strict=True) |
| red.load_state_dict(load_file(str(model_dir / "t5_red.safetensors")), strict=True) |
| mul.eval(); red.eval() |
| self.t5_mul, self.t5_red = mul, red |
|
|
| def preprocess_a(self, a): |
| return a |
|
|
| def preprocess_b(self, b): |
| return b |
|
|
| def preprocess_p(self, p): |
| return p |
|
|
| @torch.no_grad() |
| def predict_digits(self, a_enc, b_enc, p_enc): |
| return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0] |
|
|
| @torch.no_grad() |
| def predict_digits_batch(self, inputs): |
| out: list[list[int] | None] = [None] * len(inputs) |
| s_x, s_y, s_p, s_idx = [], [], [], [] |
| t3_x, t3_y, t3_p, t3_idx = [], [], [], [] |
| t4_x, t4_y, t4_p, t4_idx = [], [], [], [] |
| t5_x, t5_y, t5_p, t5_idx = [], [], [], [] |
|
|
| for i, (a_enc, b_enc, p_enc) in enumerate(inputs): |
| try: |
| p = int(p_enc) |
| except (ValueError, TypeError): |
| out[i] = [0] |
| continue |
| |
| try: |
| xr = int(a_enc) % p |
| yr = int(b_enc) % p |
| except (ValueError, TypeError): |
| out[i] = [0] |
| continue |
|
|
| if self.small is not None and 2 <= p <= MAX_P and int(self.small.prime_lookup[p]) >= 0: |
| s_x.append(xr); s_y.append(yr); s_p.append(p); s_idx.append(i) |
| elif self.t3_mul is not None and T3_MIN_P <= p <= T3_MAX_P: |
| t3_x.append(xr); t3_y.append(yr); t3_p.append(p); t3_idx.append(i) |
| elif self.t4_mul is not None and T4_MIN_P <= p <= T4_MAX_P: |
| t4_x.append(xr); t4_y.append(yr); t4_p.append(p); t4_idx.append(i) |
| elif self.t5_mul is not None and T5_MIN_P <= p <= T5_MAX_P: |
| t5_x.append(xr); t5_y.append(yr); t5_p.append(p); t5_idx.append(i) |
| else: |
| out[i] = [0] |
|
|
| if s_idx: |
| preds = self.small.predict( |
| torch.tensor(s_x, dtype=torch.long), |
| torch.tensor(s_y, dtype=torch.long), |
| torch.tensor(s_p, dtype=torch.long), |
| ).tolist() |
| for j, i in enumerate(s_idx): |
| out[i] = [int(preds[j])] |
|
|
| if t3_idx: |
| xb = _bits(torch.tensor(t3_x, dtype=torch.long), T3_MUL_OPB) |
| yb = _bits(torch.tensor(t3_y, dtype=torch.long), T3_MUL_OPB) |
| p_t = torch.tensor(t3_p, dtype=torch.long) |
| tb = _composed_product_bits(self.t3_mul, xb, yb, T3_MUL_OPB, T3_MUL_COLS, |
| T3_MUL_PRB, T3_MUL_CARB) |
| r = _composed_reduce_int(self.t3_red, tb, p_t, T3_NIB, T3_RED_NIBBLES, T3_T_BITS) |
| for j, i in enumerate(t3_idx): |
| out[i] = _digits_msb(int(r[j].item())) |
|
|
| if t4_idx: |
| xb = _bits(torch.tensor(t4_x, dtype=torch.long), T4_MUL_OPB) |
| yb = _bits(torch.tensor(t4_y, dtype=torch.long), T4_MUL_OPB) |
| p_t = torch.tensor(t4_p, dtype=torch.long) |
| tb = _composed_product_bits(self.t4_mul, xb, yb, T4_MUL_OPB, T4_MUL_COLS, |
| T4_MUL_PRB, T4_MUL_CARB) |
| r = _composed_reduce_int(self.t4_red, tb, p_t, T4_NIB, T4_RED_NIBBLES, T4_T_BITS) |
| for j, i in enumerate(t4_idx): |
| out[i] = _digits_msb(int(r[j].item())) |
|
|
| if t5_idx: |
| |
| xb = _int_bits(t5_x, T5_MUL_OPB) |
| yb = _int_bits(t5_y, T5_MUL_OPB) |
| pb = _int_bits(t5_p, T5_RED_NIBBLES * T5_NIB) |
| tb = _composed_product_bits(self.t5_mul, xb, yb, T5_MUL_OPB, T5_MUL_COLS, |
| T5_MUL_PRB, T5_MUL_CARB) |
| r_bits = _composed_reduce_bits(self.t5_red, tb, pb, T5_NIB, T5_RED_NIBBLES, T5_T_BITS) |
| r_vals = _bits_to_ints(r_bits[:, :T5_MUL_OPB]) |
| for j, i in enumerate(t5_idx): |
| out[i] = _digits_msb(r_vals[j]) |
|
|
| return [o if o is not None else [0] for o in out] |
|
|
| def max_batch_size(self) -> int: |
| return 512 |
|
|
|
|
| def _digits_msb(v: int) -> list[int]: |
| """Base-256 digits, MSB-first; at least one digit.""" |
| if v == 0: |
| return [0] |
| ds = [] |
| while v > 0: |
| ds.append(v & 0xFF) |
| v >>= 8 |
| return ds[::-1] |
|
|