Buckets:

Mercity/FluxDistill / scripts /13_eval_svdquant.py
Pranav2748's picture
download
raw
5.05 kB
"""Evaluate a W4A8 SVDQuant student: held-out velocity-matching loss vs the teacher
(directly comparable to the block-surgery numbers in RESULTS.md) + teacher-vs-quant
image montages.
Usage: python3 -u scripts/13_eval_svdquant.py [QUANT_DIR=outputs/svdquant_r32_...]
"""
import json
import os
import sys
import torch
from flux2distill.data import LatentCaptionDataset
from flux2distill.losses import velocity_match_loss, build_x_t
from flux2distill.model_utils import load_pipeline, load_transformer
from flux2distill.eval_utils import side_by_side
from flux2distill import svdquant as sq
QUANT_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("QUANT_DIR", "outputs/svdquant_r32_a0.5_w4a8")
cfg = json.load(open(f"{QUANT_DIR}/quant_config.json"))
OUT = f"{QUANT_DIR}/eval"
os.makedirs(OUT, exist_ok=True)
print(f"=== eval {QUANT_DIR}: W{cfg['w_bits']}A{cfg['a_bits']} rank={cfg['rank']} "
f"alpha={cfg['alpha']} ===", flush=True)
pipe = load_pipeline(device="cuda")
teacher = pipe.transformer
teacher.eval().requires_grad_(False)
# rebuild the quantized student and load weights
student = load_transformer(dtype="bfloat16", device="cuda").eval()
sq.apply_svdquant_empty(student, cfg["specs"], w_bits=cfg["w_bits"], a_bits=cfg["a_bits"],
w_group=cfg["w_group"], dtype=torch.bfloat16)
missing, unexpected = student.load_state_dict(torch.load(f"{QUANT_DIR}/quant_state.pt",
map_location="cuda"), strict=False)
assert not unexpected, f"unexpected keys: {unexpected[:5]}"
student.requires_grad_(False)
# ---- fixed held-out eval batch (SAME construction as 08_train_recover for continuity) ----
ds = LatentCaptionDataset(cache_dir="data/monet_cache")
EVAL_N = 16
ev_x0 = ds.latents[:EVAL_N].to("cuda", torch.bfloat16)
ev_caps = ds.captions[:EVAL_N]
with torch.no_grad():
ev_pe, ev_tid = pipe.encode_prompt(ev_caps, device="cuda")
ev_gen = torch.Generator(device="cuda").manual_seed(1234)
ev_eps = torch.randn(ev_x0.shape, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_sigma = torch.rand(EVAL_N, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_xt = build_x_t(ev_x0.float(), ev_eps, ev_sigma).to(torch.bfloat16)
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
torch.Generator(device="cuda").manual_seed(0))
def velocity(tf, x_t, sigma, pe, tid):
out = tf(hidden_states=x_t, timestep=sigma, guidance=None,
encoder_hidden_states=pe, txt_ids=tid, img_ids=img_ids, return_dict=False)[0]
return out[:, : x_t.size(1)]
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
vt = velocity(teacher, ev_xt, ev_sigma, ev_pe, ev_tid)
vs = velocity(student, ev_xt, ev_sigma, ev_pe, ev_tid)
loss = float(velocity_match_loss(vs, vt))
# also report the relative L2 error of the velocity field
rel = float((vs.float() - vt.float()).norm() / (vt.float().norm() + 1e-8))
print(f"eval_vel_loss={loss:.4f} vel_rel_err={rel:.4f}", flush=True)
summ = cfg["summary"]
print(f"quant: {summ['n_quant_layers']} layers, effective {summ['quant_MB']:.0f}MB "
f"({summ['ratio']:.2f}x), weight-recon mean rel-err {cfg['diag']['mean_rel_err']:.4f}",
flush=True)
# ---- image montages: teacher vs quant on probe prompts ----
# First 4 are the original probes (kept for continuity); next 4 added 2026-06-01 for
# richer visual comparison — a 2nd text case, multi-object composition/counting, a
# hands/face case, and a fine-texture macro (the capabilities quant is most likely to bend).
PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
"a photorealistic portrait of an elderly fisherman, weathered face, sharp detail",
"a bustling tokyo street at night, neon signs, rain-slicked pavement, reflections",
'a hand-lettered chalkboard cafe sign that reads "FRESH COFFEE" with small daily specials below',
"a flat-lay breakfast table from above: three fried eggs, two strips of bacon, a glass of orange juice, and a small vase with one sunflower",
"a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand",
"an extreme macro of a dewy spider web at dawn, water droplets catching golden light, crisp detail",
]
@torch.no_grad()
def gen(tf):
pipe.transformer = tf
g = torch.Generator(device="cuda").manual_seed(0)
with torch.autocast("cuda", dtype=torch.bfloat16):
return pipe(prompt=PROMPTS, num_inference_steps=4, guidance_scale=1.0,
height=512, width=512, generator=g).images
t_imgs = gen(teacher)
q_imgs = gen(student)
for i, (t, q) in enumerate(zip(t_imgs, q_imgs)):
side_by_side(t, q, "teacher", f"W{cfg['w_bits']}A{cfg['a_bits']} r{cfg['rank']}",
PROMPTS[i]).save(f"{OUT}/cmp_{i}.png")
print(f"saved {len(PROMPTS)} montages -> {OUT}/", flush=True)

Xet Storage Details

Size:
5.05 kB
·
Xet hash:
f4cab4b382b02f4b1762ba4035e7bf1baf35c33b6da868656ee6054132c1704f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.