Buckets:

Pranav2748's picture
download
raw
5.24 kB
"""Basic distillation training: recover the v2 student via velocity matching to the frozen
teacher + real-data flow-matching grounding on cached monet latents. Muon(2D)+AdamW.
Dev-scale: small batch, short run, A100/SDPA, student in fp32 (master), bf16 autocast compute.
Trajectory/feature-matching upgrades noted in plan.md; this validates the loop + shows recovery.
"""
import json
import os
import sys
import time
import torch
from torch.utils.data import DataLoader
from flux2distill.config import Config
from flux2distill.data import LatentCaptionDataset, collate
from flux2distill.losses import velocity_match_loss, flow_match_loss, build_x_t
from flux2distill.model_utils import load_pipeline, load_transformer, build_param_groups, param_summary
from flux2distill.optim import Muon
from flux2distill.surgery import attach_surrogates
cfg = Config()
STEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 300
OUT = "outputs/train_v2"
os.makedirs(OUT, exist_ok=True)
os.makedirs(f"{OUT}/samples", exist_ok=True)
MB, ACCUM = 2, 4
# ---- teacher (frozen) via pipeline; student v2 separate, trainable ----
print("loading teacher pipeline (frozen)...")
pipe = load_pipeline(device="cuda")
teacher = pipe.transformer
teacher.eval().requires_grad_(False)
print("loading v2 student (bf16; fp32-master is a noted upgrade)...")
sel = json.load(open("outputs/student_v2/selection.json"))
student = load_transformer(dtype="bfloat16", device="cuda")
attach_surrogates(student, sel["surrogate_idx"], rank=sel["rank"], act=sel["act"],
device="cuda", dtype=torch.bfloat16)
state = torch.load("outputs/student_v2/student_state.pt", map_location="cuda")
student.load_state_dict(state, strict=False)
student.train().requires_grad_(True)
print("student:", param_summary(student))
# img_ids for 512 (shared across batch)
_, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda",
torch.Generator(device="cuda").manual_seed(0))
def velocity(tf, x_t, sigma, prompt_embeds, text_ids):
out = tf(hidden_states=x_t, timestep=sigma, guidance=None,
encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=img_ids,
return_dict=False)[0]
return out[:, : x_t.size(1)]
# ---- optimizers: Muon on 2D student weights, AdamW on the rest ----
muon_p, adamw_p, ginfo = build_param_groups(student, cfg.train.lr_muon, cfg.train.lr_adamw)
print("param groups:", ginfo)
opt_muon = Muon(muon_p, lr=cfg.train.lr_muon, weight_decay=cfg.train.weight_decay)
opt_adamw = torch.optim.AdamW(adamw_p, lr=cfg.train.lr_adamw, weight_decay=cfg.train.weight_decay)
ds = LatentCaptionDataset()
dl = DataLoader(ds, batch_size=MB, shuffle=True, collate_fn=collate, drop_last=True)
print(f"dataset: {len(ds)} samples; {STEPS} steps, micro-batch {MB} x accum {ACCUM} = eff {MB*ACCUM}")
SAMPLE_PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water",
]
@torch.no_grad()
def sample(tag):
teacher_bak = pipe.transformer
student.eval()
pipe.transformer = student
gen = torch.Generator(device="cuda").manual_seed(0)
imgs = pipe(prompt=SAMPLE_PROMPTS, num_inference_steps=4, guidance_scale=1.0,
height=512, width=512, generator=gen).images
for i, im in enumerate(imgs):
im.save(f"{OUT}/samples/{tag}_{i}.png")
pipe.transformer = teacher_bak
student.train()
def loader():
while True:
for b in dl:
yield b
sample("step0000") # baseline before training
gen_iter = loader()
t0 = time.time()
running = 0.0
for step in range(1, STEPS + 1):
opt_muon.zero_grad(set_to_none=True)
opt_adamw.zero_grad(set_to_none=True)
loss_acc = 0.0
for _ in range(ACCUM):
x0, caps = next(gen_iter)
x0 = x0.to("cuda", torch.bfloat16)
with torch.no_grad():
prompt_embeds, text_ids = pipe.encode_prompt(caps, device="cuda")
eps = torch.randn_like(x0)
sigma = torch.rand(x0.size(0), device="cuda", dtype=torch.float32)
x_t = build_x_t(x0.float(), eps.float(), sigma).to(torch.bfloat16)
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
v_t = velocity(teacher, x_t, sigma, prompt_embeds, text_ids)
with torch.autocast("cuda", dtype=torch.bfloat16):
v_s = velocity(student, x_t, sigma, prompt_embeds, text_ids)
loss = (cfg.train.w_velocity * velocity_match_loss(v_s, v_t)
+ cfg.train.w_flow * flow_match_loss(v_s, eps, x0)) / ACCUM
loss.backward()
loss_acc += loss.item()
opt_muon.step()
opt_adamw.step()
running += loss_acc
if step % cfg.train.log_every == 0:
avg = running / cfg.train.log_every
running = 0.0
sps = step / (time.time() - t0)
print(f" step {step:4d}/{STEPS} loss {avg:.4f} {sps:.2f} it/s")
if step % 100 == 0:
sample(f"step{step:04d}")
torch.save(student.state_dict(), f"{OUT}/student_trained.pt")
print(f"saved trained student to {OUT}/student_trained.pt; samples in {OUT}/samples/")
print(f"peak VRAM: {torch.cuda.max_memory_allocated()/1e9:.1f}GB")

Xet Storage Details

Size:
5.24 kB
·
Xet hash:
32fdd7f581414a8da5f7bd7ed804352d84a5a2cb083f8a71af4fced23b33ac7b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.