guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""PyTorch Dataset over lucky9-cyou/mimic-iv-aligned-ppg-ecg.
For v1 we:
- Keep only segments where ECG lead II is present (93.7% of data)
- Extract lead II ECG and PPG Pleth
- Window: 10 s slices at 5 s stride
- Native rates: ECG 250 Hz, PPG 125 Hz -> ECG window 2500 samples, PPG 1250
Each item returns {ecg: [1, 2500], ppg: [1, 1250], subject_id, segment_start,
measured_ptt_ms (per-window estimate, may be NaN), delta_t_seconds (sampled per
step outside the dataset)}.
The caller handles delta_t sampling (60% log-uniform + 40% from measured_ptt).
"""
from __future__ import annotations
import json
import os
import re
from pathlib import Path
from typing import Iterable
import numpy as np
import torch
from scipy.signal import butter, filtfilt, find_peaks
from torch.utils.data import Dataset
from datasets import load_from_disk
ECG_FS = 250.0
PPG_FS = 125.0
WINDOW_SEC = 10.0
STRIDE_SEC = 5.0
ECG_WIN = int(ECG_FS * WINDOW_SEC) # 2500
PPG_WIN = int(PPG_FS * WINDOW_SEC) # 1250
ECG_STRIDE = int(ECG_FS * STRIDE_SEC)
PPG_STRIDE = int(PPG_FS * STRIDE_SEC)
def _parse_subject(record_name: str) -> str:
m = re.match(r"p\d+/(p\d+)/", record_name)
return m.group(1) if m else record_name.split("/")[0]
def _bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray:
ny = 0.5 * fs
b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
return filtfilt(b, a, x, method="gust").astype(np.float32)
def _zscore(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
m = x.mean()
s = x.std() + eps
return ((x - m) / s).astype(np.float32)
def _r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray:
x = _bandpass(ecg, fs, 5.0, 15.0)
s = np.diff(x, prepend=x[:1]) ** 2
w = max(int(0.12 * fs), 1)
mwa = np.convolve(s, np.ones(w) / w, mode="same")
thr = mwa.mean() + 0.5 * mwa.std()
p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs))
return p
def _ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray:
x = _bandpass(ppg, fs, 0.5, 8.0)
p, _ = find_peaks(x, distance=int(0.3 * fs),
height=x.mean() + 0.3 * x.std(),
prominence=0.1 * x.std())
return p
def _window_ptt_ms(ecg_win: np.ndarray, ppg_win: np.ndarray) -> float:
"""Median PTT across beats in one window; np.nan if <3 clean beats."""
r = _r_peaks(ecg_win, ECG_FS)
p = _ppg_peaks(ppg_win, PPG_FS)
if len(r) < 3 or len(p) < 3:
return float("nan")
r_t = r / ECG_FS
p_t = p / PPG_FS
ptts = []
for rt in r_t:
cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)]
if len(cand) == 1:
ptts.append((cand[0] - rt) * 1000.0)
if len(ptts) < 3:
return float("nan")
return float(np.median(ptts))
class MIMICAlignedDataset(Dataset):
"""Indexes windows across a set of cached shard directories.
Args:
shard_roots: list of "<snapshot_root>/shard_XXXXX" paths (pre-downloaded)
build_index: if True, scan and build/save the window index; if False,
load existing index_path
index_path: where to cache the index (JSON: list[{shard_idx, row_idx,
win_start_ecg, win_start_ppg, subject_id, ptt_ms}])
normalise: if True, apply bandpass + zscore per window
"""
def __init__(
self,
shard_roots: list[Path],
index_path: Path,
build_index: bool = True,
normalise: bool = True,
subjects_allow: set[str] | None = None,
subset_frac: float = 1.0,
subset_seed: int = 0,
):
self.shard_roots = [Path(p) for p in shard_roots]
self.index_path = Path(index_path)
self.normalise = normalise
self.subjects_allow = subjects_allow
if build_index or not self.index_path.exists():
self._build_index()
self.index = json.loads(self.index_path.read_text())
if subjects_allow is not None:
self.index = [r for r in self.index if r["subject_id"] in subjects_allow]
if subset_frac < 1.0:
rng = np.random.default_rng(subset_seed)
n_keep = max(1, int(len(self.index) * subset_frac))
keep = rng.choice(len(self.index), size=n_keep, replace=False)
self.index = [self.index[i] for i in sorted(keep)]
self._shard_cache: dict[int, object] = {}
def _build_index(self) -> None:
records = []
for s_path in self.shard_roots:
sidx = int(s_path.name.split("_")[1])
ds = load_from_disk(str(s_path))
for row_idx in range(len(ds)):
row = ds[row_idx]
names = list(row["ecg_names"])
if "II" not in names:
continue
subject_id = _parse_subject(row["record_name"])
ecg_siglen = int(row["ecg_siglen"])
ppg_siglen = int(row["ppg_siglen"])
# require full windows only
n_win = min(
(ecg_siglen - ECG_WIN) // ECG_STRIDE + 1,
(ppg_siglen - PPG_WIN) // PPG_STRIDE + 1,
)
if n_win <= 0:
continue
for w in range(n_win):
records.append({
"shard_idx": sidx,
"row_idx": row_idx,
"subject_id": subject_id,
"win_start_ecg": w * ECG_STRIDE,
"win_start_ppg": w * PPG_STRIDE,
})
self.index_path.parent.mkdir(parents=True, exist_ok=True)
self.index_path.write_text(json.dumps(records))
def _load_shard(self, sidx: int):
if sidx not in self._shard_cache:
for p in self.shard_roots:
if int(p.name.split("_")[1]) == sidx:
self._shard_cache[sidx] = load_from_disk(str(p))
break
return self._shard_cache[sidx]
def __len__(self) -> int:
return len(self.index)
def __getitem__(self, idx: int) -> dict:
rec = self.index[idx]
ds = self._load_shard(rec["shard_idx"])
row = ds[rec["row_idx"]]
ecg_full = np.asarray(row["ecg"], dtype=np.float32)
ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0]
names = list(row["ecg_names"])
ecg_lead = ecg_full[names.index("II")]
se = rec["win_start_ecg"]
sp = rec["win_start_ppg"]
ecg_win = ecg_lead[se : se + ECG_WIN].copy()
ppg_win = ppg_full[sp : sp + PPG_WIN].copy()
if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN:
raise RuntimeError(f"bad window at idx {idx}: {ecg_win.shape}, {ppg_win.shape}")
# PTT is computed ONLY at index-build time (cached in the index dict).
# __getitem__ stays cheap so the GPU isn't waiting on peak detection.
ptt_ms = float(rec.get("ptt_ms", float("nan")))
if self.normalise:
ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0))
ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0))
return {
"ecg": torch.from_numpy(ecg_win).unsqueeze(0), # [1, 2500]
"ppg": torch.from_numpy(ppg_win).unsqueeze(0), # [1, 1250]
"subject_id": rec["subject_id"],
"ptt_ms": float(ptt_ms) if np.isfinite(ptt_ms) else float("nan"),
}
def split_by_subject(
subjects: Iterable[str], frac: float = 0.9, seed: int = 0
) -> tuple[set[str], set[str]]:
subjects = sorted(set(subjects))
rng = np.random.default_rng(seed)
perm = rng.permutation(len(subjects))
cut = int(len(subjects) * frac)
train = {subjects[i] for i in perm[:cut]}
test = {subjects[i] for i in perm[cut:]}
return train, test
def collate_with_dt(
items: list[dict],
log_uniform_frac: float = 0.6,
dt_min_ms: float = 50.0,
dt_max_ms: float = 500.0,
rng: np.random.Generator | None = None,
) -> dict:
"""Stack a batch and sample Δt. 60% log-uniform, 40% measured PTT where available."""
rng = rng if rng is not None else np.random.default_rng()
ecg = torch.stack([b["ecg"] for b in items])
ppg = torch.stack([b["ppg"] for b in items])
ptts = np.array([b["ptt_ms"] for b in items], dtype=np.float32)
b = len(items)
dt_ms = np.empty(b, dtype=np.float32)
use_log = rng.random(b) < log_uniform_frac
log_lo, log_hi = np.log(dt_min_ms), np.log(dt_max_ms)
dt_ms[use_log] = np.exp(rng.uniform(log_lo, log_hi, size=int(use_log.sum())))
# for the 40% branch: measured PTT when finite, else fallback to log-uniform
rest = ~use_log
for i in np.nonzero(rest)[0]:
if np.isfinite(ptts[i]):
dt_ms[i] = ptts[i]
else:
dt_ms[i] = np.exp(rng.uniform(log_lo, log_hi))
return {
"ecg": ecg,
"ppg": ppg,
"dt_seconds": torch.from_numpy(dt_ms / 1000.0),
"ptt_ms": torch.from_numpy(ptts),
"subject_id": [b["subject_id"] for b in items],
}