Buckets:
| """Evaluate a W4A8 SVDQuant student: held-out velocity-matching loss vs the teacher | |
| (directly comparable to the block-surgery numbers in RESULTS.md) + teacher-vs-quant | |
| image montages. | |
| Usage: python3 -u scripts/13_eval_svdquant.py [QUANT_DIR=outputs/svdquant_r32_...] | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import torch | |
| from flux2distill.data import LatentCaptionDataset | |
| from flux2distill.losses import velocity_match_loss, build_x_t | |
| from flux2distill.model_utils import load_pipeline, load_transformer | |
| from flux2distill.eval_utils import side_by_side | |
| from flux2distill import svdquant as sq | |
| QUANT_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("QUANT_DIR", "outputs/svdquant_r32_a0.5_w4a8") | |
| cfg = json.load(open(f"{QUANT_DIR}/quant_config.json")) | |
| OUT = f"{QUANT_DIR}/eval" | |
| os.makedirs(OUT, exist_ok=True) | |
| print(f"=== eval {QUANT_DIR}: W{cfg['w_bits']}A{cfg['a_bits']} rank={cfg['rank']} " | |
| f"alpha={cfg['alpha']} ===", flush=True) | |
| pipe = load_pipeline(device="cuda") | |
| teacher = pipe.transformer | |
| teacher.eval().requires_grad_(False) | |
| # rebuild the quantized student and load weights | |
| student = load_transformer(dtype="bfloat16", device="cuda").eval() | |
| sq.apply_svdquant_empty(student, cfg["specs"], w_bits=cfg["w_bits"], a_bits=cfg["a_bits"], | |
| w_group=cfg["w_group"], dtype=torch.bfloat16) | |
| missing, unexpected = student.load_state_dict(torch.load(f"{QUANT_DIR}/quant_state.pt", | |
| map_location="cuda"), strict=False) | |
| assert not unexpected, f"unexpected keys: {unexpected[:5]}" | |
| student.requires_grad_(False) | |
| # ---- fixed held-out eval batch (SAME construction as 08_train_recover for continuity) ---- | |
| ds = LatentCaptionDataset(cache_dir="data/monet_cache") | |
| EVAL_N = 16 | |
| ev_x0 = ds.latents[:EVAL_N].to("cuda", torch.bfloat16) | |
| ev_caps = ds.captions[:EVAL_N] | |
| with torch.no_grad(): | |
| ev_pe, ev_tid = pipe.encode_prompt(ev_caps, device="cuda") | |
| ev_gen = torch.Generator(device="cuda").manual_seed(1234) | |
| ev_eps = torch.randn(ev_x0.shape, generator=ev_gen, device="cuda", dtype=torch.float32) | |
| ev_sigma = torch.rand(EVAL_N, generator=ev_gen, device="cuda", dtype=torch.float32) | |
| ev_xt = build_x_t(ev_x0.float(), ev_eps, ev_sigma).to(torch.bfloat16) | |
| _, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda", | |
| torch.Generator(device="cuda").manual_seed(0)) | |
| def velocity(tf, x_t, sigma, pe, tid): | |
| out = tf(hidden_states=x_t, timestep=sigma, guidance=None, | |
| encoder_hidden_states=pe, txt_ids=tid, img_ids=img_ids, return_dict=False)[0] | |
| return out[:, : x_t.size(1)] | |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| vt = velocity(teacher, ev_xt, ev_sigma, ev_pe, ev_tid) | |
| vs = velocity(student, ev_xt, ev_sigma, ev_pe, ev_tid) | |
| loss = float(velocity_match_loss(vs, vt)) | |
| # also report the relative L2 error of the velocity field | |
| rel = float((vs.float() - vt.float()).norm() / (vt.float().norm() + 1e-8)) | |
| print(f"eval_vel_loss={loss:.4f} vel_rel_err={rel:.4f}", flush=True) | |
| summ = cfg["summary"] | |
| print(f"quant: {summ['n_quant_layers']} layers, effective {summ['quant_MB']:.0f}MB " | |
| f"({summ['ratio']:.2f}x), weight-recon mean rel-err {cfg['diag']['mean_rel_err']:.4f}", | |
| flush=True) | |
| # ---- image montages: teacher vs quant on probe prompts ---- | |
| # First 4 are the original probes (kept for continuity); next 4 added 2026-06-01 for | |
| # richer visual comparison — a 2nd text case, multi-object composition/counting, a | |
| # hands/face case, and a fine-texture macro (the capabilities quant is most likely to bend). | |
| PROMPTS = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water", | |
| "a photorealistic portrait of an elderly fisherman, weathered face, sharp detail", | |
| "a bustling tokyo street at night, neon signs, rain-slicked pavement, reflections", | |
| 'a hand-lettered chalkboard cafe sign that reads "FRESH COFFEE" with small daily specials below', | |
| "a flat-lay breakfast table from above: three fried eggs, two strips of bacon, a glass of orange juice, and a small vase with one sunflower", | |
| "a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand", | |
| "an extreme macro of a dewy spider web at dawn, water droplets catching golden light, crisp detail", | |
| ] | |
| def gen(tf): | |
| pipe.transformer = tf | |
| g = torch.Generator(device="cuda").manual_seed(0) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| return pipe(prompt=PROMPTS, num_inference_steps=4, guidance_scale=1.0, | |
| height=512, width=512, generator=g).images | |
| t_imgs = gen(teacher) | |
| q_imgs = gen(student) | |
| for i, (t, q) in enumerate(zip(t_imgs, q_imgs)): | |
| side_by_side(t, q, "teacher", f"W{cfg['w_bits']}A{cfg['a_bits']} r{cfg['rank']}", | |
| PROMPTS[i]).save(f"{OUT}/cmp_{i}.png") | |
| print(f"saved {len(PROMPTS)} montages -> {OUT}/", flush=True) | |
Xet Storage Details
- Size:
- 5.05 kB
- Xet hash:
- f4cab4b382b02f4b1762ba4035e7bf1baf35c33b6da868656ee6054132c1704f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.