Buckets:
| """Basic distillation training: recover the v2 student via velocity matching to the frozen | |
| teacher + real-data flow-matching grounding on cached monet latents. Muon(2D)+AdamW. | |
| Dev-scale: small batch, short run, A100/SDPA, student in fp32 (master), bf16 autocast compute. | |
| Trajectory/feature-matching upgrades noted in plan.md; this validates the loop + shows recovery. | |
| """ | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from flux2distill.config import Config | |
| from flux2distill.data import LatentCaptionDataset, collate | |
| from flux2distill.losses import velocity_match_loss, flow_match_loss, build_x_t | |
| from flux2distill.model_utils import load_pipeline, load_transformer, build_param_groups, param_summary | |
| from flux2distill.optim import Muon | |
| from flux2distill.surgery import attach_surrogates | |
| cfg = Config() | |
| STEPS = int(sys.argv[1]) if len(sys.argv) > 1 else 300 | |
| OUT = "outputs/train_v2" | |
| os.makedirs(OUT, exist_ok=True) | |
| os.makedirs(f"{OUT}/samples", exist_ok=True) | |
| MB, ACCUM = 2, 4 | |
| # ---- teacher (frozen) via pipeline; student v2 separate, trainable ---- | |
| print("loading teacher pipeline (frozen)...") | |
| pipe = load_pipeline(device="cuda") | |
| teacher = pipe.transformer | |
| teacher.eval().requires_grad_(False) | |
| print("loading v2 student (bf16; fp32-master is a noted upgrade)...") | |
| sel = json.load(open("outputs/student_v2/selection.json")) | |
| student = load_transformer(dtype="bfloat16", device="cuda") | |
| attach_surrogates(student, sel["surrogate_idx"], rank=sel["rank"], act=sel["act"], | |
| device="cuda", dtype=torch.bfloat16) | |
| state = torch.load("outputs/student_v2/student_state.pt", map_location="cuda") | |
| student.load_state_dict(state, strict=False) | |
| student.train().requires_grad_(True) | |
| print("student:", param_summary(student)) | |
| # img_ids for 512 (shared across batch) | |
| _, img_ids = pipe.prepare_latents(1, 32, 512, 512, torch.bfloat16, "cuda", | |
| torch.Generator(device="cuda").manual_seed(0)) | |
| def velocity(tf, x_t, sigma, prompt_embeds, text_ids): | |
| out = tf(hidden_states=x_t, timestep=sigma, guidance=None, | |
| encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=img_ids, | |
| return_dict=False)[0] | |
| return out[:, : x_t.size(1)] | |
| # ---- optimizers: Muon on 2D student weights, AdamW on the rest ---- | |
| muon_p, adamw_p, ginfo = build_param_groups(student, cfg.train.lr_muon, cfg.train.lr_adamw) | |
| print("param groups:", ginfo) | |
| opt_muon = Muon(muon_p, lr=cfg.train.lr_muon, weight_decay=cfg.train.weight_decay) | |
| opt_adamw = torch.optim.AdamW(adamw_p, lr=cfg.train.lr_adamw, weight_decay=cfg.train.weight_decay) | |
| ds = LatentCaptionDataset() | |
| dl = DataLoader(ds, batch_size=MB, shuffle=True, collate_fn=collate, drop_last=True) | |
| print(f"dataset: {len(ds)} samples; {STEPS} steps, micro-batch {MB} x accum {ACCUM} = eff {MB*ACCUM}") | |
| SAMPLE_PROMPTS = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a serene mountain lake at sunrise reflecting snow-capped peaks, mist over the water", | |
| ] | |
| def sample(tag): | |
| teacher_bak = pipe.transformer | |
| student.eval() | |
| pipe.transformer = student | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| imgs = pipe(prompt=SAMPLE_PROMPTS, num_inference_steps=4, guidance_scale=1.0, | |
| height=512, width=512, generator=gen).images | |
| for i, im in enumerate(imgs): | |
| im.save(f"{OUT}/samples/{tag}_{i}.png") | |
| pipe.transformer = teacher_bak | |
| student.train() | |
| def loader(): | |
| while True: | |
| for b in dl: | |
| yield b | |
| sample("step0000") # baseline before training | |
| gen_iter = loader() | |
| t0 = time.time() | |
| running = 0.0 | |
| for step in range(1, STEPS + 1): | |
| opt_muon.zero_grad(set_to_none=True) | |
| opt_adamw.zero_grad(set_to_none=True) | |
| loss_acc = 0.0 | |
| for _ in range(ACCUM): | |
| x0, caps = next(gen_iter) | |
| x0 = x0.to("cuda", torch.bfloat16) | |
| with torch.no_grad(): | |
| prompt_embeds, text_ids = pipe.encode_prompt(caps, device="cuda") | |
| eps = torch.randn_like(x0) | |
| sigma = torch.rand(x0.size(0), device="cuda", dtype=torch.float32) | |
| x_t = build_x_t(x0.float(), eps.float(), sigma).to(torch.bfloat16) | |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| v_t = velocity(teacher, x_t, sigma, prompt_embeds, text_ids) | |
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| v_s = velocity(student, x_t, sigma, prompt_embeds, text_ids) | |
| loss = (cfg.train.w_velocity * velocity_match_loss(v_s, v_t) | |
| + cfg.train.w_flow * flow_match_loss(v_s, eps, x0)) / ACCUM | |
| loss.backward() | |
| loss_acc += loss.item() | |
| opt_muon.step() | |
| opt_adamw.step() | |
| running += loss_acc | |
| if step % cfg.train.log_every == 0: | |
| avg = running / cfg.train.log_every | |
| running = 0.0 | |
| sps = step / (time.time() - t0) | |
| print(f" step {step:4d}/{STEPS} loss {avg:.4f} {sps:.2f} it/s") | |
| if step % 100 == 0: | |
| sample(f"step{step:04d}") | |
| torch.save(student.state_dict(), f"{OUT}/student_trained.pt") | |
| print(f"saved trained student to {OUT}/student_trained.pt; samples in {OUT}/samples/") | |
| print(f"peak VRAM: {torch.cuda.max_memory_allocated()/1e9:.1f}GB") | |
Xet Storage Details
- Size:
- 5.24 kB
- Xet hash:
- 32fdd7f581414a8da5f7bd7ed804352d84a5a2cb083f8a71af4fced23b33ac7b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.