Buckets:
| """Generate teacher@4 vs student@4 images across the eval prompt set at matched seeds, | |
| save individual images + side-by-side composites, and an index for the visual review.""" | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| from flux2distill.config import Config | |
| from flux2distill.eval_utils import parse_prompts, load_student, side_by_side | |
| from flux2distill.model_utils import load_pipeline | |
| cfg = Config() | |
| TAG = sys.argv[1] if len(sys.argv) > 1 else "baseline" # e.g. baseline | trained | |
| OUT = f"outputs/eval/{TAG}" | |
| for sub in ("teacher", "student", "compare"): | |
| os.makedirs(f"{OUT}/{sub}", exist_ok=True) | |
| items = parse_prompts(cfg.eval.prompts_path) | |
| print(f"{len(items)} prompts; tag={TAG}") | |
| def gen_all(pipe, which): | |
| paths = [] | |
| for idx, cat, prompt in items: | |
| gen = torch.Generator(device="cuda").manual_seed(1000 + idx) # matched per-prompt seed | |
| im = pipe(prompt=prompt, num_inference_steps=cfg.eval.num_inference_steps, | |
| guidance_scale=cfg.eval.guidance_scale, height=512, width=512, | |
| generator=gen).images[0] | |
| p = f"{OUT}/{which}/{idx:02d}_{cat}.png" | |
| im.save(p) | |
| paths.append(im) | |
| return paths | |
| print("loading pipeline (teacher)...") | |
| pipe = load_pipeline(device="cuda") | |
| t0 = time.time() | |
| teacher_imgs = gen_all(pipe, "teacher") | |
| print(f"teacher: {len(teacher_imgs)} imgs in {time.time()-t0:.1f}s") | |
| print("loading student weights into pipeline...") | |
| pipe, sel = load_student(pipe, "outputs/student/selection.json", "outputs/student/student_state.pt") | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.time() | |
| student_imgs = gen_all(pipe, "student") | |
| dt = time.time() - t0 | |
| print(f"student: {len(student_imgs)} imgs in {dt:.1f}s ({dt/len(items):.2f}s/img)") | |
| # composites | |
| for (idx, cat, prompt), ti, si in zip(items, teacher_imgs, student_imgs): | |
| comp = side_by_side(ti, si, "TEACHER 4B", "STUDENT 2.44B", | |
| f"[{cat}] {prompt}") | |
| comp.save(f"{OUT}/compare/{idx:02d}_{cat}.png") | |
| with open(f"{OUT}/index.json", "w") as f: | |
| json.dump([{"idx": i, "category": c, "prompt": p} for i, c, p in items], f, indent=2) | |
| print(f"saved teacher/, student/, compare/ and index.json to {OUT}/") | |
| print("compare/ holds side-by-side pairs for visual review.") | |
Xet Storage Details
- Size:
- 2.27 kB
- Xet hash:
- c69f193db322765615d5172f8b9432f4f1bf7e567fc04450152ad32388a17330
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.