| """Download all MIMIC shards on the training host, then build the window index. |
| |
| Run on the RunPod pod right after boot. Saves to /workspace/cache/. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
|
|
| from dotenv import load_dotenv |
| from huggingface_hub import snapshot_download |
|
|
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
|
|
| from physiojepa.data import MIMICAlignedDataset |
|
|
| REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--root", type=str, default="/workspace/cache/mimic") |
| ap.add_argument("--index", type=str, default="/workspace/cache/mimic_index.json") |
| ap.add_argument("--n_shards", type=int, default=412) |
| args = ap.parse_args() |
|
|
| root = Path(args.root) |
| root.mkdir(parents=True, exist_ok=True) |
| patterns = [f"shard_{i:05d}/*" for i in range(args.n_shards)] |
| print(f"[prepare] downloading {len(patterns)} shard patterns to {root}") |
| local = snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, |
| local_dir=str(root), max_workers=16) |
| shard_roots = sorted([p for p in Path(local).glob("shard_*") |
| if (p / "dataset_info.json").exists()]) |
| print(f"[prepare] {len(shard_roots)} shards ready; building window index") |
| ds = MIMICAlignedDataset(shard_roots=shard_roots, index_path=Path(args.index), |
| build_index=True) |
| info = { |
| "n_shards": len(shard_roots), |
| "n_windows": len(ds), |
| "n_subjects": len(set(r["subject_id"] for r in ds.index)), |
| "shard_roots": [str(p) for p in shard_roots], |
| } |
| Path(args.index).with_suffix(".meta.json").write_text(json.dumps(info, indent=2)) |
| print(f"[prepare] index built: {json.dumps(info, indent=2)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|