Quazim0t0 commited on
Commit
73dd4cf
·
verified ·
1 Parent(s): d4da65a

Upload 7 files

Browse files
agents/modmind/moe_gradio.py CHANGED
@@ -69,6 +69,10 @@ class SpikeWhaleMoE:
69
  self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo
70
  self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall
71
  self.link_meta = None
 
 
 
 
72
  self.reload()
73
 
74
  def reload(self):
@@ -99,6 +103,7 @@ class SpikeWhaleMoE:
99
  self.steps[slot] = step
100
  self._mtime[slot] = mt
101
  self._load_links()
 
102
  return list(self.models)
103
 
104
  def available(self):
@@ -116,6 +121,10 @@ class SpikeWhaleMoE:
116
  self.trained_link = self.trained_link.to(device)
117
  if self.bridge_asker is not None:
118
  self.bridge_asker = self.bridge_asker.to(device)
 
 
 
 
119
  cache = getattr(self, "_merge_cache", None)
120
  if cache is not None:
121
  self._merge_cache = (cache[0], cache[1].to(device))
@@ -179,6 +188,80 @@ class SpikeWhaleMoE:
179
  "with_latent": wl, "without_latent": nl}
180
  return # one bridge is enough for the panel demo
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def key_recall_available(self):
183
  return self.bridge_asker is not None and self.trained_link is not None
184
 
@@ -210,6 +293,36 @@ class SpikeWhaleMoE:
210
  examples.append((k, out, ok))
211
  return {"acc": correct / max(1, n), "examples": examples}
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  def consult_available(self):
214
  return self.trained_link is not None
215
 
@@ -338,6 +451,30 @@ class SpikeWhaleMoE:
338
  break
339
  return expert, tok.decode(ids[0, start:].tolist())
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  @torch.no_grad()
342
  def run(self, query: str, max_new: int = 160, temperature: float = 0.8):
343
  """Full pass: route -> fuse latents -> generate from the winner."""
 
69
  self.trained_link = None # the TRAINED bridge (train_link.py output), for the consult demo
70
  self.bridge_asker = None # the FULL fine-tuned asker, for reproducible key-recall
71
  self.link_meta = None
72
+ self.qa_link = None # the question->answer bridge (train_qa_link.py output)
73
+ self.qa_asker = None
74
+ self.qa_meta = None
75
+ self._qa_mtime = None
76
  self.reload()
77
 
78
  def reload(self):
 
103
  self.steps[slot] = step
104
  self._mtime[slot] = mt
105
  self._load_links()
106
+ self._load_qa()
107
  return list(self.models)
108
 
109
  def available(self):
 
121
  self.trained_link = self.trained_link.to(device)
122
  if self.bridge_asker is not None:
123
  self.bridge_asker = self.bridge_asker.to(device)
124
+ if self.qa_link is not None:
125
+ self.qa_link = self.qa_link.to(device)
126
+ if self.qa_asker is not None:
127
+ self.qa_asker = self.qa_asker.to(device)
128
  cache = getattr(self, "_merge_cache", None)
129
  if cache is not None:
130
  self._merge_cache = (cache[0], cache[1].to(device))
 
188
  "with_latent": wl, "without_latent": nl}
189
  return # one bridge is enough for the panel demo
190
 
