Buckets:
| """Download monet images from URL @512 (concurrent), VAE-encode in batches, cache | |
| (packed latent, caption). Robust to dead links. -> data/monet_cache/{latents.pt,captions.json}.""" | |
| import io | |
| import json | |
| import os | |
| import sys | |
| from concurrent.futures import ThreadPoolExecutor | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from datasets import load_dataset | |
| N = int(sys.argv[1]) if len(sys.argv) > 1 else 200 | |
| OUT = "data/monet_cache" | |
| os.makedirs(OUT, exist_ok=True) | |
| UA = {"User-Agent": "Mozilla/5.0 (research dataset fetch)"} | |
| WORKERS, CHUNK = 32, 96 | |
| from flux2distill.model_utils import load_pipeline | |
| print("loading pipeline (VAE encode)...", flush=True) | |
| pipe = load_pipeline(device="cuda") | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| def fetch(item): | |
| url, cap = item | |
| try: | |
| r = requests.get(url, headers=UA, timeout=6) | |
| if r.status_code != 200 or not r.content: | |
| return None | |
| img = Image.open(io.BytesIO(r.content)).convert("RGB") | |
| return (img, cap) | |
| except Exception: | |
| return None | |
| def encode_batch(imgs): | |
| px = torch.cat([pipe.image_processor.preprocess(im, height=512, width=512) for im in imgs]) | |
| px = px.to("cuda", torch.bfloat16) | |
| lat = pipe._encode_vae_image(px, gen) # (B,128,32,32) | |
| packed = pipe._pack_latents(lat) # (B,1024,128) | |
| return packed.to("cpu", torch.bfloat16) | |
| print(f"streaming jasperai/monet, target {N} @512 ({WORKERS} workers)...", flush=True) | |
| ds = load_dataset("jasperai/monet", split="train", streaming=True) | |
| lat_list, caps, seen, ok = [], [], 0, 0 | |
| buf = [] | |
| with ThreadPoolExecutor(max_workers=WORKERS) as ex: | |
| for row in ds: | |
| if ok >= N: | |
| break | |
| url = row.get("url"); cap = row.get("caption_gemini-2.5-flash-lite") or row.get("caption_original") | |
| if not url or not cap: | |
| continue | |
| buf.append((url, cap)); seen += 1 | |
| if len(buf) >= CHUNK: | |
| results = [r for r in ex.map(fetch, buf) if r is not None] | |
| buf = [] | |
| if results: | |
| imgs = [r[0] for r in results] | |
| packed = encode_batch(imgs) | |
| lat_list.append(packed); caps.extend(r[1] for r in results); ok += len(results) | |
| print(f" cached {ok}/{N} (seen {seen}, hit-rate {ok/seen:.0%})", flush=True) | |
| latents = torch.cat(lat_list)[:N] | |
| caps = caps[:N] | |
| torch.save(latents, f"{OUT}/latents.pt") | |
| json.dump(caps, open(f"{OUT}/captions.json", "w")) | |
| print(f"DONE: cached {len(caps)} (from {seen} rows) -> {OUT}/ latents {tuple(latents.shape)}", flush=True) | |
Xet Storage Details
- Size:
- 2.6 kB
- Xet hash:
- f4957e10190df258d352697427f90802eb8137da162716a67262861fd90185c6
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.