| |
| """gary-neuron: a ~34 KB asynchronous Neural Cellular Automaton whose per-cell |
| rule is a Mixture-of-Experts. It adds integers by letting carries ripple across |
| a 1-D mesh of cells. Pure numpy + stdlib, no deps, no tokenizer. |
| |
| Usage: |
| python solve.py 1234567 + 7654321 # solve a sum |
| python solve.py 9999999 1 --show # visualise the mesh firing + carry ripple |
| python solve.py --vote 9 48591 + 9732 # robust inference (ensemble over async orders) |
| python solve.py # interactive |
| """ |
| import json, sys, os, re |
| import numpy as np |
|
|
| D = os.path.dirname(os.path.abspath(__file__)) |
| C = json.load(open(f"{D}/config.json")) |
| S, d, K, TOPK = C["S"], C["state_dim"], C["n_experts"], C["topk"] |
| P_UPDATE = C["p_update"] |
| STEPS = C.get("recommended_inference_steps", 24) |
|
|
| z = np.load(f"{D}/gary-neuron.int8.npz") |
| W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")} |
|
|
| def _sm(x): |
| e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True) |
|
|
| def digits_rev(x): |
| out = np.zeros(S, np.int64) |
| for i in range(S): out[i] = x % 10; x //= 10 |
| return out |
|
|
| def to_int(row): |
| return int(sum(int(row[i]) * (10 ** i) for i in range(S))) |
|
|
| def mesh(A, B, steps=STEPS, p=P_UPDATE, seed=0, trace=False): |
| Bn = A.shape[0] |
| rng = np.random.default_rng(seed) |
| H = W["emb"][A] + W["emb"][B] + W["posemb"][None] |
| frames = [] |
| for t in range(steps): |
| Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1] |
| Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:] |
| perc = np.concatenate([Hl, H, Hr], -1) |
| perc = (perc - perc.mean(-1, keepdims=True)) / np.sqrt(perc.var(-1, keepdims=True) + 1e-5) |
| pf = perc.reshape(Bn * S, 3 * d) |
| rl = pf @ W["Wr"] + W["br"] |
| idx = np.argpartition(-rl, TOPK - 1, axis=1)[:, :TOPK] |
| M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1) |
| gate = _sm(rl + M) |
| mix = np.zeros((Bn * S, d), np.float32) |
| for e in range(K): |
| ge = gate[:, e]; act = ge > 0 |
| if act.any(): |
| h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0) |
| mix[act] += ge[act, None] * (h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"]) |
| um = (rng.random((Bn, S, 1)) < p).astype(np.float32) |
| H = H + um * mix.reshape(Bn, S, d) |
| if trace: |
| lg = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"] |
| frames.append((lg.reshape(Bn, S, -1).argmax(-1), |
| um[..., 0].astype(int), gate.argmax(1).reshape(Bn, S))) |
| logits = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"] |
| pred = logits.reshape(Bn, S, -1).argmax(-1) |
| return (pred, frames) if trace else pred |
|
|
| def solve(a, b, vote=1, steps=STEPS): |
| A = digits_rev(a)[None]; B = digits_rev(b)[None] |
| if vote <= 1: |
| return to_int(mesh(A, B, steps=steps, seed=0)[0]) |
| preds = np.stack([mesh(A, B, steps=steps, seed=s)[0] for s in range(vote)]) |
| maj = np.stack([(preds == g).sum(0) for g in range(10)]).argmax(0) |
| return to_int(maj) |
|
|
| def show(a, b, steps=STEPS): |
| A = digits_rev(a)[None]; B = digits_rev(b)[None] |
| pred, frames = mesh(A, B, steps=steps, seed=0, trace=True) |
| hdr = " ".join(f"{i}" for i in range(S - 1, -1, -1)) |
| print(f"\n {a} + {b} (mesh = {S} cells, {K} experts, top-{TOPK}, async p={P_UPDATE}, {steps} steps)") |
| print(f" digit place (10^): {hdr}") |
| print(" " + "-" * (4 * S + 20)) |
| for t, (pr, upd, exp) in enumerate(frames): |
| dig = " ".join(str(pr[0, i]) for i in range(S - 1, -1, -1)) |
| fire = " ".join((str(exp[0, i]) if upd[0, i] else "·") for i in range(S - 1, -1, -1)) |
| print(f" step {t:2d} digits: {dig} | fired(expert#): {fire} = {to_int(pr[0])}") |
| print(" " + "-" * (4 * S + 20)) |
| ans = to_int(pred[0]); truth = a + b |
| print(f" => {a} + {b} = {ans} {'OK' if ans == truth else 'X (truth ' + str(truth) + ')'}") |
| print(" '·' = cell did not fire this step; digits settle as the carry ripples low->high.\n") |
|
|
| def parse(s): |
| n = re.findall(r"\d+", s) |
| if len(n) < 2: return None |
| return int(n[0]), int(n[1]) |
|
|
| if __name__ == "__main__": |
| args = sys.argv[1:] |
| vote = 1; steps = STEPS; doshow = False |
| if "--show" in args: doshow = True; args.remove("--show") |
| if "--vote" in args: |
| i = args.index("--vote"); vote = int(args[i + 1]); del args[i:i + 2] |
| if "--steps" in args: |
| i = args.index("--steps"); steps = int(args[i + 1]); del args[i:i + 2] |
| if args: |
| pr = parse(" ".join(args)) |
| if not pr: |
| print("give me two non-negative integers, e.g. python solve.py 123 + 456"); sys.exit() |
| a, b = pr |
| if a + b >= 10 ** S: |
| print(f"(sum exceeds {S} digits - this mesh has {S} cells; train a wider strip for bigger sums)") |
| if doshow: show(a, b, steps=steps) |
| else: print(solve(a, b, vote=vote, steps=steps)) |
| else: |
| print("gary-neuron (~34 KB). Type a sum like '1234 + 5678'. Add '#' for the mesh view. ctrl-c to exit.") |
| while True: |
| try: msg = input("\nsum: ") |
| except (EOFError, KeyboardInterrupt): break |
| sh = "#" in msg; pr = parse(msg) |
| if not pr: continue |
| if sh: show(*pr) |
| else: print(f" = {solve(*pr, vote=9)}") |
|
|