Buckets:
| """Build a W4A8 SVDQuant student of the klein-4B transformer (our own fake-quant). | |
| Pipeline: collect per-channel activation abs-max over a calibration set (real latents | |
| noised across sigmas, with their captions — same forward the eval metric uses), then for | |
| every block Linear: smooth -> SVD low-rank (rank R) -> 4-bit residual. Saves the | |
| quantized state + a config json that round-trips via apply_svdquant_empty. | |
| SVD mode: WHITEN=1 (default) uses activation-aware (whitened) SVD — the low-rank branch | |
| minimizes OUTPUT error ‖X̂(Ŵ−L)‖ via the calibration activation Gram, so rank is a real | |
| quality knob. WHITEN=0 falls back to plain weight-SVD (the base SVDQuant paper). | |
| Usage: | |
| RANK=32 ALPHA=0.5 WBITS=4 ABITS=8 WGROUP=64 N_CALIB=2048 WHITEN=1 CALIB_DIR=data/monet_calib \\ | |
| python3 -u scripts/12_build_svdquant.py | |
| """ | |
| import json | |
| import os | |
| import time | |
| import torch | |
| from flux2distill.config import Config | |
| from flux2distill.data import LatentCaptionDataset | |
| from flux2distill.losses import build_x_t | |
| from flux2distill.model_utils import load_pipeline | |
| from flux2distill import svdquant as sq | |
| cfg = Config() | |
| RANK = int(os.environ.get("RANK", 32)) | |
| ALPHA = float(os.environ.get("ALPHA", 0.5)) | |
| WBITS = int(os.environ.get("WBITS", 4)) | |
| ABITS = int(os.environ.get("ABITS", 8)) | |
| WGROUP = int(os.environ.get("WGROUP", 64)) | |
| N_CALIB = int(os.environ.get("N_CALIB", 2048)) | |
| MB = int(os.environ.get("MB", 4)) | |
| WHITEN = bool(int(os.environ.get("WHITEN", 1))) # 1=activation-aware (whitened) SVD, 0=plain | |
| REFINE = int(os.environ.get("REFINE", 3)) # iterative low-rank refinement iters (paper §4.2); 0=one-shot | |
| SMOOTH = bool(int(os.environ.get("SMOOTH", 1))) # 1=SmoothQuant migration, 0=no smoothing (s=1, RTN floor) | |
| CALIB_DIR = os.environ.get("CALIB_DIR", "data/monet_calib") | |
| _TAG = "whiten" if WHITEN else "plain" | |
| OUT = os.environ.get("OUT", f"outputs/svdquant_r{RANK}_a{ALPHA}_w{WBITS}a{ABITS}_{_TAG}") | |
| os.makedirs(OUT, exist_ok=True) | |
| def log(m): | |
| print(f"[{time.strftime('%H:%M:%S')}] {m}", flush=True) | |
| log(f"=== SVDQuant build: W{WBITS}A{ABITS} rank={RANK} alpha={ALPHA} group={WGROUP} " | |
| f"svd={'whitened' if WHITEN else 'plain'} refine_iters={REFINE} " | |
| f"smooth={'on' if SMOOTH else 'OFF(RTN)'} " | |
| f"n_calib={N_CALIB} calib_dir={CALIB_DIR} -> {OUT} ===") | |
| log("loading teacher pipeline...") | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer | |
| tf.eval().requires_grad_(False) | |
| # ---- calibration data (fall back to the small training cache if the big one is absent) ---- | |
| cache_dir = CALIB_DIR if os.path.exists(f"{CALIB_DIR}/latents.pt") else "data/monet_cache" | |
| if cache_dir != CALIB_DIR: | |
| log(f"WARNING: {CALIB_DIR} not found; falling back to {cache_dir}") | |
| ds = LatentCaptionDataset(cache_dir=cache_dir) | |
| n_calib = min(N_CALIB, len(ds)) | |
| log(f"calibration: {n_calib} samples from {cache_dir} (have {len(ds)})") | |
| _, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda", | |
| torch.Generator(device="cuda").manual_seed(0)) | |
| def velocity(x_t, sigma, pe, tid): | |
| out = tf(hidden_states=x_t, timestep=sigma, guidance=None, | |
| encoder_hidden_states=pe, txt_ids=tid, img_ids=img_ids, return_dict=False)[0] | |
| return out[:, : x_t.size(1)] | |
| # ---- collect activation stats (abs-max + Gram if whitening) over the calibration set ---- | |
| target = sq.target_linear_names(tf) | |
| log(f"target Linear layers to quantize: {len(target)}") | |
| stats, grams, handles = sq.collect_act_stats(tf, target, with_gram=WHITEN, gram_device="cuda") | |
| cg = torch.Generator(device="cuda").manual_seed(0) | |
| t0 = time.time() | |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| for i in range(0, n_calib, MB): | |
| idx = list(range(i, min(i + MB, n_calib))) | |
| x0 = ds.latents[idx].to("cuda", torch.bfloat16) | |
| caps = [ds.captions[j] for j in idx] | |
| pe, tid = pipe.encode_prompt(caps, device="cuda") | |
| eps = torch.randn(x0.shape, generator=cg, device="cuda", dtype=torch.float32) | |
| sigma = torch.rand(x0.size(0), generator=cg, device="cuda", dtype=torch.float32) | |
| x_t = build_x_t(x0.float(), eps, sigma).to(torch.bfloat16) | |
| velocity(x_t, sigma, pe, tid) | |
| if (i // MB) % 20 == 0: | |
| log(f" calib {min(i + MB, n_calib)}/{n_calib} ({(i+MB)/(time.time()-t0+1e-9):.1f} smp/s)") | |
| for h in handles: | |
| h.remove() | |
| got = sum(v is not None for v in stats.values()) | |
| log(f"collected activation stats for {got}/{len(target)} layers" | |
| f"{' (+Gram)' if WHITEN else ''}; calibVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB") | |
| # ---- decompose + quantize every target layer ---- | |
| log(f"decomposing (smooth -> {'whitened ' if WHITEN else ''}SVD -> 4-bit residual)...") | |
| specs, diags = sq.apply_svdquant_from_stats(tf, stats, rank=RANK, alpha=ALPHA, | |
| w_bits=WBITS, a_bits=ABITS, w_group=WGROUP, | |
| svd_device="cuda", grams=grams, whiten=WHITEN, | |
| refine_iters=REFINE, smooth=SMOOTH) | |
| ref_its = [d.get("refine_best_it", 0) for d in diags.values()] | |
| log(f"refinement: best-iter mean={sum(ref_its)/len(ref_its):.1f} max={max(ref_its)} " | |
| f"(0 => one-shot was best)") | |
| rels = sorted(d["rel_err"] for d in diags.values()) | |
| mean_rel = sum(rels) / len(rels) | |
| med_rel = rels[len(rels) // 2] | |
| log(f"weight-recon rel-err: mean={mean_rel:.4f} median={med_rel:.4f} " | |
| f"min={rels[0]:.4f} max={rels[-1]:.4f}") | |
| if WHITEN: | |
| orels = sorted(d["out_rel_err"] for d in diags.values() if "out_rel_err" in d) | |
| if orels: | |
| log(f"output-recon rel-err (what whitening optimizes): mean={sum(orels)/len(orels):.4f} " | |
| f"median={orels[len(orels)//2]:.4f} min={orels[0]:.4f} max={orels[-1]:.4f}") | |
| summ = sq.quant_summary(tf) | |
| log(f"quantized {summ['n_quant_layers']} layers | effective {summ['quant_MB']:.0f}MB vs " | |
| f"{summ['full_MB']:.0f}MB bf16 ({summ['ratio']:.2f}x smaller on the quantized weights)") | |
| # ---- save ---- | |
| torch.save(tf.state_dict(), f"{OUT}/quant_state.pt") | |
| config = { | |
| "method": "svdquant_fakequant", "svd": "whitened" if WHITEN else "plain", | |
| "refine_iters": REFINE, "smooth": SMOOTH, | |
| "w_bits": WBITS, "a_bits": ABITS, "rank": RANK, | |
| "alpha": ALPHA, "w_group": WGROUP, "n_calib": n_calib, "calib_dir": cache_dir, | |
| "specs": specs, | |
| "diag": {"mean_rel_err": mean_rel, "median_rel_err": med_rel}, | |
| "summary": summ, | |
| } | |
| json.dump(config, open(f"{OUT}/quant_config.json", "w")) | |
| log(f"DONE -> {OUT}/quant_state.pt + quant_config.json " | |
| f"peakVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB") | |
Xet Storage Details
- Size:
- 6.59 kB
- Xet hash:
- cb9a9491ab9104f9b8c57f11c5ebf8f8176d0ce0ab4692ece8ca0c7b59163733
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.