Buckets:

Mercity/FluxDistill / scripts /make_report_assets.py
Pranav2748's picture
download
raw
6.29 kB
"""Generate report graphs (from training logs) + montages into outputs/report_assets/."""
import os, re
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
OUT = "outputs/report_assets"
os.makedirs(OUT, exist_ok=True)
STEPS = [0, 50, 100, 150, 200, 250, 300]
def eval_curve(path):
if not os.path.exists(path):
return []
return [float(m.group(1)) for l in open(path) for m in [re.search(r"\beval_vel_loss=([0-9.]+)", l)] if m]
def train_curve(path): # for the diverged run (07_train.py format: "step N/300 loss X")
out = []
for l in open(path):
m = re.search(r"step\s+(\d+)/\d+\s+loss\s+([0-9.]+)", l)
if m:
out.append((int(m.group(1)), float(m.group(2))))
return out
RUNS = [
("per-token drop-6", "outputs/train_v3_adamw/train.log", "#888888"),
("linattn drop-6 (simple)", "outputs/train_student_linattn_adamw_constant/train.log", "#1f77b4"),
("linattn drop-6 (RoPE+conv+ws)","outputs/train_student_linattn2_adamw_cosine/train.log", "#2ca02c"),
("linattn drop-8 (+focused+FFN)","outputs/train_student_linattn3_adamw_cosine/train.log", "#ff7f0e"),
("linattn drop-10 (mixed FFN)", "outputs/train_student_linattn4_adamw_cosine/train.log", "#d62728"),
]
# --- Graph 1: recovery curves (eval loss vs step) ---
plt.figure(figsize=(9, 5.5))
for name, path, c in RUNS:
y = eval_curve(path)
if not y:
continue
x = STEPS[:len(y)]
plt.plot(x, y, marker="o", color=c, label=f"{name} (→{min(y):.3f})")
plt.xlabel("training step"); plt.ylabel("held-out velocity-matching loss")
plt.title("Recovery curves — surrogate-only training (lower = closer to teacher)")
plt.grid(alpha=0.3); plt.legend(fontsize=8); plt.tight_layout()
plt.savefig(f"{OUT}/recovery_curves.png", dpi=130); plt.close()
# --- Graph 2: Muon vs AdamW A/B ---
plt.figure(figsize=(8, 5))
for name, path, c in [("AdamW @1e-4", "outputs/train_v3_adamw/train.log", "#1f77b4"),
("Muon @2e-3", "outputs/train_v3_muon/train.log", "#d62728")]:
y = eval_curve(path); x = STEPS[:len(y)]
plt.plot(x, y, marker="o", color=c, label=f"{name} (→{min(y):.3f})")
plt.xlabel("training step"); plt.ylabel("eval velocity-loss")
plt.title("Optimizer A/B (surrogate-only, identical recipe): a statistical tie")
plt.grid(alpha=0.3); plt.legend(); plt.tight_layout()
plt.savefig(f"{OUT}/muon_vs_adamw.png", dpi=130); plt.close()
# --- Graph 3: the divergence (training all weights, Muon 0.02) ---
tc = train_curve("outputs/train_v2/train.log")
if tc:
plt.figure(figsize=(8, 5))
plt.plot([s for s, _ in tc], [v for _, v in tc], marker="o", color="#d62728")
plt.axhline(tc[0][1], ls="--", color="gray", alpha=0.6, label="start")
plt.xlabel("training step"); plt.ylabel("training loss")
plt.title("Failure mode: training ALL weights (Muon lr=0.02) → divergence to noise")
plt.grid(alpha=0.3); plt.legend(); plt.tight_layout()
plt.savefig(f"{OUT}/divergence.png", dpi=130); plt.close()
# --- Graph 4: the speed↔quality frontier ---
# (name, wall_speedup, eval_loss, params_B, color, status)
pts = [
("per-token drop6", 1.19, 0.308, 3.158, "#888888", "ok"),
("linattn drop6 simple", 1.15, 0.253, 3.177, "#1f77b4", "ok"),
("linattn drop6 upgraded", 1.15, 0.231, 3.177, "#2ca02c", "best quality"),
("linattn drop8 +FFN", 1.20, 0.269, 2.995, "#ff7f0e", "best colors"),
("linattn drop10 mixed", 1.26, 0.322, 2.737, "#d62728", "aggressive"),
]
plt.figure(figsize=(9, 5.5))
for n, s, l, p, c, st in pts:
plt.scatter(s, l, s=(4.0 - p) * 700, color=c, alpha=0.65, edgecolors="k")
plt.annotate(f"{n}\n{p:.2f}B", (s, l), fontsize=7.5, ha="center",
xytext=(0, 14), textcoords="offset points")
plt.scatter(1.45, 0.90, s=300, color="black", marker="X")
plt.annotate("v1 per-token drop12\nCOLLAPSED", (1.45, 0.90), fontsize=7.5, ha="center",
xytext=(0, -22), textcoords="offset points")
plt.xlabel("wall-clock speedup vs teacher (512/4-step, batch-1, A100)")
plt.ylabel("eval velocity-loss (lower = better)")
plt.title("The speed ↔ quality frontier (bubble size = smaller model)")
plt.grid(alpha=0.3); plt.tight_layout()
plt.savefig(f"{OUT}/frontier.png", dpi=130); plt.close()
# --- Graph 5: compute breakdown (measured) + warm-start quality ---
fig, ax = plt.subplots(1, 2, figsize=(11, 4.2))
ax[0].bar(["20 single\nblocks", "5 double\nblocks"], [78, 21], color=["#1f77b4", "#ff7f0e"])
ax[0].set_ylabel("% of transformer compute (measured)")
ax[0].set_title("Where compute goes (per-block: double 1.08× single)")
for i, v in enumerate([78, 21]):
ax[0].text(i, v + 1, f"{v}%", ha="center")
ws = [("per-token", 0.90), ("linattn\n(identity init)", 1.00), ("+RoPE+conv\n+warmstart", 0.65),
("+FFN\n(warmstart)", 0.54)]
ax[1].bar([a for a, _ in ws], [b for _, b in ws], color="#2ca02c")
ax[1].set_ylabel("warm-start residual rel-err (lower=better)")
ax[1].set_title("Surrogate warm-start: how much of the block it captures")
for i, (_, v) in enumerate(ws):
ax[1].text(i, v + 0.01, f"{v:.2f}", ha="center", fontsize=8)
plt.tight_layout(); plt.savefig(f"{OUT}/breakdown.png", dpi=130); plt.close()
# --- Montage: v1 collapse (teacher | v1 per-token drop12) ---
def montage(cols, rows, path, cell=320, pad=26):
c = Image.new("RGB", (len(cols) * cell, len(rows) * (cell + pad) + pad), "white")
d = ImageDraw.Draw(c)
for ri, (rlab, files) in enumerate(rows):
d.text((6, ri * (cell + pad) + 6), rlab, fill="black")
for ci, f in enumerate(files):
if os.path.exists(f):
c.paste(Image.open(f).resize((cell, cell)), (ci * cell, ri * (cell + pad) + pad))
if ri == len(rows) - 1:
for ci, cl in enumerate(cols):
d.text((ci * cell + 6, len(rows) * (cell + pad) + 4), cl, fill="black")
c.save(path)
montage(["teacher 4B", "v1 per-token drop-12"],
[("bookshop", ["outputs/teacher_smoke/teacher_0.png", "outputs/student_smoke/student_0.png"]),
("lake", ["outputs/teacher_smoke/teacher_1.png", "outputs/student_smoke/student_1.png"])],
f"{OUT}/v1_collapse.png")
print("report assets written to", OUT)
print(os.listdir(OUT))

Xet Storage Details

Size:
6.29 kB
·
Xet hash:
97330e481f77ec0d71b695ced43e91d46c3837a084bf5b70ebc45526c7675030

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.