Buckets:

Pranav2748's picture
download
raw
1.83 kB
"""Benchmark inference latency: teacher vs students (per-token drop6, linear-attn drop6).
Reports s/img @512/4steps and speedup vs teacher."""
import time
import torch
from flux2distill.model_utils import load_pipeline, param_summary
from flux2distill.eval_utils import load_student
PROMPTS = ["a serene mountain lake at sunrise", "a vintage bookshop storefront",
"a close-up portrait of an elderly fisherman", "a futuristic city skyline at night",
"a macro photo of dew on a spider web", "an oil painting of a bowl of fruit"]
def bench(pipe, label):
torch.cuda.synchronize()
# warmup
with torch.autocast("cuda", dtype=torch.bfloat16):
pipe(prompt=PROMPTS[0], num_inference_steps=4, guidance_scale=1.0, height=512, width=512,
generator=torch.Generator(device="cuda").manual_seed(0))
torch.cuda.synchronize(); t0 = time.time()
for p in PROMPTS:
with torch.autocast("cuda", dtype=torch.bfloat16):
pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0, height=512, width=512,
generator=torch.Generator(device="cuda").manual_seed(0))
torch.cuda.synchronize()
dt = (time.time() - t0) / len(PROMPTS)
print(f"{label:28s} {dt:.3f} s/img params={param_summary(pipe.transformer)['total_B']:.3f}B")
return dt
pipe = load_pipeline(device="cuda")
t_teacher = bench(pipe, "teacher 4B")
load_student(pipe, "outputs/student_v2/selection.json", "outputs/student_v2/student_state.pt")
t_v2 = bench(pipe, "per-token drop6 (v2)")
pipe2 = load_pipeline(device="cuda")
load_student(pipe2, "outputs/student_linattn/selection.json", "outputs/student_linattn/student_state.pt")
t_la = bench(pipe2, "linear-attn drop6")
print(f"\nspeedup vs teacher: per-token drop6 = {t_teacher/t_v2:.2f}x linear-attn drop6 = {t_teacher/t_la:.2f}x")

Xet Storage Details

Size:
1.83 kB
·
Xet hash:
55e78de01191a95cbbac9ac2361c105f45c11ccde41decd1c74ee67e008c34f0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.