| """Reversed-digit addition data. The reversed format (least-significant digit |
| first) is the Lee et al. 2023 trick: the model emits the LSB first, matching the |
| direction carries actually propagate, which is what makes exact addition |
| learnable at tiny scale. Every number is a fixed-width S-cell strip, zero-padded |
| in the high positions, so cell i is always digit i (10^i place).""" |
| import numpy as np |
|
|
| def digits_rev(x, S): |
| out = np.zeros(S, dtype=np.int64) |
| for i in range(S): |
| out[i] = x % 10; x //= 10 |
| return out |
|
|
| def make_batch(Bn, S, rng, maxdig=None): |
| """A,B,Y int grids (Bn,S), reversed digits. Operand lengths ~Uniform[1,maxdig].""" |
| maxdig = maxdig or (S - 1) |
| la = rng.integers(1, maxdig + 1, Bn) |
| lb = rng.integers(1, maxdig + 1, Bn) |
| A = np.empty((Bn, S), np.int64); B = np.empty((Bn, S), np.int64); Y = np.empty((Bn, S), np.int64) |
| for n in range(Bn): |
| a = int(rng.integers(0, 10 ** int(la[n]))) |
| b = int(rng.integers(0, 10 ** int(lb[n]))) |
| A[n] = digits_rev(a, S); B[n] = digits_rev(b, S); Y[n] = digits_rev(a + b, S) |
| return A, B, Y |
|
|
| def gen_hard(Bn, S, rng): |
| """Carry-heavy adversarial pairs the uniform sampler almost never produces: |
| long ripple chains. These teach the rare 'receive a carry and emit a new |
| most-significant digit' behaviour that full-length ripples (e.g. 9999999+1) |
| depend on. All sums are constructed to fit in S cells.""" |
| A = np.zeros((Bn, S), np.int64); B = np.zeros((Bn, S), np.int64) |
| for n in range(Bn): |
| L = int(rng.integers(2, S)) |
| t = rng.random() |
| if t < 0.45: |
| a = rng.integers(0, 10, L); b = 9 - a |
| elif t < 0.75: |
| a = np.full(L, 9) |
| b = np.full(L, 9) if rng.random() < 0.5 else rng.integers(0, 10, L) |
| else: |
| a = rng.integers(0, 10, L); b = rng.integers(0, 10, L) |
| A[n, :L] = a; B[n, :L] = b |
| pw = 10 ** np.arange(S) |
| ai = (A * pw).sum(1); bi = (B * pw).sum(1) |
| bi = bi + (rng.random(Bn) < 0.6).astype(np.int64) |
| yi = ai + bi |
| A = np.array([digits_rev(int(x), S) for x in ai]) |
| B = np.array([digits_rev(int(x), S) for x in bi]) |
| Y = np.array([digits_rev(int(x), S) for x in yi]) |
| return A, B, Y |
|
|
| def to_int(grid): |
| """reversed digit grid (Bn,S) -> python ints""" |
| Bn, S = grid.shape |
| return [int(sum(int(grid[n, i]) * (10 ** i) for i in range(S))) for n in range(Bn)] |
|
|
| def exact_match(pred, Y): |
| """fraction of rows where ALL digits match (true exact-match accuracy).""" |
| return float((pred == Y).all(axis=1).mean()) |
|
|