Buckets:

Mercity/FluxDistill / scripts /11_cache_calib.py
Pranav2748's picture
download
raw
3.47 kB
"""Download a LARGE full-res calibration set from jasperai/monet (from `url`, NOT the
thumbnail `image`), VAE-encode @512, cache (packed latent, caption) for SVDQuant
calibration. Same robust concurrent fetch as 06_cache_data.py, bigger default, separate
output dir so it never clobbers the small training cache.
Usage: python3 scripts/11_cache_calib.py [N=6000] [OUT=data/monet_calib]
"""
import io
import json
import os
import sys
from concurrent.futures import ThreadPoolExecutor
import requests
import torch
from PIL import Image
from datasets import load_dataset
N = int(sys.argv[1]) if len(sys.argv) > 1 else 6000
OUT = sys.argv[2] if len(sys.argv) > 2 else "data/monet_calib"
os.makedirs(OUT, exist_ok=True)
UA = {"User-Agent": "Mozilla/5.0 (research dataset fetch)"}
WORKERS, CHUNK = 32, 96
SAVE_EVERY = 1000 # checkpoint the growing cache so a long download is resumable-ish
from flux2distill.model_utils import load_pipeline
print("loading pipeline (VAE encode)...", flush=True)
pipe = load_pipeline(device="cuda")
gen = torch.Generator(device="cuda").manual_seed(0)
def fetch(item):
url, cap = item
try:
r = requests.get(url, headers=UA, timeout=8)
if r.status_code != 200 or not r.content:
return None
img = Image.open(io.BytesIO(r.content)).convert("RGB")
# guard against tiny/placeholder thumbnails sneaking in
if min(img.size) < 256:
return None
return (img, cap)
except Exception:
return None
@torch.no_grad()
def encode_batch(imgs):
px = torch.cat([pipe.image_processor.preprocess(im, height=512, width=512) for im in imgs])
px = px.to("cuda", torch.bfloat16)
lat = pipe._encode_vae_image(px, gen) # (B,128,32,32)
packed = pipe._pack_latents(lat) # (B,1024,128)
return packed.to("cpu", torch.bfloat16)
def flush(lat_list, caps):
latents = torch.cat(lat_list)
torch.save(latents, f"{OUT}/latents.pt")
json.dump(caps, open(f"{OUT}/captions.json", "w"))
return latents.shape
print(f"streaming jasperai/monet (full-res via url), target {N} @512 ({WORKERS} workers)...", flush=True)
ds = load_dataset("jasperai/monet", split="train", streaming=True)
lat_list, caps, seen, ok, last_save = [], [], 0, 0, 0
buf = []
with ThreadPoolExecutor(max_workers=WORKERS) as ex:
for row in ds:
if ok >= N:
break
url = row.get("url")
cap = row.get("caption_gemini-2.5-flash-lite") or row.get("caption_original")
if not url or not cap:
continue
buf.append((url, cap)); seen += 1
if len(buf) >= CHUNK:
results = [r for r in ex.map(fetch, buf) if r is not None]
buf = []
if results:
imgs = [r[0] for r in results]
packed = encode_batch(imgs)
lat_list.append(packed); caps.extend(r[1] for r in results); ok += len(results)
print(f" cached {ok}/{N} (seen {seen}, hit-rate {ok/seen:.0%})", flush=True)
if ok - last_save >= SAVE_EVERY:
shp = flush(lat_list, caps); last_save = ok
print(f" [checkpoint] saved {tuple(shp)}", flush=True)
latents = torch.cat(lat_list)[:N]
caps = caps[:N]
torch.save(latents, f"{OUT}/latents.pt")
json.dump(caps, open(f"{OUT}/captions.json", "w"))
print(f"DONE: cached {len(caps)} (from {seen} rows) -> {OUT}/ latents {tuple(latents.shape)}", flush=True)

Xet Storage Details

Size:
3.47 kB
·
Xet hash:
f877c4a6a73da5e86e19268d4c09fe24fc89b79e05cd9533025fd15c4b8c8da0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.