191
+ def _load_qa(self):
192
+ """Load the question->answer bridge (train_qa_link.py output): a NEW RecursiveLink
193
+ + a fully fine-tuned asker that answer arithmetic shown only to the consultant.
194
+ mtime-cached (the file is ~190MB) and hot-reloaded as training improves it."""
195
+ a_dom, c_dom = LINKS[0]
196
+ path = CKPT_ROOT / "links" / f"qa__{a_dom}__from__{c_dom}.safetensors"
197
+ if not path.exists():
198
+ self.qa_link = self.qa_asker = self.qa_meta = None
199
+ self._qa_mtime = None
200
+ return
201
+ mt = os.path.getmtime(path)
202
+ if self._qa_mtime == mt and self.qa_link is not None:
203
+ return
204
+ a, c = DOMAIN2SLOT[a_dom], DOMAIN2SLOT[c_dom]
205
+ if a not in self.models or c not in self.models:
206
+ return
207
+ from safetensors.torch import load_file
208
+ from safetensors import safe_open
209
+ t = load_file(str(path), device=self.device)
210
+ t = {k: (v.float() if v.is_floating_point() else v) for k, v in t.items()}
211
+ with safe_open(str(path), framework="pt") as f:
212
+ md = f.metadata() or {}
213
+ link = RecursiveLink(d_latent=D_LATENT).to(self.device).eval()
214
+ link.load_state_dict({k[5:]: v for k, v in t.items() if k.startswith("link.")})
215
+ ask = SpikeWhaleLM(specialist_config(a_dom)).to(self.device).eval()
216
+ ask.load_state_dict({k[6:]: v for k, v in t.items() if k.startswith("asker.")})
217
+ self.qa_link, self.qa_asker = link, ask
218
+ self.qa_meta = {"asker": a, "consultant": c,
219
+ "ans_len": int(md.get("ans_len", 3)), "prompt": md.get("prompt", "ANS> "),
220
+ "holdout_exact": float(md.get("holdout_exact", "nan")),
221
+ "step": int(md.get("step", 0))}
222
+ self._qa_mtime = mt
223
+
224
+ def qa_available(self):
225
+ return self.qa_link is not None and self.qa_asker is not None
226
+
227
+ def qa_info(self):
228
+ return dict(self.qa_meta) if self.qa_meta else None
229
+
230
+ @torch.no_grad()
231
+ def ask_math(self, a: int, op: str, b: int, ablate: bool = False):
232
+ """Language answers an arithmetic question SHOWN ONLY to Math: the frozen
233
+ consultant encodes the question, the QA RecursiveLink carries it across, and the
234
+ QA asker decodes the answer digits autoregressively from the latent alone (its
235
+ own input is just the 'ANS> ' prompt -- the question never reaches it as text).
236
+ ablate=True zeros the latent: the asker then has no question at all."""
237
+ if not self.qa_available():
238
+ return {"error": "qa bridge not trained yet"}
239
+ meta = self.qa_meta
240
+ a, b = int(a), int(b)
241
+ if op not in ("+", "-", "*"):
242
+ return {"error": "op must be one of + - *"}
243
+ truth = {"+": a + b, "-": a - b, "*": a * b}[op]
244
+ if not (0 <= truth < 10 ** meta["ans_len"]):
245
+ return {"error": "answer out of the trained range"}
246
+ q = f"{a} {op} {b} ="
247
+ c_ids = torch.tensor([self.toks[meta["consultant"]].encode(q, add_special_tokens=False)],
248
+ device=self.device)
249
+ latent = self.models[meta["consultant"]](input_ids=c_ids).latent
250
+ inj = torch.zeros_like(self.qa_link(latent)) if ablate else self.qa_link(latent)
251
+ a_tok = self.toks[meta["asker"]]
252
+ ids = torch.tensor([a_tok.encode(meta["prompt"], add_special_tokens=False)],
253
+ device=self.device)
254
+ plen = ids.shape[1]
255
+ for _ in range(meta["ans_len"]):
256
+ logits = self.qa_asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
257
+ ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
258
+ digits = a_tok.decode(ids[0, plen:].tolist())
259
+ want = f"{truth:0{meta['ans_len']}d}"
260
+ return {"question": q, "digits": digits, "answer": digits.lstrip("0") or "0",
261
+ "truth": truth, "want": want,
262
+ "ok": [i < len(digits) and digits[i] == ch for i, ch in enumerate(want)],
263
+ "exact": digits == want}
264
+
265
  def key_recall_available(self):
266
  return self.bridge_asker is not None and self.trained_link is not None
267
 
 
293
  examples.append((k, out, ok))
294
  return {"acc": correct / max(1, n), "examples": examples}
295
 
