PhysioJEPA / scripts /precompute_windows.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Precompute all ECG/PPG windows into a single memory-mapped tensor file.
Reads the MIMIC shard index, applies bandpass + zscore per window, and writes
a flat binary file with a companion metadata JSON. At runtime, __getitem__
is a single mmap read (~0.1 ms) instead of load_from_disk + filter (~20 ms).
Output:
/workspace/cache/windows_ecg.bin (float32, [N, 2500])
/workspace/cache/windows_ppg.bin (float32, [N, 1250])
/workspace/cache/windows_meta.json (subject_id per window, N total)
"""
from __future__ import annotations
import argparse
import json
import os
import struct
from pathlib import Path
import numpy as np
from scipy.signal import butter, filtfilt
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
ECG_FS = 250.0
PPG_FS = 125.0
ECG_WIN = 2500
PPG_WIN = 1250
def _bandpass(x, fs, lo, hi, order=3):
ny = 0.5 * fs
b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
return filtfilt(b, a, x, method="gust").astype(np.float32)
def _zscore(x, eps=1e-6):
return ((x - x.mean()) / (x.std() + eps)).astype(np.float32)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--index", required=True)
ap.add_argument("--out_dir", default="/workspace/cache")
ap.add_argument("--workers", type=int, default=1)
args = ap.parse_args()
index = json.loads(Path(args.index).read_text())
out = Path(args.out_dir)
out.mkdir(parents=True, exist_ok=True)
ecg_path = out / "windows_ecg.bin"
ppg_path = out / "windows_ppg.bin"
meta_path = out / "windows_meta.json"
if ecg_path.exists() and ppg_path.exists() and meta_path.exists():
existing = json.loads(meta_path.read_text())
if existing.get("n_windows") == len(index):
print(f"[precompute] already done: {existing['n_windows']} windows")
return
print(f"[precompute] {len(index)} windows to process")
shard_cache = {}
def load_shard(sidx):
if sidx not in shard_cache:
for p in Path(args.out_dir).parent.glob("mimic/shard_*"):
if int(p.name.split("_")[1]) == sidx:
shard_cache[sidx] = load_from_disk(str(p))
break
return shard_cache.get(sidx)
# Find shard root
mimic_root = None
for candidate in [Path(args.out_dir) / "mimic", Path(args.out_dir).parent / "mimic",
Path("/workspace/cache/mimic")]:
if candidate.exists():
mimic_root = candidate
break
assert mimic_root, "mimic shard root not found"
def load_shard_v2(sidx):
if sidx not in shard_cache:
p = mimic_root / f"shard_{sidx:05d}"
if (p / "dataset_info.json").exists():
shard_cache[sidx] = load_from_disk(str(p))
return shard_cache.get(sidx)
subjects = []
n_written = 0
with open(ecg_path, "wb") as f_ecg, open(ppg_path, "wb") as f_ppg:
for rec in tqdm(index, desc="precompute"):
sidx = rec["shard_idx"]
ds = load_shard_v2(sidx)
if ds is None:
continue
row = ds[rec["row_idx"]]
ecg_full = np.asarray(row["ecg"], dtype=np.float32)
ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0]
names = list(row["ecg_names"])
if "II" not in names:
continue
ecg_lead = ecg_full[names.index("II")]
se = rec["win_start_ecg"]
sp = rec["win_start_ppg"]
ecg_win = ecg_lead[se : se + ECG_WIN]
ppg_win = ppg_full[sp : sp + PPG_WIN]
if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN:
continue
ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0))
ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0))
f_ecg.write(ecg_win.tobytes())
f_ppg.write(ppg_win.tobytes())
subjects.append(rec["subject_id"])
n_written += 1
meta = {
"n_windows": n_written,
"ecg_win": ECG_WIN,
"ppg_win": PPG_WIN,
"dtype": "float32",
"subjects": subjects,
}
meta_path.write_text(json.dumps(meta))
ecg_gb = ecg_path.stat().st_size / 1e9
ppg_gb = ppg_path.stat().st_size / 1e9
print(f"[precompute] wrote {n_written} windows: ecg={ecg_gb:.2f}GB ppg={ppg_gb:.2f}GB")
if __name__ == "__main__":
main()