gary-neuron / solve.py
gary23w's picture
gary-neuron: async NCA + top-2 MoE, 26k params, 99.97%/100% exact-match on 7-digit addition
57f9808 verified
#!/usr/bin/env python3
"""gary-neuron: a ~34 KB asynchronous Neural Cellular Automaton whose per-cell
rule is a Mixture-of-Experts. It adds integers by letting carries ripple across
a 1-D mesh of cells. Pure numpy + stdlib, no deps, no tokenizer.
Usage:
python solve.py 1234567 + 7654321 # solve a sum
python solve.py 9999999 1 --show # visualise the mesh firing + carry ripple
python solve.py --vote 9 48591 + 9732 # robust inference (ensemble over async orders)
python solve.py # interactive
"""
import json, sys, os, re
import numpy as np
D = os.path.dirname(os.path.abspath(__file__))
C = json.load(open(f"{D}/config.json"))
S, d, K, TOPK = C["S"], C["state_dim"], C["n_experts"], C["topk"]
P_UPDATE = C["p_update"]
STEPS = C.get("recommended_inference_steps", 24)
z = np.load(f"{D}/gary-neuron.int8.npz")
W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")}
def _sm(x):
e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True)
def digits_rev(x):
out = np.zeros(S, np.int64)
for i in range(S): out[i] = x % 10; x //= 10
return out
def to_int(row):
return int(sum(int(row[i]) * (10 ** i) for i in range(S)))
def mesh(A, B, steps=STEPS, p=P_UPDATE, seed=0, trace=False):
Bn = A.shape[0]
rng = np.random.default_rng(seed)
H = W["emb"][A] + W["emb"][B] + W["posemb"][None]
frames = []
for t in range(steps):
Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1]
Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:]
perc = np.concatenate([Hl, H, Hr], -1)
perc = (perc - perc.mean(-1, keepdims=True)) / np.sqrt(perc.var(-1, keepdims=True) + 1e-5)
pf = perc.reshape(Bn * S, 3 * d)
rl = pf @ W["Wr"] + W["br"]
idx = np.argpartition(-rl, TOPK - 1, axis=1)[:, :TOPK]
M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1)
gate = _sm(rl + M)
mix = np.zeros((Bn * S, d), np.float32)
for e in range(K):
ge = gate[:, e]; act = ge > 0
if act.any():
h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0)
mix[act] += ge[act, None] * (h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"])
um = (rng.random((Bn, S, 1)) < p).astype(np.float32)
H = H + um * mix.reshape(Bn, S, d)
if trace:
lg = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"]
frames.append((lg.reshape(Bn, S, -1).argmax(-1),
um[..., 0].astype(int), gate.argmax(1).reshape(Bn, S)))
logits = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"]
pred = logits.reshape(Bn, S, -1).argmax(-1)
return (pred, frames) if trace else pred
def solve(a, b, vote=1, steps=STEPS):
A = digits_rev(a)[None]; B = digits_rev(b)[None]
if vote <= 1:
return to_int(mesh(A, B, steps=steps, seed=0)[0])
preds = np.stack([mesh(A, B, steps=steps, seed=s)[0] for s in range(vote)]) # (vote,S)
maj = np.stack([(preds == g).sum(0) for g in range(10)]).argmax(0) # (S,)
return to_int(maj)
def show(a, b, steps=STEPS):
A = digits_rev(a)[None]; B = digits_rev(b)[None]
pred, frames = mesh(A, B, steps=steps, seed=0, trace=True)
hdr = " ".join(f"{i}" for i in range(S - 1, -1, -1))
print(f"\n {a} + {b} (mesh = {S} cells, {K} experts, top-{TOPK}, async p={P_UPDATE}, {steps} steps)")
print(f" digit place (10^): {hdr}")
print(" " + "-" * (4 * S + 20))
for t, (pr, upd, exp) in enumerate(frames):
dig = " ".join(str(pr[0, i]) for i in range(S - 1, -1, -1))
fire = " ".join((str(exp[0, i]) if upd[0, i] else "·") for i in range(S - 1, -1, -1))
print(f" step {t:2d} digits: {dig} | fired(expert#): {fire} = {to_int(pr[0])}")
print(" " + "-" * (4 * S + 20))
ans = to_int(pred[0]); truth = a + b
print(f" => {a} + {b} = {ans} {'OK' if ans == truth else 'X (truth ' + str(truth) + ')'}")
print(" '·' = cell did not fire this step; digits settle as the carry ripples low->high.\n")
def parse(s):
n = re.findall(r"\d+", s)
if len(n) < 2: return None
return int(n[0]), int(n[1])
if __name__ == "__main__":
args = sys.argv[1:]
vote = 1; steps = STEPS; doshow = False
if "--show" in args: doshow = True; args.remove("--show")
if "--vote" in args:
i = args.index("--vote"); vote = int(args[i + 1]); del args[i:i + 2]
if "--steps" in args:
i = args.index("--steps"); steps = int(args[i + 1]); del args[i:i + 2]
if args:
pr = parse(" ".join(args))
if not pr:
print("give me two non-negative integers, e.g. python solve.py 123 + 456"); sys.exit()
a, b = pr
if a + b >= 10 ** S:
print(f"(sum exceeds {S} digits - this mesh has {S} cells; train a wider strip for bigger sums)")
if doshow: show(a, b, steps=steps)
else: print(solve(a, b, vote=vote, steps=steps))
else:
print("gary-neuron (~34 KB). Type a sum like '1234 + 5678'. Add '#' for the mesh view. ctrl-c to exit.")
while True:
try: msg = input("\nsum: ")
except (EOFError, KeyboardInterrupt): break
sh = "#" in msg; pr = parse(msg)
if not pr: continue
if sh: show(*pr)
else: print(f" = {solve(*pr, vote=9)}")