296
+ @torch.no_grad()
297
+ def relay_secret(self, secret: str, ablate: bool = False):
298
+ """Interactive bridge demo: a USER-CHOSEN key is shown only to the consultant;
299
+ the asker reads it back from the latent alone (same mechanism as key_recall, but
300
+ the human picks the secret). Returns {secret, recovered, ok:[per-char bool],
301
+ aligned} -- aligned=False means the tokenizer fused some characters into
302
+ multi-char tokens the bridge never saw in training, so expect degradation."""
303
+ if not self.key_recall_available():
304
+ return {"error": "bridge unavailable"}
305
+ s = "".join(ch for ch in (secret or "") if ch in KEY_CHARS)
306
+ key_len = self.link_meta.get("key_len", 6)
307
+ if len(s) != key_len:
308
+ return {"error": f"need exactly {key_len} characters (letters and digits only)"}
309
+ a, c = self.link_meta["asker"], self.link_meta["consultant"]
310
+ a_tok, c_tok = self.toks[a], self.toks[c]
311
+ prompt = self.link_meta.get("prompt", "KEY> ")
312
+ plen = len(a_tok.encode(prompt, add_special_tokens=False))
313
+ c_ids = torch.tensor([c_tok.encode(s, add_special_tokens=False)], device=self.device)
314
+ a_ids = torch.tensor([a_tok.encode(prompt, add_special_tokens=False)
315
+ + a_tok.encode(s, add_special_tokens=False)], device=self.device)
316
+ aligned = c_ids.shape[1] == key_len and a_ids.shape[1] == plen + key_len
317
+ latent = self.models[c](input_ids=c_ids).latent
318
+ inj = torch.zeros_like(self.trained_link(latent)) if ablate else self.trained_link(latent)
319
+ logits = self.bridge_asker(input_ids=a_ids, inject_latent=inj).logits
320
+ pred = logits[:, plen - 1:plen - 1 + key_len, :].argmax(-1)[0]
321
+ out = a_tok.decode(pred.tolist())[:len(s)]
322
+ return {"secret": s, "recovered": out,
323
+ "ok": [i < len(out) and out[i] == ch for i, ch in enumerate(s)],
324
+ "aligned": aligned}
325
+
326
  def consult_available(self):
327
  return self.trained_link is not None
328
 
 
451
  break
452
  return expert, tok.decode(ids[0, start:].tolist())
453
 
454
+ @torch.no_grad()
455
+ def generate_stream(self, query: str, expert: str | None = None, max_new: int = 160,
456
+ temperature: float = 0.8, top_k: int = 40, chunk: int = 4):
457
+ """Like generate(), but yields (expert, text_so_far) as tokens arrive, so the UI
458
+ can show generation live instead of freezing until the whole thing is done."""
459
+ if expert is None:
460
+ expert, _, _ = self.route(query)
461
+ m, tok = self.models[expert], self.toks[expert]
462
+ ids = self._ids(expert, query)
463
+ start = ids.shape[1]
464
+ ctx_max = int(getattr(m.config, "max_position_embeddings", 2048))
465
+ for i in range(max_new):
466
+ logits = m(input_ids=ids[:, -ctx_max:]).logits[:, -1, :] / max(1e-5, temperature)
467
+ if top_k:
468
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
469
+ logits[logits < v[:, [-1]]] = -float("inf")
470
+ nxt = torch.multinomial(F.softmax(logits, dim=-1), 1)
471
+ ids = torch.cat([ids, nxt], dim=1)
472
+ done = tok.eos_token_id is not None and int(nxt.item()) == tok.eos_token_id
473
+ if done or i % chunk == chunk - 1 or i == max_new - 1:
474
+ yield expert, tok.decode(ids[0, start:].tolist())
475
+ if done:
476
+ break
477
+
478
  @torch.no_grad()
479
  def run(self, query: str, max_new: int = 160, temperature: float = 0.8):
480
  """Full pass: route -> fuse latents -> generate from the winner."""
