Spaces:
Running on Zero
Running on Zero
Upload 7 files
Browse files- agents/modmind/moe_gradio.py +137 -0
- agents/modmind/train_qa_link.py +250 -0
agents/modmind/moe_gradio.py
CHANGED
|
@@ -69,6 +69,10 @@ class SpikeWhaleMoE:
|
|
| 69 |
self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo
|
| 70 |
self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall
|
| 71 |
self.link_meta = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self.reload()
|
| 73 |
|
| 74 |
def reload(self):
|
|
@@ -99,6 +103,7 @@ class SpikeWhaleMoE:
|
|
| 99 |
self.steps[slot] = step
|
| 100 |
self._mtime[slot] = mt
|
| 101 |
self._load_links()
|
|
|
|
| 102 |
return list(self.models)
|
| 103 |
|
| 104 |
def available(self):
|
|
@@ -116,6 +121,10 @@ class SpikeWhaleMoE:
|
|
| 116 |
self.trained_link = self.trained_link.to(device)
|
| 117 |
if self.bridge_asker is not None:
|
| 118 |
self.bridge_asker = self.bridge_asker.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
cache = getattr(self, "_merge_cache", None)
|
| 120 |
if cache is not None:
|
| 121 |
self._merge_cache = (cache[0], cache[1].to(device))
|
|
@@ -179,6 +188,80 @@ class SpikeWhaleMoE:
|
|
| 179 |
"with_latent": wl, "without_latent": nl}
|
| 180 |
return # one bridge is enough for the panel demo
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
def key_recall_available(self):
|
| 183 |
return self.bridge_asker is not None and self.trained_link is not None
|
| 184 |
|
|
@@ -210,6 +293,36 @@ class SpikeWhaleMoE:
|
|
| 210 |
examples.append((k, out, ok))
|
| 211 |
return {"acc": correct / max(1, n), "examples": examples}
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
def consult_available(self):
|
| 214 |
return self.trained_link is not None
|
| 215 |
|
|
@@ -338,6 +451,30 @@ class SpikeWhaleMoE:
|
|
| 338 |
break
|
| 339 |
return expert, tok.decode(ids[0, start:].tolist())
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
@torch.no_grad()
|
| 342 |
def run(self, query: str, max_new: int = 160, temperature: float = 0.8):
|
| 343 |
"""Full pass: route -> fuse latents -> generate from the winner."""
|
|
|
|
| 69 |
self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo
|
| 70 |
self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall
|
| 71 |
self.link_meta = None
|
| 72 |
+
self.qa_link = None # the question->answer bridge (train_qa_link.py output)
|
| 73 |
+
self.qa_asker = None
|
| 74 |
+
self.qa_meta = None
|
| 75 |
+
self._qa_mtime = None
|
| 76 |
self.reload()
|
| 77 |
|
| 78 |
def reload(self):
|
|
|
|
| 103 |
self.steps[slot] = step
|
| 104 |
self._mtime[slot] = mt
|
| 105 |
self._load_links()
|
| 106 |
+
self._load_qa()
|
| 107 |
return list(self.models)
|
| 108 |
|
| 109 |
def available(self):
|
|
|
|
| 121 |
self.trained_link = self.trained_link.to(device)
|
| 122 |
if self.bridge_asker is not None:
|
| 123 |
self.bridge_asker = self.bridge_asker.to(device)
|
| 124 |
+
if self.qa_link is not None:
|
| 125 |
+
self.qa_link = self.qa_link.to(device)
|
| 126 |
+
if self.qa_asker is not None:
|
| 127 |
+
self.qa_asker = self.qa_asker.to(device)
|
| 128 |
cache = getattr(self, "_merge_cache", None)
|
| 129 |
if cache is not None:
|
| 130 |
self._merge_cache = (cache[0], cache[1].to(device))
|
|
|
|
| 188 |
"with_latent": wl, "without_latent": nl}
|
| 189 |
return # one bridge is enough for the panel demo
|
| 190 |
|
| 191 |
+
def _load_qa(self):
|
| 192 |
+
"""Load the question->answer bridge (train_qa_link.py output): a NEW RecursiveLink
|
| 193 |
+
+ a fully fine-tuned asker that answer arithmetic shown only to the consultant.
|
| 194 |
+
mtime-cached (the file is ~190MB) and hot-reloaded as training improves it."""
|
| 195 |
+
a_dom, c_dom = LINKS[0]
|
| 196 |
+
path = CKPT_ROOT / "links" / f"qa__{a_dom}__from__{c_dom}.safetensors"
|
| 197 |
+
if not path.exists():
|
| 198 |
+
self.qa_link = self.qa_asker = self.qa_meta = None
|
| 199 |
+
self._qa_mtime = None
|
| 200 |
+
return
|
| 201 |
+
mt = os.path.getmtime(path)
|
| 202 |
+
if self._qa_mtime == mt and self.qa_link is not None:
|
| 203 |
+
return
|
| 204 |
+
a, c = DOMAIN2SLOT[a_dom], DOMAIN2SLOT[c_dom]
|
| 205 |
+
if a not in self.models or c not in self.models:
|
| 206 |
+
return
|
| 207 |
+
from safetensors.torch import load_file
|
| 208 |
+
from safetensors import safe_open
|
| 209 |
+
t = load_file(str(path), device=self.device)
|
| 210 |
+
t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()}
|
| 211 |
+
with safe_open(str(path), framework="pt") as f:
|
| 212 |
+
md = f.metadata() or {}
|
| 213 |
+
link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval()
|
| 214 |
+
link.load_state_dict({k[5:]: v for k, v in t.items() if k.startswith("link.")})
|
| 215 |
+
ask = SpikeWhaleLM(specialist_config(a_dom)).to(self.device).eval()
|
| 216 |
+
ask.load_state_dict({k[6:]: v for k, v in t.items() if k.startswith("asker.")})
|
| 217 |
+
self.qa_link, self.qa_asker = link, ask
|
| 218 |
+
self.qa_meta = {"asker": a, "consultant": c,
|
| 219 |
+
"ans_len": int(md.get("ans_len", 3)), "prompt": md.get("prompt", "ANS> "),
|
| 220 |
+
"holdout_exact": float(md.get("holdout_exact", "nan")),
|
| 221 |
+
"step": int(md.get("step", 0))}
|
| 222 |
+
self._qa_mtime = mt
|
| 223 |
+
|
| 224 |
+
def qa_available(self):
|
| 225 |
+
return self.qa_link is not None and self.qa_asker is not None
|
| 226 |
+
|
| 227 |
+
def qa_info(self):
|
| 228 |
+
return dict(self.qa_meta) if self.qa_meta else None
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def ask_math(self, a: int, op: str, b: int, ablate: bool = False):
|
| 232 |
+
"""Language answers an arithmetic question SHOWN ONLY to Math: the frozen
|
| 233 |
+
consultant encodes the question, the QA RecursiveLink carries it across, and the
|
| 234 |
+
QA asker decodes the answer digits autoregressively from the latent alone (its
|
| 235 |
+
own input is just the 'ANS> ' prompt -- the question never reaches it as text).
|
| 236 |
+
ablate=True zeros the latent: the asker then has no question at all."""
|
| 237 |
+
if not self.qa_available():
|
| 238 |
+
return {"error": "qa bridge not trained yet"}
|
| 239 |
+
meta = self.qa_meta
|
| 240 |
+
a, b = int(a), int(b)
|
| 241 |
+
if op not in ("+", "-", "*"):
|
| 242 |
+
return {"error": "op must be one of + - *"}
|
| 243 |
+
truth = {"+": a + b, "-": a - b, "*": a * b}[op]
|
| 244 |
+
if not (0 <= truth < 10 ** meta["ans_len"]):
|
| 245 |
+
return {"error": "answer out of the trained range"}
|
| 246 |
+
q = f"{a} {op} {b} ="
|
| 247 |
+
c_ids = torch.tensor([self.toks[meta["consultant"]].encode(q, add_special_tokens=False)],
|
| 248 |
+
device=self.device)
|
| 249 |
+
latent = self.models[meta["consultant"]](input_ids=c_ids).latent
|
| 250 |
+
inj = torch.zeros_like(self.qa_link(latent)) if ablate else self.qa_link(latent)
|
| 251 |
+
a_tok = self.toks[meta["asker"]]
|
| 252 |
+
ids = torch.tensor([a_tok.encode(meta["prompt"], add_special_tokens=False)],
|
| 253 |
+
device=self.device)
|
| 254 |
+
plen = ids.shape[1]
|
| 255 |
+
for _ in range(meta["ans_len"]):
|
| 256 |
+
logits = self.qa_asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
|
| 257 |
+
ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
|
| 258 |
+
digits = a_tok.decode(ids[0, plen:].tolist())
|
| 259 |
+
want = f"{truth:0{meta['ans_len']}d}"
|
| 260 |
+
return {"question": q, "digits": digits, "answer": digits.lstrip("0") or "0",
|
| 261 |
+
"truth": truth, "want": want,
|
| 262 |
+
"ok": [i < len(digits) and digits[i] == ch for i, ch in enumerate(want)],
|
| 263 |
+
"exact": digits == want}
|
| 264 |
+
|
| 265 |
def key_recall_available(self):
|
| 266 |
return self.bridge_asker is not None and self.trained_link is not None
|
| 267 |
|
|
|
|
| 293 |
examples.append((k, out, ok))
|
| 294 |
return {"acc": correct / max(1, n), "examples": examples}
|
| 295 |
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def relay_secret(self, secret: str, ablate: bool = False):
|
| 298 |
+
"""Interactive bridge demo: a USER-CHOSEN key is shown only to the consultant;
|
| 299 |
+
the asker reads it back from the latent alone (same mechanism as key_recall, but
|
| 300 |
+
the human picks the secret). Returns {secret, recovered, ok:[per-char bool],
|
| 301 |
+
aligned} -- aligned=False means the tokenizer fused some characters into
|
| 302 |
+
multi-char tokens the bridge never saw in training, so expect degradation."""
|
| 303 |
+
if not self.key_recall_available():
|
| 304 |
+
return {"error": "bridge unavailable"}
|
| 305 |
+
s = "".join(ch for ch in (secret or "") if ch in KEY_CHARS)
|
| 306 |
+
key_len = self.link_meta.get("key_len", 6)
|
| 307 |
+
if len(s) != key_len:
|
| 308 |
+
return {"error": f"need exactly {key_len} characters (letters and digits only)"}
|
| 309 |
+
a, c = self.link_meta["asker"], self.link_meta["consultant"]
|
| 310 |
+
a_tok, c_tok = self.toks[a], self.toks[c]
|
| 311 |
+
prompt = self.link_meta.get("prompt", "KEY> ")
|
| 312 |
+
plen = len(a_tok.encode(prompt, add_special_tokens=False))
|
| 313 |
+
c_ids = torch.tensor([c_tok.encode(s, add_special_tokens=False)], device=self.device)
|
| 314 |
+
a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False)
|
| 315 |
+
+ a_tok.encode(s, add_special_tokens=False)], device=self.device)
|
| 316 |
+
aligned = c_ids.shape[1] == key_len and a_ids.shape[1] == plen + key_len
|
| 317 |
+
latent = self.models[c](input_ids=c_ids).latent
|
| 318 |
+
inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent)
|
| 319 |
+
logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits
|
| 320 |
+
pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0]
|
| 321 |
+
out = a_tok.decode(pred.tolist())[:len(s)]
|
| 322 |
+
return {"secret": s, "recovered": out,
|
| 323 |
+
"ok": [i < len(out) and out[i] == ch for i, ch in enumerate(s)],
|
| 324 |
+
"aligned": aligned}
|
| 325 |
+
|
| 326 |
def consult_available(self):
|
| 327 |
return self.trained_link is not None
|
| 328 |
|
|
|
|
| 451 |
break
|
| 452 |
return expert, tok.decode(ids[0, start:].tolist())
|
| 453 |
|
| 454 |
+
@torch.no_grad()
|
| 455 |
+
def generate_stream(self, query: str, expert: str | None = None, max_new: int = 160,
|
| 456 |
+
temperature: float = 0.8, top_k: int = 40, chunk: int = 4):
|
| 457 |
+
"""Like generate(), but yields (expert, text_so_far) as tokens arrive, so the UI
|
| 458 |
+
can show generation live instead of freezing until the whole thing is done."""
|
| 459 |
+
if expert is None:
|
| 460 |
+
expert, _, _ = self.route(query)
|
| 461 |
+
m, tok = self.models[expert], self.toks[expert]
|
| 462 |
+
ids = self._ids(expert, query)
|
| 463 |
+
start = ids.shape[1]
|
| 464 |
+
ctx_max = int(getattr(m.config, "max_position_embeddings", 2048))
|
| 465 |
+
for i in range(max_new):
|
| 466 |
+
logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature)
|
| 467 |
+
if top_k:
|
| 468 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 469 |
+
logits[logits < v[:, [-1]]] = -float("inf")
|
| 470 |
+
nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
|
| 471 |
+
ids = torch.cat([ids, nxt], dim=1)
|
| 472 |
+
done = tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id
|
| 473 |
+
if done or i % chunk == chunk - 1 or i == max_new - 1:
|
| 474 |
+
yield expert, tok.decode(ids[0, start:].tolist())
|
| 475 |
+
if done:
|
| 476 |
+
break
|
| 477 |
+
|
| 478 |
@torch.no_grad()
|
| 479 |
def run(self, query: str, max_new: int = 160, temperature: float = 0.8):
|
| 480 |
"""Full pass: route -> fuse latents -> generate from the winner."""
|
agents/modmind/train_qa_link.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's
|
| 3 |
+
"Tell Math a secret" demo points at).
|
| 4 |
+
|
| 5 |
+
Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning
|
| 6 |
+
specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a
|
| 7 |
+
fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading
|
| 8 |
+
nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced
|
| 9 |
+
in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT
|
| 10 |
+
of training, so eval accuracy on them is generalization, not memorization. Ablating
|
| 11 |
+
the latent removes the question entirely -> accuracy collapses to the digit prior.
|
| 12 |
+
|
| 13 |
+
Saves links/qa__language__from__reasoning.safetensors in the same key style as the
|
| 14 |
+
key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load.
|
| 15 |
+
|
| 16 |
+
Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda]
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import hashlib
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
HERE = os.path.dirname(os.path.abspath(__file__))
|
| 32 |
+
sys.path.insert(0, HERE)
|
| 33 |
+
|
| 34 |
+
from model import RecursiveLink, SpikeWhaleLM # noqa: E402
|
| 35 |
+
from specialist_presets import specialist_config # noqa: E402
|
| 36 |
+
from spike_tokenizer import SpikeTokenizer # noqa: E402
|
| 37 |
+
|
| 38 |
+
ASKER, CONSULTANT = "language", "reasoning"
|
| 39 |
+
D_LATENT = 256
|
| 40 |
+
PROMPT = "ANS> "
|
| 41 |
+
ANS_LEN = 3 # answers zero-padded to 3 digits ("077")
|
| 42 |
+
HOLDOUT_PCT = 8 # % of problems held out of training entirely
|
| 43 |
+
OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---- the problem space --------------------------------------------------------
|
| 47 |
+
def all_problems():
|
| 48 |
+
"""Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198."""
|
| 49 |
+
probs = []
|
| 50 |
+
for a in range(10, 100):
|
| 51 |
+
for b in range(10, 100):
|
| 52 |
+
probs.append((a, "+", b))
|
| 53 |
+
if a >= b:
|
| 54 |
+
probs.append((a, "-", b))
|
| 55 |
+
for a in range(2, 13):
|
| 56 |
+
for b in range(2, 13):
|
| 57 |
+
probs.append((a, "*", b))
|
| 58 |
+
return probs
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def answer(a, op, b):
|
| 62 |
+
return {"+": a + b, "-": a - b, "*": a * b}[op]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def is_holdout(a, op, b, pct):
|
| 66 |
+
if pct <= 0:
|
| 67 |
+
return False
|
| 68 |
+
h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0]
|
| 69 |
+
return h % 100 < pct
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def render(a, op, b):
|
| 73 |
+
return f"{a} {op} {b} ="
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ---- model loading (same pattern as moe_gradio.py) ------------------------------
|
| 77 |
+
def load_specialist(domain, device):
|
| 78 |
+
from safetensors.torch import load_file
|
| 79 |
+
ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors")
|
| 80 |
+
cfg = specialist_config(domain)
|
| 81 |
+
m = SpikeWhaleLM(cfg).to(device)
|
| 82 |
+
sd = load_file(ck, device=device)
|
| 83 |
+
sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()}
|
| 84 |
+
m.load_state_dict(sd)
|
| 85 |
+
tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json"))
|
| 86 |
+
return m, tok
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ---- training -------------------------------------------------------------------
|
| 90 |
+
def main():
|
| 91 |
+
ap = argparse.ArgumentParser()
|
| 92 |
+
ap.add_argument("--steps", type=int, default=4000)
|
| 93 |
+
ap.add_argument("--batch", type=int, default=128)
|
| 94 |
+
ap.add_argument("--link-lr", type=float, default=1e-3)
|
| 95 |
+
ap.add_argument("--asker-lr", type=float, default=1e-4)
|
| 96 |
+
ap.add_argument("--asker-wd", type=float, default=0.0)
|
| 97 |
+
ap.add_argument("--holdout", type=int, default=0,
|
| 98 |
+
help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)")
|
| 99 |
+
ap.add_argument("--eval-every", type=int, default=200)
|
| 100 |
+
ap.add_argument("--eval-n", type=int, default=256)
|
| 101 |
+
ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small
|
| 102 |
+
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
ap.add_argument("--seed", type=int, default=0)
|
| 104 |
+
ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over")
|
| 105 |
+
args = ap.parse_args()
|
| 106 |
+
dev = args.device
|
| 107 |
+
random.seed(args.seed); torch.manual_seed(args.seed)
|
| 108 |
+
|
| 109 |
+
print(f"[qa-link] device={dev}", flush=True)
|
| 110 |
+
consultant, c_tok = load_specialist(CONSULTANT, dev)
|
| 111 |
+
asker, a_tok = load_specialist(ASKER, dev)
|
| 112 |
+
consultant.eval()
|
| 113 |
+
for p in consultant.parameters():
|
| 114 |
+
p.requires_grad_(False)
|
| 115 |
+
|
| 116 |
+
# answer digits must be single tokens for the asker (position-aligned readout)
|
| 117 |
+
digit_ids = []
|
| 118 |
+
for d in "0123456789":
|
| 119 |
+
ids = a_tok.encode(d, add_special_tokens=False)
|
| 120 |
+
assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}"
|
| 121 |
+
digit_ids.append(ids[0])
|
| 122 |
+
prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False)
|
| 123 |
+
plen = len(prompt_ids)
|
| 124 |
+
print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True)
|
| 125 |
+
|
| 126 |
+
link = RecursiveLink(d_latent=D_LATENT).to(dev)
|
| 127 |
+
opt = torch.optim.AdamW([
|
| 128 |
+
{"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0},
|
| 129 |
+
{"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd},
|
| 130 |
+
])
|
| 131 |
+
|
| 132 |
+
probs = all_problems()
|
| 133 |
+
train_pool = [p for p in probs if not is_holdout(*p, args.holdout)]
|
| 134 |
+
eval_pool = [p for p in probs if is_holdout(*p, args.holdout)]
|
| 135 |
+
memorize = args.holdout <= 0
|
| 136 |
+
if memorize:
|
| 137 |
+
eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table
|
| 138 |
+
print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True)
|
| 139 |
+
else:
|
| 140 |
+
print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True)
|
| 141 |
+
label = "accuracy" if memorize else "held-out exact"
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def encode_questions(batch):
|
| 145 |
+
"""Frozen consultant -> latents. Bucketed by token length (latent is a
|
| 146 |
+
mean-pool over positions, so padding would corrupt it)."""
|
| 147 |
+
idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch]
|
| 148 |
+
lat = torch.zeros(len(batch), D_LATENT, device=dev)
|
| 149 |
+
by_len = {}
|
| 150 |
+
for i, ids in enumerate(idss):
|
| 151 |
+
by_len.setdefault(len(ids), []).append(i)
|
| 152 |
+
for L, idx in by_len.items():
|
| 153 |
+
c_ids = torch.tensor([idss[i] for i in idx], device=dev)
|
| 154 |
+
lat[idx] = consultant(input_ids=c_ids).latent
|
| 155 |
+
return lat
|
| 156 |
+
|
| 157 |
+
def ans_tokens(p):
|
| 158 |
+
return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"]
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def evaluate(pool, n, ablate=False):
|
| 162 |
+
"""Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing).
|
| 163 |
+
Chunked to keep VRAM peaks small."""
|
| 164 |
+
asker.eval()
|
| 165 |
+
sample = random.sample(pool, min(n, len(pool)))
|
| 166 |
+
hit_e = hit_d = 0
|
| 167 |
+
for o in range(0, len(sample), args.eval_chunk):
|
| 168 |
+
chunk = sample[o:o + args.eval_chunk]
|
| 169 |
+
lat = encode_questions(chunk)
|
| 170 |
+
inj = torch.zeros_like(link(lat)) if ablate else link(lat)
|
| 171 |
+
ids = torch.tensor([prompt_ids] * len(chunk), device=dev)
|
| 172 |
+
for _ in range(ANS_LEN):
|
| 173 |
+
logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
|
| 174 |
+
ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
|
| 175 |
+
pred = ids[:, plen:]
|
| 176 |
+
tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev)
|
| 177 |
+
hit_e += int((pred == tgt).all(dim=1).sum())
|
| 178 |
+
hit_d += int((pred == tgt).sum())
|
| 179 |
+
asker.train()
|
| 180 |
+
return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN)
|
| 181 |
+
|
| 182 |
+
# resume from last.pt if a previous run died mid-flight
|
| 183 |
+
last_pt = OUT + ".last.pt"
|
| 184 |
+
best, start_step = -1.0, 0
|
| 185 |
+
if os.path.exists(last_pt) and not args.fresh:
|
| 186 |
+
st = torch.load(last_pt, map_location=dev, weights_only=False)
|
| 187 |
+
link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"])
|
| 188 |
+
opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"]
|
| 189 |
+
print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True)
|
| 190 |
+
|
| 191 |
+
t0 = time.time()
|
| 192 |
+
asker.train()
|
| 193 |
+
for step in range(start_step + 1, args.steps + 1):
|
| 194 |
+
batch = random.sample(train_pool, args.batch)
|
| 195 |
+
lat = encode_questions(batch)
|
| 196 |
+
inj = link(lat)
|
| 197 |
+
a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev)
|
| 198 |
+
labels = a_ids.clone()
|
| 199 |
+
labels[:, :plen] = -100 # loss only on the answer digits
|
| 200 |
+
out = asker(input_ids=a_ids, labels=labels, inject_latent=inj)
|
| 201 |
+
opt.zero_grad(); out.loss.backward(); opt.step()
|
| 202 |
+
|
| 203 |
+
if step % args.eval_every == 0 or step == args.steps:
|
| 204 |
+
ex, pd = evaluate(eval_pool, args.eval_n)
|
| 205 |
+
extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%"
|
| 206 |
+
print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} "
|
| 207 |
+
f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} "
|
| 208 |
+
f"[{time.time()-t0:.0f}s]", flush=True)
|
| 209 |
+
if ex > best:
|
| 210 |
+
best = ex
|
| 211 |
+
save(link, asker, ex, step, args, memorize)
|
| 212 |
+
print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True)
|
| 213 |
+
# resume checkpoint every eval, so a crash never loses more than eval_every steps
|
| 214 |
+
torch.save({"link": link.state_dict(), "asker": asker.state_dict(),
|
| 215 |
+
"opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp")
|
| 216 |
+
os.replace(last_pt + ".tmp", last_pt)
|
| 217 |
+
|
| 218 |
+
# final ablation numbers from the BEST saved bridge are written at save();
|
| 219 |
+
# report the last-step ablation here for the log.
|
| 220 |
+
ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True)
|
| 221 |
+
print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True)
|
| 222 |
+
print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def save(link, asker, acc, step, args, memorize):
|
| 226 |
+
from safetensors.torch import save_file
|
| 227 |
+
os.makedirs(os.path.dirname(OUT), exist_ok=True)
|
| 228 |
+
t = {}
|
| 229 |
+
for k, v in link.state_dict().items():
|
| 230 |
+
t["link." + k] = v.detach().to("cpu", torch.float16).contiguous()
|
| 231 |
+
for k, v in asker.model.latent_inject.state_dict().items():
|
| 232 |
+
t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous()
|
| 233 |
+
for k, v in asker.state_dict().items():
|
| 234 |
+
t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous()
|
| 235 |
+
if v.is_floating_point() else v.detach().cpu().contiguous())
|
| 236 |
+
tmp = OUT + ".tmp"
|
| 237 |
+
save_file(t, tmp, metadata={
|
| 238 |
+
"kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT,
|
| 239 |
+
"asker": ASKER, "consultant": CONSULTANT,
|
| 240 |
+
"mode": "memorize" if memorize else "generalize",
|
| 241 |
+
"holdout_pct": str(args.holdout), "step": str(step),
|
| 242 |
+
# accuracy over the whole table (memorize) or held-out set (generalize)
|
| 243 |
+
"holdout_exact": f"{acc:.4f}",
|
| 244 |
+
"ops": json.dumps(["+", "-", "*"]),
|
| 245 |
+
})
|
| 246 |
+
os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|