PhysioJEPA / scripts /prepare_data.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Download all MIMIC shards on the training host, then build the window index.
Run on the RunPod pod right after boot. Saves to /workspace/cache/.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from physiojepa.data import MIMICAlignedDataset
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--root", type=str, default="/workspace/cache/mimic")
ap.add_argument("--index", type=str, default="/workspace/cache/mimic_index.json")
ap.add_argument("--n_shards", type=int, default=412)
args = ap.parse_args()
root = Path(args.root)
root.mkdir(parents=True, exist_ok=True)
patterns = [f"shard_{i:05d}/*" for i in range(args.n_shards)]
print(f"[prepare] downloading {len(patterns)} shard patterns to {root}")
local = snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns,
local_dir=str(root), max_workers=16)
shard_roots = sorted([p for p in Path(local).glob("shard_*")
if (p / "dataset_info.json").exists()])
print(f"[prepare] {len(shard_roots)} shards ready; building window index")
ds = MIMICAlignedDataset(shard_roots=shard_roots, index_path=Path(args.index),
build_index=True)
info = {
"n_shards": len(shard_roots),
"n_windows": len(ds),
"n_subjects": len(set(r["subject_id"] for r in ds.index)),
"shard_roots": [str(p) for p in shard_roots],
}
Path(args.index).with_suffix(".meta.json").write_text(json.dumps(info, indent=2))
print(f"[prepare] index built: {json.dumps(info, indent=2)}")
if __name__ == "__main__":
main()