"""Reversed-digit addition data. The reversed format (least-significant digit first) is the Lee et al. 2023 trick: the model emits the LSB first, matching the direction carries actually propagate, which is what makes exact addition learnable at tiny scale. Every number is a fixed-width S-cell strip, zero-padded in the high positions, so cell i is always digit i (10^i place).""" import numpy as np def digits_rev(x, S): out = np.zeros(S, dtype=np.int64) for i in range(S): out[i] = x % 10; x //= 10 return out def make_batch(Bn, S, rng, maxdig=None): """A,B,Y int grids (Bn,S), reversed digits. Operand lengths ~Uniform[1,maxdig].""" maxdig = maxdig or (S - 1) la = rng.integers(1, maxdig + 1, Bn) lb = rng.integers(1, maxdig + 1, Bn) A = np.empty((Bn, S), np.int64); B = np.empty((Bn, S), np.int64); Y = np.empty((Bn, S), np.int64) for n in range(Bn): a = int(rng.integers(0, 10 ** int(la[n]))) b = int(rng.integers(0, 10 ** int(lb[n]))) A[n] = digits_rev(a, S); B[n] = digits_rev(b, S); Y[n] = digits_rev(a + b, S) return A, B, Y def gen_hard(Bn, S, rng): """Carry-heavy adversarial pairs the uniform sampler almost never produces: long ripple chains. These teach the rare 'receive a carry and emit a new most-significant digit' behaviour that full-length ripples (e.g. 9999999+1) depend on. All sums are constructed to fit in S cells.""" A = np.zeros((Bn, S), np.int64); B = np.zeros((Bn, S), np.int64) for n in range(Bn): L = int(rng.integers(2, S)) # chain length, <= S-1 t = rng.random() if t < 0.45: # sum-to-9 chain: any carry ripples the whole way a = rng.integers(0, 10, L); b = 9 - a elif t < 0.75: # all-nines operand(s) a = np.full(L, 9) b = np.full(L, 9) if rng.random() < 0.5 else rng.integers(0, 10, L) else: # random short a = rng.integers(0, 10, L); b = rng.integers(0, 10, L) A[n, :L] = a; B[n, :L] = b pw = 10 ** np.arange(S) ai = (A * pw).sum(1); bi = (B * pw).sum(1) bi = bi + (rng.random(Bn) < 0.6).astype(np.int64) # +1 trigger sets off the ripple yi = ai + bi A = np.array([digits_rev(int(x), S) for x in ai]) B = np.array([digits_rev(int(x), S) for x in bi]) Y = np.array([digits_rev(int(x), S) for x in yi]) return A, B, Y def to_int(grid): """reversed digit grid (Bn,S) -> python ints""" Bn, S = grid.shape return [int(sum(int(grid[n, i]) * (10 ** i) for i in range(S))) for n in range(Bn)] def exact_match(pred, Y): """fraction of rows where ALL digits match (true exact-match accuracy).""" return float((pred == Y).all(axis=1).mean())