Buckets:
| """Generate report graphs (from training logs) + montages into outputs/report_assets/.""" | |
| import os, re | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageDraw | |
| OUT = "outputs/report_assets" | |
| os.makedirs(OUT, exist_ok=True) | |
| STEPS = [0, 50, 100, 150, 200, 250, 300] | |
| def eval_curve(path): | |
| if not os.path.exists(path): | |
| return [] | |
| return [float(m.group(1)) for l in open(path) for m in [re.search(r"\beval_vel_loss=([0-9.]+)", l)] if m] | |
| def train_curve(path): # for the diverged run (07_train.py format: "step N/300 loss X") | |
| out = [] | |
| for l in open(path): | |
| m = re.search(r"step\s+(\d+)/\d+\s+loss\s+([0-9.]+)", l) | |
| if m: | |
| out.append((int(m.group(1)), float(m.group(2)))) | |
| return out | |
| RUNS = [ | |
| ("per-token drop-6", "outputs/train_v3_adamw/train.log", "#888888"), | |
| ("linattn drop-6 (simple)", "outputs/train_student_linattn_adamw_constant/train.log", "#1f77b4"), | |
| ("linattn drop-6 (RoPE+conv+ws)","outputs/train_student_linattn2_adamw_cosine/train.log", "#2ca02c"), | |
| ("linattn drop-8 (+focused+FFN)","outputs/train_student_linattn3_adamw_cosine/train.log", "#ff7f0e"), | |
| ("linattn drop-10 (mixed FFN)", "outputs/train_student_linattn4_adamw_cosine/train.log", "#d62728"), | |
| ] | |
| # --- Graph 1: recovery curves (eval loss vs step) --- | |
| plt.figure(figsize=(9, 5.5)) | |
| for name, path, c in RUNS: | |
| y = eval_curve(path) | |
| if not y: | |
| continue | |
| x = STEPS[:len(y)] | |
| plt.plot(x, y, marker="o", color=c, label=f"{name} (→{min(y):.3f})") | |
| plt.xlabel("training step"); plt.ylabel("held-out velocity-matching loss") | |
| plt.title("Recovery curves — surrogate-only training (lower = closer to teacher)") | |
| plt.grid(alpha=0.3); plt.legend(fontsize=8); plt.tight_layout() | |
| plt.savefig(f"{OUT}/recovery_curves.png", dpi=130); plt.close() | |
| # --- Graph 2: Muon vs AdamW A/B --- | |
| plt.figure(figsize=(8, 5)) | |
| for name, path, c in [("AdamW @1e-4", "outputs/train_v3_adamw/train.log", "#1f77b4"), | |
| ("Muon @2e-3", "outputs/train_v3_muon/train.log", "#d62728")]: | |
| y = eval_curve(path); x = STEPS[:len(y)] | |
| plt.plot(x, y, marker="o", color=c, label=f"{name} (→{min(y):.3f})") | |
| plt.xlabel("training step"); plt.ylabel("eval velocity-loss") | |
| plt.title("Optimizer A/B (surrogate-only, identical recipe): a statistical tie") | |
| plt.grid(alpha=0.3); plt.legend(); plt.tight_layout() | |
| plt.savefig(f"{OUT}/muon_vs_adamw.png", dpi=130); plt.close() | |
| # --- Graph 3: the divergence (training all weights, Muon 0.02) --- | |
| tc = train_curve("outputs/train_v2/train.log") | |
| if tc: | |
| plt.figure(figsize=(8, 5)) | |
| plt.plot([s for s, _ in tc], [v for _, v in tc], marker="o", color="#d62728") | |
| plt.axhline(tc[0][1], ls="--", color="gray", alpha=0.6, label="start") | |
| plt.xlabel("training step"); plt.ylabel("training loss") | |
| plt.title("Failure mode: training ALL weights (Muon lr=0.02) → divergence to noise") | |
| plt.grid(alpha=0.3); plt.legend(); plt.tight_layout() | |
| plt.savefig(f"{OUT}/divergence.png", dpi=130); plt.close() | |
| # --- Graph 4: the speed↔quality frontier --- | |
| # (name, wall_speedup, eval_loss, params_B, color, status) | |
| pts = [ | |
| ("per-token drop6", 1.19, 0.308, 3.158, "#888888", "ok"), | |
| ("linattn drop6 simple", 1.15, 0.253, 3.177, "#1f77b4", "ok"), | |
| ("linattn drop6 upgraded", 1.15, 0.231, 3.177, "#2ca02c", "best quality"), | |
| ("linattn drop8 +FFN", 1.20, 0.269, 2.995, "#ff7f0e", "best colors"), | |
| ("linattn drop10 mixed", 1.26, 0.322, 2.737, "#d62728", "aggressive"), | |
| ] | |
| plt.figure(figsize=(9, 5.5)) | |
| for n, s, l, p, c, st in pts: | |
| plt.scatter(s, l, s=(4.0 - p) * 700, color=c, alpha=0.65, edgecolors="k") | |
| plt.annotate(f"{n}\n{p:.2f}B", (s, l), fontsize=7.5, ha="center", | |
| xytext=(0, 14), textcoords="offset points") | |
| plt.scatter(1.45, 0.90, s=300, color="black", marker="X") | |
| plt.annotate("v1 per-token drop12\nCOLLAPSED", (1.45, 0.90), fontsize=7.5, ha="center", | |
| xytext=(0, -22), textcoords="offset points") | |
| plt.xlabel("wall-clock speedup vs teacher (512/4-step, batch-1, A100)") | |
| plt.ylabel("eval velocity-loss (lower = better)") | |
| plt.title("The speed ↔ quality frontier (bubble size = smaller model)") | |
| plt.grid(alpha=0.3); plt.tight_layout() | |
| plt.savefig(f"{OUT}/frontier.png", dpi=130); plt.close() | |
| # --- Graph 5: compute breakdown (measured) + warm-start quality --- | |
| fig, ax = plt.subplots(1, 2, figsize=(11, 4.2)) | |
| ax[0].bar(["20 single\nblocks", "5 double\nblocks"], [78, 21], color=["#1f77b4", "#ff7f0e"]) | |
| ax[0].set_ylabel("% of transformer compute (measured)") | |
| ax[0].set_title("Where compute goes (per-block: double 1.08× single)") | |
| for i, v in enumerate([78, 21]): | |
| ax[0].text(i, v + 1, f"{v}%", ha="center") | |
| ws = [("per-token", 0.90), ("linattn\n(identity init)", 1.00), ("+RoPE+conv\n+warmstart", 0.65), | |
| ("+FFN\n(warmstart)", 0.54)] | |
| ax[1].bar([a for a, _ in ws], [b for _, b in ws], color="#2ca02c") | |
| ax[1].set_ylabel("warm-start residual rel-err (lower=better)") | |
| ax[1].set_title("Surrogate warm-start: how much of the block it captures") | |
| for i, (_, v) in enumerate(ws): | |
| ax[1].text(i, v + 0.01, f"{v:.2f}", ha="center", fontsize=8) | |
| plt.tight_layout(); plt.savefig(f"{OUT}/breakdown.png", dpi=130); plt.close() | |
| # --- Montage: v1 collapse (teacher | v1 per-token drop12) --- | |
| def montage(cols, rows, path, cell=320, pad=26): | |
| c = Image.new("RGB", (len(cols) * cell, len(rows) * (cell + pad) + pad), "white") | |
| d = ImageDraw.Draw(c) | |
| for ri, (rlab, files) in enumerate(rows): | |
| d.text((6, ri * (cell + pad) + 6), rlab, fill="black") | |
| for ci, f in enumerate(files): | |
| if os.path.exists(f): | |
| c.paste(Image.open(f).resize((cell, cell)), (ci * cell, ri * (cell + pad) + pad)) | |
| if ri == len(rows) - 1: | |
| for ci, cl in enumerate(cols): | |
| d.text((ci * cell + 6, len(rows) * (cell + pad) + 4), cl, fill="black") | |
| c.save(path) | |
| montage(["teacher 4B", "v1 per-token drop-12"], | |
| [("bookshop", ["outputs/teacher_smoke/teacher_0.png", "outputs/student_smoke/student_0.png"]), | |
| ("lake", ["outputs/teacher_smoke/teacher_1.png", "outputs/student_smoke/student_1.png"])], | |
| f"{OUT}/v1_collapse.png") | |
| print("report assets written to", OUT) | |
| print(os.listdir(OUT)) | |
Xet Storage Details
- Size:
- 6.29 kB
- Xet hash:
- 97330e481f77ec0d71b695ced43e91d46c3837a084bf5b70ebc45526c7675030
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.