agents/modmind/train_qa_link.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's
3
+ "Tell Math a secret" demo points at).
4
+
5
+ Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning
6
+ specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a
7
+ fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading
8
+ nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced
9
+ in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT
10
+ of training, so eval accuracy on them is generalization, not memorization. Ablating
11
+ the latent removes the question entirely -> accuracy collapses to the digit prior.
12
+
13
+ Saves links/qa__language__from__reasoning.safetensors in the same key style as the
14
+ key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load.
15
+
16
+ Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda]
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import hashlib
22
+ import json
23
+ import os
24
+ import random
25
+ import sys
26
+ import time
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+
31
+ HERE = os.path.dirname(os.path.abspath(__file__))
32
+ sys.path.insert(0, HERE)
33
+
34
+ from model import RecursiveLink, SpikeWhaleLM # noqa: E402
35
+ from specialist_presets import specialist_config # noqa: E402
36
+ from spike_tokenizer import SpikeTokenizer # noqa: E402
37
+
38
+ ASKER, CONSULTANT = "language", "reasoning"
39
+ D_LATENT = 256
40
+ PROMPT = "ANS> "
41
+ ANS_LEN = 3 # answers zero-padded to 3 digits ("077")
42
+ HOLDOUT_PCT = 8 # % of problems held out of training entirely
43
+ OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors")
44
+
45
+
46
+ # ---- the problem space --------------------------------------------------------
47
+ def all_problems():
48
+ """Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198."""
49
+ probs = []
50
+ for a in range(10, 100):
51
+ for b in range(10, 100):
52
+ probs.append((a, "+", b))
53
+ if a >= b:
54
+ probs.append((a, "-", b))
55
+ for a in range(2, 13):
56
+ for b in range(2, 13):
57
+ probs.append((a, "*", b))
58
+ return probs
59
+
60
+
61
+ def answer(a, op, b):
62
+ return {"+": a + b, "-": a - b, "*": a * b}[op]
63
+
64
+
65
+ def is_holdout(a, op, b, pct):
66
+ if pct <= 0:
67
+ return False
68
+ h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0]
69
+ return h % 100 < pct
70
+
71
+
72
+ def render(a, op, b):
73
+ return f"{a} {op} {b} ="
74
+
75
+
76
+ # ---- model loading (same pattern as moe_gradio.py) ------------------------------
77
+ def load_specialist(domain, device):
78
+ from safetensors.torch import load_file
79
+ ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors")
80
+ cfg = specialist_config(domain)
81
+ m = SpikeWhaleLM(cfg).to(device)
82
+ sd = load_file(ck, device=device)
83
+ sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()}
84
+ m.load_state_dict(sd)
85
+ tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json"))
86
+ return m, tok
87
+
88
+
89
+ # ---- training -------------------------------------------------------------------
90
+ def main():
91
+ ap = argparse.ArgumentParser()
92
+ ap.add_argument("--steps", type=int, default=4000)
93
+ ap.add_argument("--batch", type=int, default=128)
94
+ ap.add_argument("--link-lr", type=float, default=1e-3)
95
+ ap.add_argument("--asker-lr", type=float, default=1e-4)
96
+ ap.add_argument("--asker-wd", type=float, default=0.0)
97
+ ap.add_argument("--holdout", type=int, default=0,
98
+ help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)")
99
+ ap.add_argument("--eval-every", type=int, default=200)
100
+ ap.add_argument("--eval-n", type=int, default=256)
101
+ ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small
102
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
103
+ ap.add_argument("--seed", type=int, default=0)
104
+ ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over")
105
+ args = ap.parse_args()
106
+ dev = args.device
107
+ random.seed(args.seed); torch.manual_seed(args.seed)
108
+
109
+ print(f"[qa-link] device={dev}", flush=True)
110
+ consultant, c_tok = load_specialist(CONSULTANT, dev)
111
+ asker, a_tok = load_specialist(ASKER, dev)
112
+ consultant.eval()
113
+ for p in consultant.parameters():
114
+ p.requires_grad_(False)
115
+
116
+ # answer digits must be single tokens for the asker (position-aligned readout)
117
+ digit_ids = []
118
+ for d in "0123456789":
119
+ ids = a_tok.encode(d, add_special_tokens=False)
120
+ assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}"
121
+ digit_ids.append(ids[0])
122
+ prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False)
123
+ plen = len(prompt_ids)
124
+ print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True)
125
+
126
+ link = RecursiveLink(d_latent=D_LATENT).to(dev)
127
+ opt = torch.optim.AdamW([
128
+ {"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0},
129
+ {"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd},
130
+ ])
131
+
132
+ probs = all_problems()
133
+ train_pool = [p for p in probs if not is_holdout(*p, args.holdout)]
134
+ eval_pool = [p for p in probs if is_holdout(*p, args.holdout)]
135
+ memorize = args.holdout <= 0
136
+ if memorize:
137
+ eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table
138
+ print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True)
139
+ else:
140
+ print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True)
141
+ label = "accuracy" if memorize else "held-out exact"
142
+
143
+ @torch.no_grad()
144
+ def encode_questions(batch):
145
+ """Frozen consultant -> latents. Bucketed by token length (latent is a
146
+ mean-pool over positions, so padding would corrupt it)."""
147
+ idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch]
148
+ lat = torch.zeros(len(batch), D_LATENT, device=dev)
149
+ by_len = {}
150
+ for i, ids in enumerate(idss):
151
+ by_len.setdefault(len(ids), []).append(i)
152
+ for L, idx in by_len.items():
153
+ c_ids = torch.tensor([idss[i] for i in idx], device=dev)
154
+ lat[idx] = consultant(input_ids=c_ids).latent
155
+ return lat
156
+
157
+ def ans_tokens(p):
158
+ return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"]
159
+
160
+ @torch.no_grad()
161
+ def evaluate(pool, n, ablate=False):
162
+ """Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing).
163
+ Chunked to keep VRAM peaks small."""
164
+ asker.eval()
165
+ sample = random.sample(pool, min(n, len(pool)))
166
+ hit_e = hit_d = 0
167
+ for o in range(0, len(sample), args.eval_chunk):
168
+ chunk = sample[o:o + args.eval_chunk]
169
+ lat = encode_questions(chunk)
170
+ inj = torch.zeros_like(link(lat)) if ablate else link(lat)
171
+ ids = torch.tensor([prompt_ids] * len(chunk), device=dev)
172
+ for _ in range(ANS_LEN):
173
+ logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
174
+ ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
175
+ pred = ids[:, plen:]
176
+ tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev)
177
+ hit_e += int((pred == tgt).all(dim=1).sum())
178
+ hit_d += int((pred == tgt).sum())
179
+ asker.train()
180
+ return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN)
181
+
182
+ # resume from last.pt if a previous run died mid-flight
183
+ last_pt = OUT + ".last.pt"
184
+ best, start_step = -1.0, 0
185
+ if os.path.exists(last_pt) and not args.fresh:
186
+ st = torch.load(last_pt, map_location=dev, weights_only=False)
187
+ link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"])
188
+ opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"]
189
+ print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True)
190
+
191
+ t0 = time.time()
192
+ asker.train()
193
+ for step in range(start_step + 1, args.steps + 1):
194
+ batch = random.sample(train_pool, args.batch)
195
+ lat = encode_questions(batch)
196
+ inj = link(lat)
197
+ a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev)
198
+ labels = a_ids.clone()
199
+ labels[:, :plen] = -100 # loss only on the answer digits
200
+ out = asker(input_ids=a_ids, labels=labels, inject_latent=inj)
201
+ opt.zero_grad(); out.loss.backward(); opt.step()
202
+
203
+ if step % args.eval_every == 0 or step == args.steps:
204
+ ex, pd = evaluate(eval_pool, args.eval_n)
205
+ extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%"
206
+ print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} "
207
+ f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} "
208
+ f"[{time.time()-t0:.0f}s]", flush=True)
209
+ if ex > best:
210
+ best = ex
211
+ save(link, asker, ex, step, args, memorize)
212
+ print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True)
213
+ # resume checkpoint every eval, so a crash never loses more than eval_every steps
214
+ torch.save({"link": link.state_dict(), "asker": asker.state_dict(),
215
+ "opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp")
216
+ os.replace(last_pt + ".tmp", last_pt)
217
+
218
+ # final ablation numbers from the BEST saved bridge are written at save();
219
+ # report the last-step ablation here for the log.
220
+ ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True)
221
+ print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True)
222
+ print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True)
223
+
224
+
225
+ def save(link, asker, acc, step, args, memorize):
226
+ from safetensors.torch import save_file
227
+ os.makedirs(os.path.dirname(OUT), exist_ok=True)
228
+ t = {}
229
+ for k, v in link.state_dict().items():
230
+ t["link." + k] = v.detach().to("cpu", torch.float16).contiguous()
231
+ for k, v in asker.model.latent_inject.state_dict().items():
232
+ t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous()
233
+ for k, v in asker.state_dict().items():
234
+ t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous()
235
+ if v.is_floating_point() else v.detach().cpu().contiguous())
236
+ tmp = OUT + ".tmp"
237
+ save_file(t, tmp, metadata={
238
+ "kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT,
239
+ "asker": ASKER, "consultant": CONSULTANT,
240
+ "mode": "memorize" if memorize else "generalize",
241
+ "holdout_pct": str(args.holdout), "step": str(step),
242
+ # accuracy over the whole table (memorize) or held-out set (generalize)
243
+ "holdout_exact": f"{acc:.4f}",
244
+ "ops": json.dumps(["+", "-", "*"]),
245
+ })
246
+ os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()