Buckets:

Mercity/FluxDistill / scripts /04_gen_eval.py
Pranav2748's picture
download
raw
2.27 kB
"""Generate teacher@4 vs student@4 images across the eval prompt set at matched seeds,
save individual images + side-by-side composites, and an index for the visual review."""
import json
import os
import sys
import time
import torch
from flux2distill.config import Config
from flux2distill.eval_utils import parse_prompts, load_student, side_by_side
from flux2distill.model_utils import load_pipeline
cfg = Config()
TAG = sys.argv[1] if len(sys.argv) > 1 else "baseline" # e.g. baseline | trained
OUT = f"outputs/eval/{TAG}"
for sub in ("teacher", "student", "compare"):
os.makedirs(f"{OUT}/{sub}", exist_ok=True)
items = parse_prompts(cfg.eval.prompts_path)
print(f"{len(items)} prompts; tag={TAG}")
def gen_all(pipe, which):
paths = []
for idx, cat, prompt in items:
gen = torch.Generator(device="cuda").manual_seed(1000 + idx) # matched per-prompt seed
im = pipe(prompt=prompt, num_inference_steps=cfg.eval.num_inference_steps,
guidance_scale=cfg.eval.guidance_scale, height=512, width=512,
generator=gen).images[0]
p = f"{OUT}/{which}/{idx:02d}_{cat}.png"
im.save(p)
paths.append(im)
return paths
print("loading pipeline (teacher)...")
pipe = load_pipeline(device="cuda")
t0 = time.time()
teacher_imgs = gen_all(pipe, "teacher")
print(f"teacher: {len(teacher_imgs)} imgs in {time.time()-t0:.1f}s")
print("loading student weights into pipeline...")
pipe, sel = load_student(pipe, "outputs/student/selection.json", "outputs/student/student_state.pt")
torch.cuda.reset_peak_memory_stats()
t0 = time.time()
student_imgs = gen_all(pipe, "student")
dt = time.time() - t0
print(f"student: {len(student_imgs)} imgs in {dt:.1f}s ({dt/len(items):.2f}s/img)")
# composites
for (idx, cat, prompt), ti, si in zip(items, teacher_imgs, student_imgs):
comp = side_by_side(ti, si, "TEACHER 4B", "STUDENT 2.44B",
f"[{cat}] {prompt}")
comp.save(f"{OUT}/compare/{idx:02d}_{cat}.png")
with open(f"{OUT}/index.json", "w") as f:
json.dump([{"idx": i, "category": c, "prompt": p} for i, c, p in items], f, indent=2)
print(f"saved teacher/, student/, compare/ and index.json to {OUT}/")
print("compare/ holds side-by-side pairs for visual review.")

Xet Storage Details

Size:
2.27 kB
·
Xet hash:
c69f193db322765615d5172f8b9432f4f1bf7e567fc04450152ad32388a17330

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.