Buckets:

Mercity/FluxDistill / scripts /12_build_svdquant.py
Pranav2748's picture
download
raw
6.59 kB
"""Build a W4A8 SVDQuant student of the klein-4B transformer (our own fake-quant).
Pipeline: collect per-channel activation abs-max over a calibration set (real latents
noised across sigmas, with their captions — same forward the eval metric uses), then for
every block Linear: smooth -> SVD low-rank (rank R) -> 4-bit residual. Saves the
quantized state + a config json that round-trips via apply_svdquant_empty.
SVD mode: WHITEN=1 (default) uses activation-aware (whitened) SVD — the low-rank branch
minimizes OUTPUT error ‖X̂(Ŵ−L)‖ via the calibration activation Gram, so rank is a real
quality knob. WHITEN=0 falls back to plain weight-SVD (the base SVDQuant paper).
Usage:
RANK=32 ALPHA=0.5 WBITS=4 ABITS=8 WGROUP=64 N_CALIB=2048 WHITEN=1 CALIB_DIR=data/monet_calib \\
python3 -u scripts/12_build_svdquant.py
"""
import json
import os
import time
import torch
from flux2distill.config import Config
from flux2distill.data import LatentCaptionDataset
from flux2distill.losses import build_x_t
from flux2distill.model_utils import load_pipeline
from flux2distill import svdquant as sq
cfg = Config()
RANK = int(os.environ.get("RANK", 32))
ALPHA = float(os.environ.get("ALPHA", 0.5))
WBITS = int(os.environ.get("WBITS", 4))
ABITS = int(os.environ.get("ABITS", 8))
WGROUP = int(os.environ.get("WGROUP", 64))
N_CALIB = int(os.environ.get("N_CALIB", 2048))
MB = int(os.environ.get("MB", 4))
WHITEN = bool(int(os.environ.get("WHITEN", 1))) # 1=activation-aware (whitened) SVD, 0=plain
REFINE = int(os.environ.get("REFINE", 3)) # iterative low-rank refinement iters (paper §4.2); 0=one-shot
SMOOTH = bool(int(os.environ.get("SMOOTH", 1))) # 1=SmoothQuant migration, 0=no smoothing (s=1, RTN floor)
CALIB_DIR = os.environ.get("CALIB_DIR", "data/monet_calib")
_TAG = "whiten" if WHITEN else "plain"
OUT = os.environ.get("OUT", f"outputs/svdquant_r{RANK}_a{ALPHA}_w{WBITS}a{ABITS}_{_TAG}")
os.makedirs(OUT, exist_ok=True)
def log(m):
print(f"[{time.strftime('%H:%M:%S')}] {m}", flush=True)
log(f"=== SVDQuant build: W{WBITS}A{ABITS} rank={RANK} alpha={ALPHA} group={WGROUP} "
f"svd={'whitened' if WHITEN else 'plain'} refine_iters={REFINE} "
f"smooth={'on' if SMOOTH else 'OFF(RTN)'} "
f"n_calib={N_CALIB} calib_dir={CALIB_DIR} -> {OUT} ===")
log("loading teacher pipeline...")
pipe = load_pipeline(device="cuda")
tf = pipe.transformer
tf.eval().requires_grad_(False)
# ---- calibration data (fall back to the small training cache if the big one is absent) ----
cache_dir = CALIB_DIR if os.path.exists(f"{CALIB_DIR}/latents.pt") else "data/monet_cache"
if cache_dir != CALIB_DIR:
log(f"WARNING: {CALIB_DIR} not found; falling back to {cache_dir}")
ds = LatentCaptionDataset(cache_dir=cache_dir)
n_calib = min(N_CALIB, len(ds))
log(f"calibration: {n_calib} samples from {cache_dir} (have {len(ds)})")
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
torch.Generator(device="cuda").manual_seed(0))
def velocity(x_t, sigma, pe, tid):
out = tf(hidden_states=x_t, timestep=sigma, guidance=None,
encoder_hidden_states=pe, txt_ids=tid, img_ids=img_ids, return_dict=False)[0]
return out[:, : x_t.size(1)]
# ---- collect activation stats (abs-max + Gram if whitening) over the calibration set ----
target = sq.target_linear_names(tf)
log(f"target Linear layers to quantize: {len(target)}")
stats, grams, handles = sq.collect_act_stats(tf, target, with_gram=WHITEN, gram_device="cuda")
cg = torch.Generator(device="cuda").manual_seed(0)
t0 = time.time()
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
for i in range(0, n_calib, MB):
idx = list(range(i, min(i + MB, n_calib)))
x0 = ds.latents[idx].to("cuda", torch.bfloat16)
caps = [ds.captions[j] for j in idx]
pe, tid = pipe.encode_prompt(caps, device="cuda")
eps = torch.randn(x0.shape, generator=cg, device="cuda", dtype=torch.float32)
sigma = torch.rand(x0.size(0), generator=cg, device="cuda", dtype=torch.float32)
x_t = build_x_t(x0.float(), eps, sigma).to(torch.bfloat16)
velocity(x_t, sigma, pe, tid)
if (i // MB) % 20 == 0:
log(f" calib {min(i + MB, n_calib)}/{n_calib} ({(i+MB)/(time.time()-t0+1e-9):.1f} smp/s)")
for h in handles:
h.remove()
got = sum(v is not None for v in stats.values())
log(f"collected activation stats for {got}/{len(target)} layers"
f"{' (+Gram)' if WHITEN else ''}; calibVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB")
# ---- decompose + quantize every target layer ----
log(f"decomposing (smooth -> {'whitened ' if WHITEN else ''}SVD -> 4-bit residual)...")
specs, diags = sq.apply_svdquant_from_stats(tf, stats, rank=RANK, alpha=ALPHA,
w_bits=WBITS, a_bits=ABITS, w_group=WGROUP,
svd_device="cuda", grams=grams, whiten=WHITEN,
refine_iters=REFINE, smooth=SMOOTH)
ref_its = [d.get("refine_best_it", 0) for d in diags.values()]
log(f"refinement: best-iter mean={sum(ref_its)/len(ref_its):.1f} max={max(ref_its)} "
f"(0 => one-shot was best)")
rels = sorted(d["rel_err"] for d in diags.values())
mean_rel = sum(rels) / len(rels)
med_rel = rels[len(rels) // 2]
log(f"weight-recon rel-err: mean={mean_rel:.4f} median={med_rel:.4f} "
f"min={rels[0]:.4f} max={rels[-1]:.4f}")
if WHITEN:
orels = sorted(d["out_rel_err"] for d in diags.values() if "out_rel_err" in d)
if orels:
log(f"output-recon rel-err (what whitening optimizes): mean={sum(orels)/len(orels):.4f} "
f"median={orels[len(orels)//2]:.4f} min={orels[0]:.4f} max={orels[-1]:.4f}")
summ = sq.quant_summary(tf)
log(f"quantized {summ['n_quant_layers']} layers | effective {summ['quant_MB']:.0f}MB vs "
f"{summ['full_MB']:.0f}MB bf16 ({summ['ratio']:.2f}x smaller on the quantized weights)")
# ---- save ----
torch.save(tf.state_dict(), f"{OUT}/quant_state.pt")
config = {
"method": "svdquant_fakequant", "svd": "whitened" if WHITEN else "plain",
"refine_iters": REFINE, "smooth": SMOOTH,
"w_bits": WBITS, "a_bits": ABITS, "rank": RANK,
"alpha": ALPHA, "w_group": WGROUP, "n_calib": n_calib, "calib_dir": cache_dir,
"specs": specs,
"diag": {"mean_rel_err": mean_rel, "median_rel_err": med_rel},
"summary": summ,
}
json.dump(config, open(f"{OUT}/quant_config.json", "w"))
log(f"DONE -> {OUT}/quant_state.pt + quant_config.json "
f"peakVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB")

Xet Storage Details

Size:
6.59 kB
·
Xet hash:
cb9a9491ab9104f9b8c57f11c5ebf8f8176d0ce0ab4692ece8ca0c7b59163733

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.