Buckets:

Mercity/FluxDistill / scripts /make_quant_report_assets.py
Pranav2748's picture
download
raw
9.28 kB
"""Generate W4A8 SVDQuant report figures + montages into outputs/quant_report_assets/.
Source of truth = the REAL measured values from the 2026-06-01 full 4×3 grid (300-img calib,
data/monet_cache), hard-coded here (the /tmp build logs are ephemeral).
Run: source .venv/bin/activate && python scripts/make_quant_report_assets.py
"""
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
OUT = "outputs/quant_report_assets"
os.makedirs(OUT, exist_ok=True)
# --------------------------------------------------------------------- DATA (real, 300-calib grid)
VARIANTS = ["plain", "plain_refine", "whiten", "whiten_refine"]
VLABEL = {"plain": "plain", "plain_refine": "plain+refine",
"whiten": "whiten", "whiten_refine": "whiten+refine"}
VCOL = {"plain": "#9e9e9e", "plain_refine": "#2ca02c",
"whiten": "#1f77b4", "whiten_refine": "#d62728"}
RANKS = [16, 32, 64]
SMALLER = {16: 3.67, 32: 3.59, 64: 3.43}
# eval velocity-loss (lower = closer to teacher)
GRID = {
16: {"plain": 0.0620, "plain_refine": 0.0655, "whiten": 0.0656, "whiten_refine": 0.0556},
32: {"plain": 0.0586, "plain_refine": 0.0574, "whiten": 0.0545, "whiten_refine": 0.0476},
64: {"plain": 0.0487, "plain_refine": 0.0446, "whiten": 0.0588, "whiten_refine": 0.0451},
}
SURGERY_BEST = 0.231 # best block-surgery eval-loss (shelved track), cross-track reference
BEST_DIR = "outputs/abl_c300_r64_plain_refine" # grid best (0.0446 @ 3.43x)
PROBES = ["storefront text", "mountain lake", "fisherman portrait", "neon street",
"chalkboard 'FRESH COFFEE'", "breakfast flat-lay (count)", "hand / five fingers",
"dewy spiderweb macro"]
# mechanism ablation ladder (rank-0 baselines + smoothed SVD plain cells), eval-loss
LADDER = [
("RTN W4A8\n(no smooth,\nno SVD)", 0.0573, "#d62728"),
("SmoothQuant\n(rank-0,\nα=0.5)", 0.0729, "#9e9e9e"),
("+SVD r16\nplain (α=0.5)", 0.0620, "#7fb8e6"),
("+SVD r32\nplain", 0.0586, "#4a98d4"),
("+SVD r64\nplain", 0.0487, "#1f77b4"),
("r64 plain\n+refine\n(grid best)", 0.0446, "#2ca02c"),
]
# --------------------------------------------------------- Graph 0: mechanism ladder
plt.figure(figsize=(10, 5.5))
xs = range(len(LADDER))
plt.bar(xs, [v for _, v, _ in LADDER], color=[c for *_, c in LADDER], edgecolor="k", linewidth=0.6, width=0.62)
for i, (_, v, _) in enumerate(LADDER):
plt.text(i, v + 0.0008, f"{v:.4f}", ha="center", fontsize=9.5, fontweight="bold")
plt.xticks(list(xs), [n for n, *_ in LADDER], fontsize=8.5)
plt.ylabel("held-out velocity-matching loss (lower = closer to teacher)")
plt.ylim(0, 0.083)
plt.title("Mechanism ablation — what each piece buys (300-img calib, W4A8)\n"
"⚠️ SmoothQuant (α=0.5) HURTS: RTN floor (no smooth) beats it AND the smoothed r16/r32 SVD cells")
plt.grid(axis="y", alpha=0.3)
plt.axhline(0.0573, color="#d62728", ls="--", lw=1, alpha=0.6)
plt.tight_layout(); plt.savefig(f"{OUT}/mechanism_ladder.png", dpi=140); plt.close()
# --------------------------------------------------------- Graph 1: the 4×3 grid (headline)
fig, ax = plt.subplots(figsize=(10, 5.5))
nV = len(VARIANTS)
group_w = 0.8
bw = group_w / nV
for vi, v in enumerate(VARIANTS):
xs = [ri + (vi - (nV - 1) / 2) * bw for ri in range(len(RANKS))]
ys = [GRID[r][v] for r in RANKS]
bars = ax.bar(xs, ys, width=bw * 0.95, color=VCOL[v], edgecolor="k", linewidth=0.5,
label=VLABEL[v])
for x, y in zip(xs, ys):
ax.text(x, y + 0.0007, f"{y:.4f}", ha="center", va="bottom", fontsize=7.5, rotation=90)
ax.set_xticks(range(len(RANKS)))
ax.set_xticklabels([f"rank {r}\n({SMALLER[r]}× smaller)" for r in RANKS])
ax.set_ylabel("held-out velocity-matching loss (lower = closer to teacher)")
ax.set_ylim(0, 0.075)
ax.set_title("W4A8 SVDQuant — full method×rank grid (300-img calib)\n"
"refine = reliable; whitening alone is non-monotonic; r64 plain+refine = best (0.0446)")
ax.legend(ncol=4, fontsize=9, loc="upper center")
ax.grid(axis="y", alpha=0.3)
plt.tight_layout(); plt.savefig(f"{OUT}/grid_4x3.png", dpi=140); plt.close()
# ------------------------------------------------- Graph 2: refine & whiten deltas vs plain
fig, ax = plt.subplots(figsize=(9, 5))
for v, c, mk in [("plain_refine", "#2ca02c", "o"), ("whiten", "#1f77b4", "s"),
("whiten_refine", "#d62728", "^")]:
deltas = [100 * (GRID[r][v] - GRID[r]["plain"]) / GRID[r]["plain"] for r in RANKS]
ax.plot(RANKS, deltas, mk + "-", color=c, lw=2, ms=9, label=VLABEL[v])
for r, d in zip(RANKS, deltas):
ax.annotate(f"{d:+.0f}%", (r, d), textcoords="offset points", xytext=(0, 8),
ha="center", fontsize=8, color=c)
ax.axhline(0, color="#9e9e9e", lw=1.5, ls="--", label="plain (baseline)")
ax.set_xticks(RANKS); ax.set_xlabel("low-rank branch rank")
ax.set_ylabel("eval-loss change vs plain SVD (%) — negative = better")
ax.set_title("Each upgrade vs plain, by rank: whitening alone is unreliable\n"
"(hurts r16 & r64, helps r32); refine reliable except r16; combo strongest at low rank")
ax.grid(alpha=0.3); ax.legend(fontsize=9)
plt.tight_layout(); plt.savefig(f"{OUT}/deltas_vs_plain.png", dpi=140); plt.close()
# ------------------------------------------------- Graph 3: cross-track context (why we pivoted)
plt.figure(figsize=(9, 5))
labels = ["surgery\ndrop-10", "surgery\ndrop-8", "surgery\ndrop-6\n(best)",
"W4A8 r16\nwhiten+ref", "W4A8 r32\nwhiten+ref", "W4A8 r64\nplain+ref\n(best)"]
vals = [0.322, 0.269, 0.231, GRID[16]["whiten_refine"], GRID[32]["whiten_refine"],
GRID[64]["plain_refine"]]
cols = ["#bbbbbb", "#bbbbbb", "#999999", "#1f77b4", "#2ca02c", "#d62728"]
plt.bar(range(len(vals)), vals, color=cols, edgecolor="k", linewidth=0.6)
for i, v in enumerate(vals):
plt.text(i, v + 0.006, f"{v:.3f}", ha="center", fontsize=9, fontweight="bold")
plt.xticks(range(len(vals)), labels, fontsize=8.5)
plt.ylabel("held-out velocity-matching loss (same metric, lower=better)")
plt.title("Why we pivoted: W4A8 SVDQuant lands ~4–5× closer to the teacher than the\n"
"entire (shelved) block-surgery frontier — at 3.4–3.7× weight compression")
plt.axhspan(0, 0.06, color="#2ca02c", alpha=0.06)
plt.grid(axis="y", alpha=0.3); plt.tight_layout()
plt.savefig(f"{OUT}/cross_track.png", dpi=140); plt.close()
# ------------------------------------------------- Graph 4: where the bits / compute go (rank-indep)
fig, ax = plt.subplots(1, 2, figsize=(11, 4.2))
comp = [("4-bit residual", 3072 * 3072 * 4 / 8 / 1e6),
("group scales", 3072 * (3072 // 64) * 2 / 1e6),
("low-rank (bf16,r32)", (32 * 3072 + 3072 * 32) * 2 / 1e6),
("smooth", 3072 * 2 / 1e6)]
ax[0].bar([a for a, _ in comp], [b for _, b in comp],
color=["#1f77b4", "#9cf", "#2ca02c", "#ccc"], edgecolor="k", linewidth=0.5)
ax[0].axhline(3072 * 3072 * 2 / 1e6, ls="--", color="k", alpha=0.5)
ax[0].text(2.4, 3072 * 3072 * 2 / 1e6 + 0.4, f"bf16 baseline {3072*3072*2/1e6:.1f} MB",
fontsize=8, ha="center")
ax[0].set_ylabel("MB (one 3072×3072 layer)")
ax[0].set_title("Where the bytes go (rank-32): 4-bit residual dominates")
for i, (_, b) in enumerate(comp):
ax[0].text(i, b + 0.3, f"{b:.2f}", ha="center", fontsize=8)
ax[0].tick_params(axis="x", labelsize=8)
rr = [16, 32, 64, 128]
share = [100.0 * r / 3072 for r in rr]
ax[1].plot(rr, share, "o-", color="#d62728", lw=2, ms=8)
for r, s in zip(rr, share):
ax[1].annotate(f"{s:.1f}%", (r, s), textcoords="offset points", xytext=(0, 9),
ha="center", fontsize=9)
ax[1].set_xlabel("rank"); ax[1].set_ylabel("low-rank branch FLOPs (% of the 4-bit GEMM)")
ax[1].set_title("Rank ≈ free on compute: branch FLOPs ∝ r/in_dim")
ax[1].set_xticks(rr); ax[1].grid(alpha=0.3)
plt.tight_layout(); plt.savefig(f"{OUT}/byte_compute_budget.png", dpi=140); plt.close()
# --------------------------------------------------------------------- Montage helper
def labeled_montage(rows, path):
imgs = [(lab, Image.open(p)) for lab, p in rows if os.path.exists(p)]
if not imgs:
print(" (skip montage, no images:", path, ")"); return
w = max(im.width for _, im in imgs)
pad_top = 24
scaled = [(lab, im.resize((w, int(im.height * w / im.width)))) for lab, im in imgs]
H = sum(s.height for _, s in scaled) + pad_top * len(scaled)
canvas = Image.new("RGB", (w, H), "white")
d = ImageDraw.Draw(canvas)
y = 0
for lab, s in scaled:
d.text((4, y + 5), lab, fill="black")
canvas.paste(s, (0, y + pad_top)); y += s.height + pad_top
canvas.save(path)
# Grid best (r64 plain+refine) across all 8 probes — teacher | quant per row
labeled_montage(
[(f"{PROBES[i]} — teacher | W4A8 r64 plain+refine (0.0446)", f"{BEST_DIR}/eval/cmp_{i}.png")
for i in range(8)],
f"{OUT}/montage_best_allprompts.png")
# Method comparison at r64 on the text probe (cmp_0): all 4 variants
labeled_montage(
[(f"{VLABEL[v]} ({GRID[64][v]:.4f}) — teacher | quant",
f"outputs/abl_c300_r64_{v}/eval/cmp_0.png") for v in VARIANTS],
f"{OUT}/montage_method_r64_text.png")
print("quant report assets written to", OUT)
for f in sorted(os.listdir(OUT)):
print(" ", f)

Xet Storage Details

Size:
9.28 kB
·
Xet hash:
cf3bdbdab2d2f2a81fd3b76f5fc3b838dd0e89b4ca1a0f757f506e5248a8bca8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.