Buckets:
| """Download a LARGE full-res calibration set from jasperai/monet (from `url`, NOT the | |
| thumbnail `image`), VAE-encode @512, cache (packed latent, caption) for SVDQuant | |
| calibration. Same robust concurrent fetch as 06_cache_data.py, bigger default, separate | |
| output dir so it never clobbers the small training cache. | |
| Usage: python3 scripts/11_cache_calib.py [N=6000] [OUT=data/monet_calib] | |
| """ | |
| import io | |
| import json | |
| import os | |
| import sys | |
| from concurrent.futures import ThreadPoolExecutor | |
| import requests | |
| import torch | |
| from PIL import Image | |
| from datasets import load_dataset | |
| N = int(sys.argv[1]) if len(sys.argv) > 1 else 6000 | |
| OUT = sys.argv[2] if len(sys.argv) > 2 else "data/monet_calib" | |
| os.makedirs(OUT, exist_ok=True) | |
| UA = {"User-Agent": "Mozilla/5.0 (research dataset fetch)"} | |
| WORKERS, CHUNK = 32, 96 | |
| SAVE_EVERY = 1000 # checkpoint the growing cache so a long download is resumable-ish | |
| from flux2distill.model_utils import load_pipeline | |
| print("loading pipeline (VAE encode)...", flush=True) | |
| pipe = load_pipeline(device="cuda") | |
| gen = torch.Generator(device="cuda").manual_seed(0) | |
| def fetch(item): | |
| url, cap = item | |
| try: | |
| r = requests.get(url, headers=UA, timeout=8) | |
| if r.status_code != 200 or not r.content: | |
| return None | |
| img = Image.open(io.BytesIO(r.content)).convert("RGB") | |
| # guard against tiny/placeholder thumbnails sneaking in | |
| if min(img.size) < 256: | |
| return None | |
| return (img, cap) | |
| except Exception: | |
| return None | |
| def encode_batch(imgs): | |
| px = torch.cat([pipe.image_processor.preprocess(im, height=512, width=512) for im in imgs]) | |
| px = px.to("cuda", torch.bfloat16) | |
| lat = pipe._encode_vae_image(px, gen) # (B,128,32,32) | |
| packed = pipe._pack_latents(lat) # (B,1024,128) | |
| return packed.to("cpu", torch.bfloat16) | |
| def flush(lat_list, caps): | |
| latents = torch.cat(lat_list) | |
| torch.save(latents, f"{OUT}/latents.pt") | |
| json.dump(caps, open(f"{OUT}/captions.json", "w")) | |
| return latents.shape | |
| print(f"streaming jasperai/monet (full-res via url), target {N} @512 ({WORKERS} workers)...", flush=True) | |
| ds = load_dataset("jasperai/monet", split="train", streaming=True) | |
| lat_list, caps, seen, ok, last_save = [], [], 0, 0, 0 | |
| buf = [] | |
| with ThreadPoolExecutor(max_workers=WORKERS) as ex: | |
| for row in ds: | |
| if ok >= N: | |
| break | |
| url = row.get("url") | |
| cap = row.get("caption_gemini-2.5-flash-lite") or row.get("caption_original") | |
| if not url or not cap: | |
| continue | |
| buf.append((url, cap)); seen += 1 | |
| if len(buf) >= CHUNK: | |
| results = [r for r in ex.map(fetch, buf) if r is not None] | |
| buf = [] | |
| if results: | |
| imgs = [r[0] for r in results] | |
| packed = encode_batch(imgs) | |
| lat_list.append(packed); caps.extend(r[1] for r in results); ok += len(results) | |
| print(f" cached {ok}/{N} (seen {seen}, hit-rate {ok/seen:.0%})", flush=True) | |
| if ok - last_save >= SAVE_EVERY: | |
| shp = flush(lat_list, caps); last_save = ok | |
| print(f" [checkpoint] saved {tuple(shp)}", flush=True) | |
| latents = torch.cat(lat_list)[:N] | |
| caps = caps[:N] | |
| torch.save(latents, f"{OUT}/latents.pt") | |
| json.dump(caps, open(f"{OUT}/captions.json", "w")) | |
| print(f"DONE: cached {len(caps)} (from {seen} rows) -> {OUT}/ latents {tuple(latents.shape)}", flush=True) | |
Xet Storage Details
- Size:
- 3.47 kB
- Xet hash:
- f877c4a6a73da5e86e19268d4c09fe24fc89b79e05cd9533025fd15c4b8c8da0
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.