Buckets:

Mercity/FluxDistill / scripts /08_train_recover.py
Pranav2748's picture
download
raw
9.08 kB
"""Careful surrogate-only recovery (supersedes 07_train.py).
Research-led recipe (see plan.md):
- FREEZE the entire pretrained network; train ONLY the 6 surrogate modules (~19M).
- fp32 master weights on the surrogates (only the trained params), bf16 autocast compute.
- AdamW @ 1e-4 (adapter/LoRA regime for diffusion), cosine decay, grad-clip 1.0.
(Muon is for the later full-recovery run; pass opt=muon to A/B on surrogates only.)
- Unbuffered file logging (run with python3 -u), per-step image logging, fixed held-out
eval-loss metric, best-checkpoint saving, divergence guard.
Usage: python3 -u scripts/08_train_recover.py [steps] [adamw|muon] [base_lr]
"""
import json
import math
import os
import sys
import time
import torch
from torch.utils.data import DataLoader
from flux2distill.config import Config
from flux2distill.data import LatentCaptionDataset, collate
from flux2distill.losses import velocity_match_loss, flow_match_loss, build_x_t
from flux2distill.model_utils import load_pipeline, load_transformer, param_summary
from flux2distill.optim import Muon
from flux2distill.surgery import attach_surrogates
cfg = Config()
STEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 300
OPT = sys.argv[2] if len(sys.argv) > 2 else "adamw"
BASE_LR = float(sys.argv[3]) if len(sys.argv) > 3 else (1e-4 if (len(sys.argv) <= 2 or sys.argv[2] == "adamw") else 2e-3)
MB = int(os.environ.get("MB", 4))
ACCUM = int(os.environ.get("ACCUM", 2))
W_VEL, W_FM = 1.0, 0.25
CLIP = 1.0
MIN_LR_FRAC = 0.30 # cosine floors at 30% of base (visible decay + settles, no dead tail)
STUDENT_DIR = os.environ.get("STUDENT_DIR", "outputs/student_v2")
SCHED = os.environ.get("SCHED", "cosine") # cosine | constant ("fixed lr")
TAG = f"{os.path.basename(STUDENT_DIR)}_{OPT}_{SCHED}"
OUT = f"outputs/train_{TAG}"
os.makedirs(f"{OUT}/samples", exist_ok=True)
logf = open(f"{OUT}/train.log", "w", buffering=1) # line-buffered
def log(msg):
line = f"[{time.strftime('%H:%M:%S')}] {msg}"
print(line, flush=True)
logf.write(line + "\n"); logf.flush()
log(f"=== recovery run: student={STUDENT_DIR} steps={STEPS} opt={OPT} sched={SCHED} "
f"base_lr={BASE_LR} min_lr_frac={MIN_LR_FRAC} mb={MB} accum={ACCUM} clip={CLIP} "
f"w_vel={W_VEL} w_fm={W_FM} ===")
# ---- teacher (frozen) + student v2 ----
log("loading teacher pipeline (frozen)...")
pipe = load_pipeline(device="cuda")
teacher = pipe.transformer
teacher.eval().requires_grad_(False)
log(f"loading student from {STUDENT_DIR}...")
sel = json.load(open(f"{STUDENT_DIR}/selection.json"))
student = load_transformer(dtype="bfloat16", device="cuda")
attach_surrogates(student, sel["surrogate_idx"], kind=sel.get("kind", "lowrank"),
rank=sel.get("rank", 512), act=sel.get("act", "gelu"),
heads=sel.get("heads", 4), head_dim=sel.get("head_dim", 128),
conv_kernel=sel.get("conv_kernel", 5), ffn_hidden=sel.get("ffn_hidden", 1024),
ffn_idx=sel.get("ffn_idx", None),
device="cuda", dtype=torch.bfloat16)
student.load_state_dict(torch.load(f"{STUDENT_DIR}/student_state.pt", map_location="cuda"), strict=False)
student.eval()
# FREEZE everything; train ONLY surrogate modules (cast those to fp32 = master weights).
student.requires_grad_(False)
trainable = []
for i in sel["surrogate_idx"]:
blk = student.single_transformer_blocks[i].float() # fp32 master for trained params
blk.requires_grad_(True)
trainable += list(blk.parameters())
n_train = sum(p.numel() for p in trainable)
n_total = sum(p.numel() for p in student.parameters())
log(f"trainable: {n_train/1e6:.2f}M / {n_total/1e9:.3f}B ({100*n_train/n_total:.2f}%) "
f"surrogates={sel['surrogate_idx']}")
assert n_train < 0.05 * n_total, "expected only surrogates trainable (<5% of model)"
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
torch.Generator(device="cuda").manual_seed(0))
def velocity(tf, x_t, sigma, prompt_embeds, text_ids):
out = tf(hidden_states=x_t, timestep=sigma, guidance=None,
encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=img_ids,
return_dict=False)[0]
return out[:, : x_t.size(1)]
# ---- optimizer (surrogates only) + cosine schedule ----
if OPT == "muon":
opt = Muon(trainable, lr=BASE_LR, momentum=0.95, weight_decay=cfg.train.weight_decay)
else:
opt = torch.optim.AdamW(trainable, lr=BASE_LR, betas=(0.9, 0.999), weight_decay=cfg.train.weight_decay)
def lr_at(step): # constant ("fixed lr"), or cosine decay to a 15%-of-base floor (no warmup)
if SCHED == "constant":
return BASE_LR
min_lr = MIN_LR_FRAC * BASE_LR
return min_lr + (BASE_LR - min_lr) * 0.5 * (1 + math.cos(math.pi * step / STEPS))
ds = LatentCaptionDataset()
dl = DataLoader(ds, batch_size=MB, shuffle=True, collate_fn=collate, drop_last=True)
def loader():
while True:
for b in dl:
yield b
gen_iter = loader()
log(f"dataset: {len(ds)} samples; eff batch {MB*ACCUM}")
# ---- fixed held-out eval batch (objective, apples-to-apples metric) ----
EVAL_N = 16
ev_x0 = ds.latents[:EVAL_N].to("cuda", torch.bfloat16)
ev_caps = ds.captions[:EVAL_N]
with torch.no_grad():
ev_pe, ev_tid = pipe.encode_prompt(ev_caps, device="cuda")
ev_gen = torch.Generator(device="cuda").manual_seed(1234)
ev_eps = torch.randn(ev_x0.shape, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_sigma = torch.rand(EVAL_N, generator=ev_gen, device="cuda", dtype=torch.float32)
ev_xt = build_x_t(ev_x0.float(), ev_eps, ev_sigma).to(torch.bfloat16)
@torch.no_grad()
def eval_loss():
student.eval()
with torch.autocast("cuda", dtype=torch.bfloat16):
vt = velocity(teacher, ev_xt, ev_sigma, ev_pe, ev_tid)
vs = velocity(student, ev_xt, ev_sigma, ev_pe, ev_tid)
return float(velocity_match_loss(vs, vt))
SAMPLE_PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
]
@torch.no_grad()
def sample(tag):
bak = pipe.transformer
student.eval(); pipe.transformer = student
g = torch.Generator(device="cuda").manual_seed(0)
# fp32 master surrogates -> run under autocast so their matmuls execute in bf16.
with torch.autocast("cuda", dtype=torch.bfloat16):
imgs = pipe(prompt=SAMPLE_PROMPTS, num_inference_steps=4, guidance_scale=1.0,
height=512, width=512, generator=g).images
for i, im in enumerate(imgs):
im.save(f"{OUT}/samples/{tag}_{i}.png")
pipe.transformer = bak
# ---- baseline before any training ----
ev0 = eval_loss()
log(f"step 0 eval_vel_loss={ev0:.4f} (baseline, surrogates=warm-start)")
sample("step0000")
best = float("inf")
t0 = time.time()
run = run_v = run_f = 0.0
for step in range(1, STEPS + 1):
lr = lr_at(step)
for g in opt.param_groups:
g["lr"] = lr
student.train()
opt.zero_grad(set_to_none=True)
acc = acc_v = acc_f = 0.0
for _ in range(ACCUM):
x0, caps = next(gen_iter)
x0 = x0.to("cuda", torch.bfloat16)
with torch.no_grad():
pe, tid = pipe.encode_prompt(caps, device="cuda")
eps = torch.randn_like(x0)
sigma = torch.rand(x0.size(0), device="cuda", dtype=torch.float32)
x_t = build_x_t(x0.float(), eps.float(), sigma).to(torch.bfloat16)
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
vt = velocity(teacher, x_t, sigma, pe, tid)
with torch.autocast("cuda", dtype=torch.bfloat16):
vs = velocity(student, x_t, sigma, pe, tid)
lv = velocity_match_loss(vs, vt)
lf = flow_match_loss(vs, eps, x0)
loss = (W_VEL * lv + W_FM * lf) / ACCUM
loss.backward()
acc += loss.item(); acc_v += lv.item() / ACCUM; acc_f += lf.item() / ACCUM
gnorm = torch.nn.utils.clip_grad_norm_(trainable, CLIP).item()
opt.step()
run += acc; run_v += acc_v; run_f += acc_f
if step % 10 == 0:
n = 10
log(f"step {step:4d} loss={run/n:.4f} (vel={run_v/n:.4f} fm={run_f/n:.4f}) "
f"lr={lr:.2e} gnorm={gnorm:.3f} {step/(time.time()-t0):.2f}it/s")
run = run_v = run_f = 0.0
if step % 50 == 0 or step == STEPS:
ev = eval_loss()
improved = " *BEST*" if ev < best else ""
log(f" eval_vel_loss={ev:.4f} (baseline {ev0:.4f}){improved}")
sample(f"step{step:04d}")
if ev < best:
best = ev
torch.save(student.state_dict(), f"{OUT}/student_best.pt")
if ev > 3 * ev0 and step > 50: # divergence guard
log(f"DIVERGENCE: eval_loss {ev:.3f} > 3x baseline {ev0:.3f}; stopping early.")
break
torch.save(student.state_dict(), f"{OUT}/student_final.pt")
log(f"DONE best_eval_vel_loss={best:.4f} baseline={ev0:.4f} "
f"improvement={100*(ev0-best)/ev0:.1f}% peakVRAM={torch.cuda.max_memory_allocated()/1e9:.1f}GB")
log("TRAIN_DONE")
logf.close()

Xet Storage Details

Size:
9.08 kB
·
Xet hash:
655a28eb9c01f0515611fb07f56039273d6cbcf62f17ea81583086a4f2aabcab

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.