Buckets:
| """Careful surrogate-only recovery (supersedes 07_train.py). | |
| Research-led recipe (see plan.md): | |
| - FREEZE the entire pretrained network; train ONLY the 6 surrogate modules (~19M). | |
| - fp32 master weights on the surrogates (only the trained params), bf16 autocast compute. | |
| - AdamW @ 1e-4 (adapter/LoRA regime for diffusion), cosine decay, grad-clip 1.0. | |
| (Muon is for the later full-recovery run; pass opt=muon to A/B on surrogates only.) | |
| - Unbuffered file logging (run with python3 -u), per-step image logging, fixed held-out | |
| eval-loss metric, best-checkpoint saving, divergence guard. | |
| Usage: python3 -u scripts/08_train_recover.py [steps] [adamw|muon] [base_lr] | |
| """ | |
| import json | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from flux2distill.config import Config | |
| from flux2distill.data import LatentCaptionDataset, collate | |
| from flux2distill.losses import velocity_match_loss, flow_match_loss, build_x_t | |
| from flux2distill.model_utils import load_pipeline, load_transformer, param_summary | |
| from flux2distill.optim import Muon | |
| from flux2distill.surgery import attach_surrogates | |
| cfg = Config() | |
| STEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 300 | |
| OPT = sys.argv[2] if len(sys.argv) > 2 else "adamw" | |
| BASE_LR = float(sys.argv[3]) if len(sys.argv) > 3 else (1e-4 if (len(sys.argv) <= 2 or sys.argv[2] == "adamw") else 2e-3) | |
| MB = int(os.environ.get("MB", 4)) | |
| ACCUM = int(os.environ.get("ACCUM", 2)) | |
| W_VEL, W_FM = 1.0, 0.25 | |
| CLIP = 1.0 | |
| MIN_LR_FRAC = 0.30 # cosine floors at 30% of base (visible decay + settles, no dead tail) | |
| STUDENT_DIR = os.environ.get("STUDENT_DIR", "outputs/student_v2") | |
| SCHED = os.environ.get("SCHED", "cosine") # cosine | constant ("fixed lr") | |
| TAG = f"{os.path.basename(STUDENT_DIR)}_{OPT}_{SCHED}" | |
| OUT = f"outputs/train_{TAG}" | |
| os.makedirs(f"{OUT}/samples", exist_ok=True) | |
| logf = open(f"{OUT}/train.log", "w", buffering=1) # line-buffered | |
| def log(msg): | |
| line = f"[{time.strftime('%H:%M:%S')}] {msg}" | |
| print(line, flush=True) | |
| logf.write(line + "\n"); logf.flush() | |
| log(f"=== recovery run: student={STUDENT_DIR} steps={STEPS} opt={OPT} sched={SCHED} " | |
| f"base_lr={BASE_LR} min_lr_frac={MIN_LR_FRAC} mb={MB} accum={ACCUM} clip={CLIP} " | |
| f"w_vel={W_VEL} w_fm={W_FM} ===") | |
| # ---- teacher (frozen) + student v2 ---- | |
| log("loading teacher pipeline (frozen)...") | |
| pipe = load_pipeline(device="cuda") | |
| teacher = pipe.transformer | |
| teacher.eval().requires_grad_(False) | |
| log(f"loading student from {STUDENT_DIR}...") | |
| sel = json.load(open(f"{STUDENT_DIR}/selection.json")) | |
| student = load_transformer(dtype="bfloat16", device="cuda") | |
| attach_surrogates(student, sel["surrogate_idx"], kind=sel.get("kind", "lowrank"), | |
| rank=sel.get("rank", 512), act=sel.get("act", "gelu"), | |
| heads=sel.get("heads", 4), head_dim=sel.get("head_dim", 128), | |
| conv_kernel=sel.get("conv_kernel", 5), ffn_hidden=sel.get("ffn_hidden", 1024), | |
| ffn_idx=sel.get("ffn_idx", None), | |
| device="cuda", dtype=torch.bfloat16) | |
| student.load_state_dict(torch.load(f"{STUDENT_DIR}/student_state.pt", map_location="cuda"), strict=False) | |
| student.eval() | |
| # FREEZE everything; train ONLY surrogate modules (cast those to fp32 = master weights). | |
| student.requires_grad_(False) | |
| trainable = [] | |
| for i in sel["surrogate_idx"]: | |
| blk = student.single_transformer_blocks[i].float() # fp32 master for trained params | |
| blk.requires_grad_(True) | |
| trainable += list(blk.parameters()) | |
| n_train = sum(p.numel() for p in trainable) | |
| n_total = sum(p.numel() for p in student.parameters()) | |
| log(f"trainable: {n_train/1e6:.2f}M / {n_total/1e9:.3f}B ({100*n_train/n_total:.2f}%) " | |
| f"surrogates={sel['surrogate_idx']}") | |
| assert n_train < 0.05 * n_total, "expected only surrogates trainable (<5% of model)" | |
| _, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda", | |
| torch.Generator(device="cuda").manual_seed(0)) | |
| def velocity(tf, x_t, sigma, prompt_embeds, text_ids): | |
| out = tf(hidden_states=x_t, timestep=sigma, guidance=None, | |
| encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=img_ids, | |
| return_dict=False)[0] | |
| return out[:, : x_t.size(1)] | |
| # ---- optimizer (surrogates only) + cosine schedule ---- | |
| if OPT == "muon": | |
| opt = Muon(trainable, lr=BASE_LR, momentum=0.95, weight_decay=cfg.train.weight_decay) | |
| else: | |
| opt = torch.optim.AdamW(trainable, lr=BASE_LR, betas=(0.9, 0.999), weight_decay=cfg.train.weight_decay) | |
| def lr_at(step): # constant ("fixed lr"), or cosine decay to a 15%-of-base floor (no warmup) | |
| if SCHED == "constant": | |
| return BASE_LR | |
| min_lr = MIN_LR_FRAC * BASE_LR | |
| return min_lr + (BASE_LR - min_lr) * 0.5 * (1 + math.cos(math.pi * step / STEPS)) | |
| ds = LatentCaptionDataset() | |
| dl = DataLoader(ds, batch_size=MB, shuffle=True, collate_fn=collate, drop_last=True) | |
| def loader(): | |
| while True: | |
| for b in dl: | |
| yield b | |
| gen_iter = loader() | |
| log(f"dataset: {len(ds)} samples; eff batch {MB*ACCUM}") | |
| # ---- fixed held-out eval batch (objective, apples-to-apples metric) ---- | |
| EVAL_N = 16 | |
| ev_x0 = ds.latents[:EVAL_N].to("cuda", torch.bfloat16) | |
| ev_caps = ds.captions[:EVAL_N] | |
| with torch.no_grad(): | |
| ev_pe, ev_tid = pipe.encode_prompt(ev_caps, device="cuda") | |
| ev_gen = torch.Generator(device="cuda").manual_seed(1234) | |
| ev_eps = torch.randn(ev_x0.shape, generator=ev_gen, device="cuda", dtype=torch.float32) | |
| ev_sigma = torch.rand(EVAL_N, generator=ev_gen, device="cuda", dtype=torch.float32) | |
| ev_xt = build_x_t(ev_x0.float(), ev_eps, ev_sigma).to(torch.bfloat16) | |
| def eval_loss(): | |
| student.eval() | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| vt = velocity(teacher, ev_xt, ev_sigma, ev_pe, ev_tid) | |
| vs = velocity(student, ev_xt, ev_sigma, ev_pe, ev_tid) | |
| return float(velocity_match_loss(vs, vt)) | |
| SAMPLE_PROMPTS = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water", | |
| ] | |
| def sample(tag): | |
| bak = pipe.transformer | |
| student.eval(); pipe.transformer = student | |
| g = torch.Generator(device="cuda").manual_seed(0) | |
| # fp32 master surrogates -> run under autocast so their matmuls execute in bf16. | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| imgs = pipe(prompt=SAMPLE_PROMPTS, num_inference_steps=4, guidance_scale=1.0, | |
| height=512, width=512, generator=g).images | |
| for i, im in enumerate(imgs): | |
| im.save(f"{OUT}/samples/{tag}_{i}.png") | |
| pipe.transformer = bak | |
| # ---- baseline before any training ---- | |
| ev0 = eval_loss() | |
| log(f"step 0 eval_vel_loss={ev0:.4f} (baseline, surrogates=warm-start)") | |
| sample("step0000") | |
| best = float("inf") | |
| t0 = time.time() | |
| run = run_v = run_f = 0.0 | |
| for step in range(1, STEPS + 1): | |
| lr = lr_at(step) | |
| for g in opt.param_groups: | |
| g["lr"] = lr | |
| student.train() | |
| opt.zero_grad(set_to_none=True) | |
| acc = acc_v = acc_f = 0.0 | |
| for _ in range(ACCUM): | |
| x0, caps = next(gen_iter) | |
| x0 = x0.to("cuda", torch.bfloat16) | |
| with torch.no_grad(): | |
| pe, tid = pipe.encode_prompt(caps, device="cuda") | |
| eps = torch.randn_like(x0) | |
| sigma = torch.rand(x0.size(0), device="cuda", dtype=torch.float32) | |
| x_t = build_x_t(x0.float(), eps.float(), sigma).to(torch.bfloat16) | |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| vt = velocity(teacher, x_t, sigma, pe, tid) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| vs = velocity(student, x_t, sigma, pe, tid) | |
| lv = velocity_match_loss(vs, vt) | |
| lf = flow_match_loss(vs, eps, x0) | |
| loss = (W_VEL * lv + W_FM * lf) / ACCUM | |
| loss.backward() | |
| acc += loss.item(); acc_v += lv.item() / ACCUM; acc_f += lf.item() / ACCUM | |
| gnorm = torch.nn.utils.clip_grad_norm_(trainable, CLIP).item() | |
| opt.step() | |
| run += acc; run_v += acc_v; run_f += acc_f | |
| if step % 10 == 0: | |
| n = 10 | |
| log(f"step {step:4d} loss={run/n:.4f} (vel={run_v/n:.4f} fm={run_f/n:.4f}) " | |
| f"lr={lr:.2e} gnorm={gnorm:.3f} {step/(time.time()-t0):.2f}it/s") | |
| run = run_v = run_f = 0.0 | |
| if step % 50 == 0 or step == STEPS: | |
| ev = eval_loss() | |
| improved = " *BEST*" if ev < best else "" | |
| log(f" eval_vel_loss={ev:.4f} (baseline {ev0:.4f}){improved}") | |
| sample(f"step{step:04d}") | |
| if ev < best: | |
| best = ev | |
| torch.save(student.state_dict(), f"{OUT}/student_best.pt") | |
| if ev > 3 * ev0 and step > 50: # divergence guard | |
| log(f"DIVERGENCE: eval_loss {ev:.3f} > 3x baseline {ev0:.3f}; stopping early.") | |
| break | |
| torch.save(student.state_dict(), f"{OUT}/student_final.pt") | |
| log(f"DONE best_eval_vel_loss={best:.4f} baseline={ev0:.4f} " | |
| f"improvement={100*(ev0-best)/ev0:.1f}% peakVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB") | |
| log("TRAIN_DONE") | |
| logf.close() | |
Xet Storage Details
- Size:
- 9.08 kB
- Xet hash:
- 655a28eb9c01f0515611fb07f56039273d6cbcf62f17ea81583086a4f2aabcab
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.