Buckets:

Mercity/FluxDistill / scripts /03_build_student.py
Pranav2748's picture
download
raw
4.16 kB
"""Build the ~2B student from klein-4B:
capture single-block I/O -> SVD-energy block selection -> build + lstsq-init surrogates
-> short calibration fit -> save student state_dict + selection metadata.
"""
import json
import os
import time
import torch
from flux2distill.config import Config
from flux2distill.model_utils import load_pipeline, param_summary
from flux2distill.surgery import capture_single_block_io, select_blocks_svd_energy, build_student
from flux2distill.calibration import fit_surrogate
cfg = Config()
OUT = "outputs/student"
os.makedirs(OUT, exist_ok=True)
DO_FIT = True
CALIB_PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"exactly five ripe red apples arranged in a row on a wooden table",
"a close-up portrait of an elderly fisherman with a weathered face, natural light",
"a macro photograph of dew drops on a spider web at dawn",
"a futuristic city skyline at night with neon lights and flying cars",
"an oil painting of a bowl of fruit in the style of a dutch still life",
]
print("loading teacher pipeline...")
pipe = load_pipeline(device="cuda")
tf = pipe.transformer
print("teacher transformer params:", param_summary(tf))
print(f"\ncapturing single-block I/O on {len(CALIB_PROMPTS)} calib prompts...")
t0 = time.time()
io = capture_single_block_io(pipe, CALIB_PROMPTS, num_inference_steps=4,
max_tokens_per_block=12000, seed=0)
print(f" captured in {time.time()-t0:.1f}s; tokens/block={io[0]['X'].shape[0]}, d={io[0]['X'].shape[1]}")
print(f"\nselecting blocks (rank={cfg.surgery.rank}, keep {cfg.surgery.keep_single} full)...")
keep_idx, surr_idx, stats = select_blocks_svd_energy(io, cfg.surgery.rank, cfg.surgery.keep_single)
print(f" KEEP full ({len(keep_idx)}): {keep_idx}")
print(f" SURROGATE ({len(surr_idx)}): {surr_idx}")
print(" per-block [block: captured_ratio | delta_rms]:")
for s in sorted(stats, key=lambda d: d['block']):
tag = "keep" if s['block'] in keep_idx else "SURR"
print(f" blk{s['block']:2d} [{tag}] captured@{cfg.surgery.rank}={s['captured_ratio']:.4f} rms={s['delta_rms']:.4f}")
print("\nbuilding student (lstsq warm-start)...")
errs = build_student(tf, surr_idx, io, rank=cfg.surgery.rank, act=cfg.surgery.act, device="cuda")
print(" lstsq reconstruction rel-err per surrogate:")
for i in surr_idx:
print(f" blk{i:2d}: {errs[i]:.4f}")
print(f" mean lstsq rel-err: {sum(errs.values())/len(errs):.4f}")
if DO_FIT:
print("\ncalibration fit (closing the GELU gap)...")
fit_results = {}
for i in surr_idx:
sur = tf.single_transformer_blocks[i]
ie, fe = fit_surrogate(sur, io[i]["X"], io[i]["Delta"], steps=200, lr=1e-3)
sur.to(dtype=torch.bfloat16) # back to model dtype
fit_results[i] = (ie, fe)
print(f" blk{i:2d}: {ie:.4f} -> {fe:.4f}")
print(f" mean post-fit rel-err: {sum(f for _, f in fit_results.values())/len(fit_results):.4f}")
print("\nstudent params:", param_summary(tf))
# Save student state + selection metadata.
torch.save(tf.state_dict(), f"{OUT}/student_state.pt")
meta = {
"keep_idx": keep_idx, "surrogate_idx": surr_idx,
"rank": cfg.surgery.rank, "act": cfg.surgery.act,
"lstsq_rel_err": {str(k): v for k, v in errs.items()},
"stats": stats,
"param_summary": param_summary(tf),
}
with open(f"{OUT}/selection.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"\nsaved student_state.pt + selection.json to {OUT}/")
# Smoke gen with the student (pipe.transformer is now the student).
os.makedirs("outputs/student_smoke", exist_ok=True)
SMOKE = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
]
gen = torch.Generator(device="cuda").manual_seed(0)
out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=gen)
for i, im in enumerate(out.images):
im.save(f"outputs/student_smoke/student_{i}.png")
print("saved student smoke images to outputs/student_smoke/ (pre-training, warm-start only)")

Xet Storage Details

Size:
4.16 kB
·
Xet hash:
421783289e3e642158f5402ffb74f2959226a4c24c966f30fd6ea3f2891ab26e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.