gary-neuron / training /data.py
gary23w's picture
gary-neuron: async NCA + top-2 MoE, 26k params, 99.97%/100% exact-match on 7-digit addition
57f9808 verified
"""Reversed-digit addition data. The reversed format (least-significant digit
first) is the Lee et al. 2023 trick: the model emits the LSB first, matching the
direction carries actually propagate, which is what makes exact addition
learnable at tiny scale. Every number is a fixed-width S-cell strip, zero-padded
in the high positions, so cell i is always digit i (10^i place)."""
import numpy as np
def digits_rev(x, S):
out = np.zeros(S, dtype=np.int64)
for i in range(S):
out[i] = x % 10; x //= 10
return out
def make_batch(Bn, S, rng, maxdig=None):
"""A,B,Y int grids (Bn,S), reversed digits. Operand lengths ~Uniform[1,maxdig]."""
maxdig = maxdig or (S - 1)
la = rng.integers(1, maxdig + 1, Bn)
lb = rng.integers(1, maxdig + 1, Bn)
A = np.empty((Bn, S), np.int64); B = np.empty((Bn, S), np.int64); Y = np.empty((Bn, S), np.int64)
for n in range(Bn):
a = int(rng.integers(0, 10 ** int(la[n])))
b = int(rng.integers(0, 10 ** int(lb[n])))
A[n] = digits_rev(a, S); B[n] = digits_rev(b, S); Y[n] = digits_rev(a + b, S)
return A, B, Y
def gen_hard(Bn, S, rng):
"""Carry-heavy adversarial pairs the uniform sampler almost never produces:
long ripple chains. These teach the rare 'receive a carry and emit a new
most-significant digit' behaviour that full-length ripples (e.g. 9999999+1)
depend on. All sums are constructed to fit in S cells."""
A = np.zeros((Bn, S), np.int64); B = np.zeros((Bn, S), np.int64)
for n in range(Bn):
L = int(rng.integers(2, S)) # chain length, <= S-1
t = rng.random()
if t < 0.45: # sum-to-9 chain: any carry ripples the whole way
a = rng.integers(0, 10, L); b = 9 - a
elif t < 0.75: # all-nines operand(s)
a = np.full(L, 9)
b = np.full(L, 9) if rng.random() < 0.5 else rng.integers(0, 10, L)
else: # random short
a = rng.integers(0, 10, L); b = rng.integers(0, 10, L)
A[n, :L] = a; B[n, :L] = b
pw = 10 ** np.arange(S)
ai = (A * pw).sum(1); bi = (B * pw).sum(1)
bi = bi + (rng.random(Bn) < 0.6).astype(np.int64) # +1 trigger sets off the ripple
yi = ai + bi
A = np.array([digits_rev(int(x), S) for x in ai])
B = np.array([digits_rev(int(x), S) for x in bi])
Y = np.array([digits_rev(int(x), S) for x in yi])
return A, B, Y
def to_int(grid):
"""reversed digit grid (Bn,S) -> python ints"""
Bn, S = grid.shape
return [int(sum(int(grid[n, i]) * (10 ** i) for i in range(S))) for n in range(Bn)]
def exact_match(pred, Y):
"""fraction of rows where ALL digits match (true exact-match accuracy)."""
return float((pred == Y).all(axis=1).mean())