Buckets:

Mercity/FluxDistill / scripts /06_cache_data.py
Pranav2748's picture
download
raw
2.6 kB
"""Download monet images from URL @512 (concurrent), VAE-encode in batches, cache
(packed latent, caption). Robust to dead links. -> data/monet_cache/{latents.pt,captions.json}."""
import io
import json
import os
import sys
from concurrent.futures import ThreadPoolExecutor
import requests
import torch
from PIL import Image
from datasets import load_dataset
N = int(sys.argv[1]) if len(sys.argv) > 1 else 200
OUT = "data/monet_cache"
os.makedirs(OUT, exist_ok=True)
UA = {"User-Agent": "Mozilla/5.0 (research dataset fetch)"}
WORKERS, CHUNK = 32, 96
from flux2distill.model_utils import load_pipeline
print("loading pipeline (VAE encode)...", flush=True)
pipe = load_pipeline(device="cuda")
gen = torch.Generator(device="cuda").manual_seed(0)
def fetch(item):
url, cap = item
try:
r = requests.get(url, headers=UA, timeout=6)
if r.status_code != 200 or not r.content:
return None
img = Image.open(io.BytesIO(r.content)).convert("RGB")
return (img, cap)
except Exception:
return None
@torch.no_grad()
def encode_batch(imgs):
px = torch.cat([pipe.image_processor.preprocess(im, height=512, width=512) for im in imgs])
px = px.to("cuda", torch.bfloat16)
lat = pipe._encode_vae_image(px, gen) # (B,128,32,32)
packed = pipe._pack_latents(lat) # (B,1024,128)
return packed.to("cpu", torch.bfloat16)
print(f"streaming jasperai/monet, target {N} @512 ({WORKERS} workers)...", flush=True)
ds = load_dataset("jasperai/monet", split="train", streaming=True)
lat_list, caps, seen, ok = [], [], 0, 0
buf = []
with ThreadPoolExecutor(max_workers=WORKERS) as ex:
for row in ds:
if ok >= N:
break
url = row.get("url"); cap = row.get("caption_gemini-2.5-flash-lite") or row.get("caption_original")
if not url or not cap:
continue
buf.append((url, cap)); seen += 1
if len(buf) >= CHUNK:
results = [r for r in ex.map(fetch, buf) if r is not None]
buf = []
if results:
imgs = [r[0] for r in results]
packed = encode_batch(imgs)
lat_list.append(packed); caps.extend(r[1] for r in results); ok += len(results)
print(f" cached {ok}/{N} (seen {seen}, hit-rate {ok/seen:.0%})", flush=True)
latents = torch.cat(lat_list)[:N]
caps = caps[:N]
torch.save(latents, f"{OUT}/latents.pt")
json.dump(caps, open(f"{OUT}/captions.json", "w"))
print(f"DONE: cached {len(caps)} (from {seen} rows) -> {OUT}/ latents {tuple(latents.shape)}", flush=True)

Xet Storage Details

Size:
2.6 kB
·
Xet hash:
f4957e10190df258d352697427f90802eb8137da162716a67262861fd90185c6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.