#!/usr/bin/env python3 """gary-neuron: a ~34 KB asynchronous Neural Cellular Automaton whose per-cell rule is a Mixture-of-Experts. It adds integers by letting carries ripple across a 1-D mesh of cells. Pure numpy + stdlib, no deps, no tokenizer. Usage: python solve.py 1234567 + 7654321 # solve a sum python solve.py 9999999 1 --show # visualise the mesh firing + carry ripple python solve.py --vote 9 48591 + 9732 # robust inference (ensemble over async orders) python solve.py # interactive """ import json, sys, os, re import numpy as np D = os.path.dirname(os.path.abspath(__file__)) C = json.load(open(f"{D}/config.json")) S, d, K, TOPK = C["S"], C["state_dim"], C["n_experts"], C["topk"] P_UPDATE = C["p_update"] STEPS = C.get("recommended_inference_steps", 24) z = np.load(f"{D}/gary-neuron.int8.npz") W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")} def _sm(x): e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True) def digits_rev(x): out = np.zeros(S, np.int64) for i in range(S): out[i] = x % 10; x //= 10 return out def to_int(row): return int(sum(int(row[i]) * (10 ** i) for i in range(S))) def mesh(A, B, steps=STEPS, p=P_UPDATE, seed=0, trace=False): Bn = A.shape[0] rng = np.random.default_rng(seed) H = W["emb"][A] + W["emb"][B] + W["posemb"][None] frames = [] for t in range(steps): Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1] Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:] perc = np.concatenate([Hl, H, Hr], -1) perc = (perc - perc.mean(-1, keepdims=True)) / np.sqrt(perc.var(-1, keepdims=True) + 1e-5) pf = perc.reshape(Bn * S, 3 * d) rl = pf @ W["Wr"] + W["br"] idx = np.argpartition(-rl, TOPK - 1, axis=1)[:, :TOPK] M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1) gate = _sm(rl + M) mix = np.zeros((Bn * S, d), np.float32) for e in range(K): ge = gate[:, e]; act = ge > 0 if act.any(): h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0) mix[act] += ge[act, None] * (h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"]) um = (rng.random((Bn, S, 1)) < p).astype(np.float32) H = H + um * mix.reshape(Bn, S, d) if trace: lg = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"] frames.append((lg.reshape(Bn, S, -1).argmax(-1), um[..., 0].astype(int), gate.argmax(1).reshape(Bn, S))) logits = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"] pred = logits.reshape(Bn, S, -1).argmax(-1) return (pred, frames) if trace else pred def solve(a, b, vote=1, steps=STEPS): A = digits_rev(a)[None]; B = digits_rev(b)[None] if vote <= 1: return to_int(mesh(A, B, steps=steps, seed=0)[0]) preds = np.stack([mesh(A, B, steps=steps, seed=s)[0] for s in range(vote)]) # (vote,S) maj = np.stack([(preds == g).sum(0) for g in range(10)]).argmax(0) # (S,) return to_int(maj) def show(a, b, steps=STEPS): A = digits_rev(a)[None]; B = digits_rev(b)[None] pred, frames = mesh(A, B, steps=steps, seed=0, trace=True) hdr = " ".join(f"{i}" for i in range(S - 1, -1, -1)) print(f"\n {a} + {b} (mesh = {S} cells, {K} experts, top-{TOPK}, async p={P_UPDATE}, {steps} steps)") print(f" digit place (10^): {hdr}") print(" " + "-" * (4 * S + 20)) for t, (pr, upd, exp) in enumerate(frames): dig = " ".join(str(pr[0, i]) for i in range(S - 1, -1, -1)) fire = " ".join((str(exp[0, i]) if upd[0, i] else "·") for i in range(S - 1, -1, -1)) print(f" step {t:2d} digits: {dig} | fired(expert#): {fire} = {to_int(pr[0])}") print(" " + "-" * (4 * S + 20)) ans = to_int(pred[0]); truth = a + b print(f" => {a} + {b} = {ans} {'OK' if ans == truth else 'X (truth ' + str(truth) + ')'}") print(" '·' = cell did not fire this step; digits settle as the carry ripples low->high.\n") def parse(s): n = re.findall(r"\d+", s) if len(n) < 2: return None return int(n[0]), int(n[1]) if __name__ == "__main__": args = sys.argv[1:] vote = 1; steps = STEPS; doshow = False if "--show" in args: doshow = True; args.remove("--show") if "--vote" in args: i = args.index("--vote"); vote = int(args[i + 1]); del args[i:i + 2] if "--steps" in args: i = args.index("--steps"); steps = int(args[i + 1]); del args[i:i + 2] if args: pr = parse(" ".join(args)) if not pr: print("give me two non-negative integers, e.g. python solve.py 123 + 456"); sys.exit() a, b = pr if a + b >= 10 ** S: print(f"(sum exceeds {S} digits - this mesh has {S} cells; train a wider strip for bigger sums)") if doshow: show(a, b, steps=steps) else: print(solve(a, b, vote=vote, steps=steps)) else: print("gary-neuron (~34 KB). Type a sum like '1234 + 5678'. Add '#' for the mesh view. ctrl-c to exit.") while True: try: msg = input("\nsum: ") except (EOFError, KeyboardInterrupt): break sh = "#" in msg; pr = parse(msg) if not pr: continue if sh: show(*pr) else: print(f" = {solve(*pr, vote=9)}")