File size: 6,421 Bytes
3019386 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Bit-serial learned reducer (general width) for the Modular Arithmetic Challenge.
Same design as bit-serial-v1/v2: one shared, p-conditioned transition cell that
learned s' = (2*s + d*x) mod p, applied in a fixed bit-serial Horner loop (reduce a,
reduce b, multiply). The arithmetic is in the trained cell; the loop only sequences
bits. Randomising the weights collapses accuracy to chance.
This version generalises the state width to L (read from the checkpoint), so it
covers tiers up to whatever L the weights were trained for. Bit extraction uses
32-bit limbs (`to_bits_limbs`) so a modulus p >= 2^63 never overflows an int64
tensor (needed at L >= 64). State is carried as bits between steps; the harness
decoder reconstructs the integer answer from the emitted base-2 digits.
Regime: primes p < 2^L and operands up to 4*L bits. Outside it the model abstains
and emits [0] -- the honest fallback.
"""
from __future__ import annotations
from pathlib import Path
import torch
from torch import nn
from modchallenge.interface.base_model import ModularMultiplicationModel
_MASK32 = (1 << 32) - 1
def _to_bits_small(vals: torch.Tensor, width: int) -> torch.Tensor:
shifts = torch.arange(width - 1, -1, -1, device=vals.device)
return (vals[:, None] >> shifts[None, :]) & 1
def to_bits_limbs(ints, dev, width: int) -> torch.Tensor:
"""List of python ints (< 2^width) -> (N, width) MSB-first bit tensor via 32-bit limbs.
Overflow-safe for any width: no int64 tensor ever holds a value >= 2^32."""
nl = (width + 31) // 32
cols = []
for k in range(nl - 1, -1, -1): # most-significant limb first
limb = torch.tensor([(v >> (32 * k)) & _MASK32 for v in ints],
dtype=torch.int64, device=dev)
cols.append(_to_bits_small(limb, 32))
bits = torch.cat(cols, dim=1)
return bits[:, nl * 32 - width:] if width < nl * 32 else bits
class Cell(nn.Module):
def __init__(self, dmodel: int = 96, hidden: int = 128):
super().__init__()
self.in_proj = nn.Linear(3, dmodel)
self.d_emb = nn.Embedding(2, dmodel)
self.gru = nn.GRU(dmodel, hidden, num_layers=2, batch_first=True, bidirectional=True)
self.head = nn.Linear(2 * hidden, 1)
def forward(self, feat, d):
x = self.in_proj(feat) + self.d_emb(d)[:, None, :]
h, _ = self.gru(x)
return self.head(h).squeeze(-1)
def _bits_of(n: int) -> list[int]:
if n <= 0:
return [0]
out: list[int] = []
while n > 0:
out.append(n & 1)
n >>= 1
out.reverse()
return out
class BitSerialReducer(ModularMultiplicationModel):
def __init__(self) -> None:
self.model: Cell | None = None
self.device: torch.device | None = None
self.L = 32
self._Leff = 32
def load(self, model_dir: str) -> None:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
ckpt = torch.load(Path(model_dir) / "weights.pt", map_location=self.device, weights_only=True)
self.L = int(ckpt.get("L", 32))
self.model = Cell(**ckpt.get("config", {}))
self.model.load_state_dict(ckpt["state_dict"])
self.model.to(self.device)
self.model.eval()
def preprocess_a(self, a):
return _bits_of(int(a))
def preprocess_b(self, b):
return _bits_of(int(b))
def preprocess_p(self, p):
return int(p)
@torch.no_grad()
def predict_digits(self, a_enc, b_enc, p_enc):
return self.predict_digits_batch([(a_enc, b_enc, p_enc)])[0]
@torch.no_grad()
def predict_digits_batch(self, inputs):
L = self.L
max_op = 4 * L
out: list[list[int]] = [[0] for _ in inputs]
idx, a_lists, b_lists, p_vals = [], [], [], []
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
p = int(p_enc)
a_bits = list(a_enc)
b_bits = list(b_enc)
if p < 2 or p >= (1 << L) or len(a_bits) > max_op or len(b_bits) > max_op:
continue
idx.append(i)
a_lists.append(a_bits)
b_lists.append(b_bits)
p_vals.append(p)
if not idx:
return out
dev = self.device
maxp = max(int(p).bit_length() for p in p_vals)
self._Leff = min(self.L, max(32, ((maxp + 31)//32)*32))
p_bits = to_bits_limbs(p_vals, dev, self._Leff).float()
ra = self._reduce(a_lists, p_bits, dev)
rb = self._reduce(b_lists, p_bits, dev)
prod = self._mul(ra, rb, p_bits)
prod_list = prod.long().tolist()
for j, i in enumerate(idx):
out[i] = [int(x) for x in prod_list[j]]
return out
def max_batch_size(self) -> int:
return 256
def _step(self, s_bits, x_bits, p_bits, d):
feat = torch.stack([s_bits, x_bits, p_bits], dim=-1)
if self.device is not None and self.device.type == "cuda":
# bf16 for the GRU (~2x at L=2048); threshold in fp32 so the discrete
# decision is unchanged (logits are saturated, far from 0).
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = self.model(feat, d)
return (torch.sigmoid(logits.float()) > 0.5).float()
return (torch.sigmoid(self.model(feat, d)) > 0.5).float()
def _reduce(self, bit_lists, p_bits, dev):
n = len(bit_lists)
width = max(len(b) for b in bit_lists)
padded = torch.zeros((n, width), dtype=torch.long, device=dev)
for r, bl in enumerate(bit_lists):
if bl:
padded[r, width - len(bl):] = torch.tensor(bl, dtype=torch.long, device=dev)
s_bits = torch.zeros((n, self._Leff), device=dev)
x_bits = to_bits_limbs([1] * n, dev, self._Leff).float()
for pos in range(width):
s_bits = self._step(s_bits, x_bits, p_bits, padded[:, pos])
return s_bits
def _mul(self, ra_bits, rb_bits, p_bits):
n = ra_bits.shape[0]
s_bits = torch.zeros((n, self._Leff), device=ra_bits.device)
rb_long = rb_bits.long()
for k in range(self._Leff):
s_bits = self._step(s_bits, ra_bits, p_bits, rb_long[:, k])
return s_bits
|