Buckets:
| """Build the ~2B student from klein-4B: | |
| capture single-block I/O -> SVD-energy block selection -> build + lstsq-init surrogates | |
| -> short calibration fit -> save student state_dict + selection metadata. | |
| """ | |
| import json | |
| import os | |
| import time | |
| import torch | |
| from flux2distill.config import Config | |
| from flux2distill.model_utils import load_pipeline, param_summary | |
| from flux2distill.surgery import capture_single_block_io, select_blocks_svd_energy, build_student | |
| from flux2distill.calibration import fit_surrogate | |
| cfg = Config() | |
| OUT = "outputs/student" | |
| os.makedirs(OUT, exist_ok=True) | |
| DO_FIT = True | |
| CALIB_PROMPTS = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "exactly five ripe red apples arranged in a row on a wooden table", | |
| "a close-up portrait of an elderly fisherman with a weathered face, natural light", | |
| "a macro photograph of dew drops on a spider web at dawn", | |
| "a futuristic city skyline at night with neon lights and flying cars", | |
| "an oil painting of a bowl of fruit in the style of a dutch still life", | |
| ] | |
| print("loading teacher pipeline...") | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer | |
| print("teacher transformer params:", param_summary(tf)) | |
| print(f"\ncapturing single-block I/O on {len(CALIB_PROMPTS)} calib prompts...") | |
| t0 = time.time() | |
| io = capture_single_block_io(pipe, CALIB_PROMPTS, num_inference_steps=4, | |
| max_tokens_per_block=12000, seed=0) | |
| print(f" captured in {time.time()-t0:.1f}s; tokens/block={io[0]['X'].shape[0]}, d={io[0]['X'].shape[1]}") | |
| print(f"\nselecting blocks (rank={cfg.surgery.rank}, keep {cfg.surgery.keep_single} full)...") | |
| keep_idx, surr_idx, stats = select_blocks_svd_energy(io, cfg.surgery.rank, cfg.surgery.keep_single) | |
| print(f" KEEP full ({len(keep_idx)}): {keep_idx}") | |
| print(f" SURROGATE ({len(surr_idx)}): {surr_idx}") | |
| print(" per-block [block: captured_ratio | delta_rms]:") | |
| for s in sorted(stats, key=lambda d: d['block']): | |
| tag = "keep" if s['block'] in keep_idx else "SURR" | |
| print(f" blk{s['block']:2d} [{tag}] captured@{cfg.surgery.rank}={s['captured_ratio']:.4f} rms={s['delta_rms']:.4f}") | |
| print("\nbuilding student (lstsq warm-start)...") | |
| errs = build_student(tf, surr_idx, io, rank=cfg.surgery.rank, act=cfg.surgery.act, device="cuda") | |
| print(" lstsq reconstruction rel-err per surrogate:") | |
| for i in surr_idx: | |
| print(f" blk{i:2d}: {errs[i]:.4f}") | |
| print(f" mean lstsq rel-err: {sum(errs.values())/len(errs):.4f}") | |
| if DO_FIT: | |
| print("\ncalibration fit (closing the GELU gap)...") | |
| fit_results = {} | |
| for i in surr_idx: | |
| sur = tf.single_transformer_blocks[i] | |
| ie, fe = fit_surrogate(sur, io[i]["X"], io[i]["Delta"], steps=200, lr=1e-3) | |
| sur.to(dtype=torch.bfloat16) # back to model dtype | |
| fit_results[i] = (ie, fe) | |
| print(f" blk{i:2d}: {ie:.4f} -> {fe:.4f}") | |
| print(f" mean post-fit rel-err: {sum(f for _, f in fit_results.values())/len(fit_results):.4f}") | |
| print("\nstudent params:", param_summary(tf)) | |
| # Save student state + selection metadata. | |
| torch.save(tf.state_dict(), f"{OUT}/student_state.pt") | |
| meta = { | |
| "keep_idx": keep_idx, "surrogate_idx": surr_idx, | |
| "rank": cfg.surgery.rank, "act": cfg.surgery.act, | |
| "lstsq_rel_err": {str(k): v for k, v in errs.items()}, | |
| "stats": stats, | |
| "param_summary": param_summary(tf), | |
| } | |
| with open(f"{OUT}/selection.json", "w") as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"\nsaved student_state.pt + selection.json to {OUT}/") | |
| # Smoke gen with the student (pipe.transformer is now the student). | |
| os.makedirs("outputs/student_smoke", exist_ok=True) | |
| SMOKE = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water", | |
| ] | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| out = pipe(prompt=SMOKE, num_inference_steps=4, guidance_scale=1.0, height=512, width=512, generator=gen) | |
| for i, im in enumerate(out.images): | |
| im.save(f"outputs/student_smoke/student_{i}.png") | |
| print("saved student smoke images to outputs/student_smoke/ (pre-training, warm-start only)") | |
Xet Storage Details
- Size:
- 4.16 kB
- Xet hash:
- 421783289e3e642158f5402ffb74f2959226a4c24c966f30fd6ea3f2891ab26e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.