| """PyTorch Dataset over lucky9-cyou/mimic-iv-aligned-ppg-ecg. |
| |
| For v1 we: |
| - Keep only segments where ECG lead II is present (93.7% of data) |
| - Extract lead II ECG and PPG Pleth |
| - Window: 10 s slices at 5 s stride |
| - Native rates: ECG 250 Hz, PPG 125 Hz -> ECG window 2500 samples, PPG 1250 |
| |
| Each item returns {ecg: [1, 2500], ppg: [1, 1250], subject_id, segment_start, |
| measured_ptt_ms (per-window estimate, may be NaN), delta_t_seconds (sampled per |
| step outside the dataset)}. |
| |
| The caller handles delta_t sampling (60% log-uniform + 40% from measured_ptt). |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import numpy as np |
| import torch |
| from scipy.signal import butter, filtfilt, find_peaks |
| from torch.utils.data import Dataset |
|
|
| from datasets import load_from_disk |
|
|
| ECG_FS = 250.0 |
| PPG_FS = 125.0 |
| WINDOW_SEC = 10.0 |
| STRIDE_SEC = 5.0 |
| ECG_WIN = int(ECG_FS * WINDOW_SEC) |
| PPG_WIN = int(PPG_FS * WINDOW_SEC) |
| ECG_STRIDE = int(ECG_FS * STRIDE_SEC) |
| PPG_STRIDE = int(PPG_FS * STRIDE_SEC) |
|
|
|
|
| def _parse_subject(record_name: str) -> str: |
| m = re.match(r"p\d+/(p\d+)/", record_name) |
| return m.group(1) if m else record_name.split("/")[0] |
|
|
|
|
| def _bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: |
| ny = 0.5 * fs |
| b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") |
| return filtfilt(b, a, x, method="gust").astype(np.float32) |
|
|
|
|
| def _zscore(x: np.ndarray, eps: float = 1e-6) -> np.ndarray: |
| m = x.mean() |
| s = x.std() + eps |
| return ((x - m) / s).astype(np.float32) |
|
|
|
|
| def _r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: |
| x = _bandpass(ecg, fs, 5.0, 15.0) |
| s = np.diff(x, prepend=x[:1]) ** 2 |
| w = max(int(0.12 * fs), 1) |
| mwa = np.convolve(s, np.ones(w) / w, mode="same") |
| thr = mwa.mean() + 0.5 * mwa.std() |
| p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) |
| return p |
|
|
|
|
| def _ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: |
| x = _bandpass(ppg, fs, 0.5, 8.0) |
| p, _ = find_peaks(x, distance=int(0.3 * fs), |
| height=x.mean() + 0.3 * x.std(), |
| prominence=0.1 * x.std()) |
| return p |
|
|
|
|
| def _window_ptt_ms(ecg_win: np.ndarray, ppg_win: np.ndarray) -> float: |
| """Median PTT across beats in one window; np.nan if <3 clean beats.""" |
| r = _r_peaks(ecg_win, ECG_FS) |
| p = _ppg_peaks(ppg_win, PPG_FS) |
| if len(r) < 3 or len(p) < 3: |
| return float("nan") |
| r_t = r / ECG_FS |
| p_t = p / PPG_FS |
| ptts = [] |
| for rt in r_t: |
| cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] |
| if len(cand) == 1: |
| ptts.append((cand[0] - rt) * 1000.0) |
| if len(ptts) < 3: |
| return float("nan") |
| return float(np.median(ptts)) |
|
|
|
|
| class MIMICAlignedDataset(Dataset): |
| """Indexes windows across a set of cached shard directories. |
| |
| Args: |
| shard_roots: list of "<snapshot_root>/shard_XXXXX" paths (pre-downloaded) |
| build_index: if True, scan and build/save the window index; if False, |
| load existing index_path |
| index_path: where to cache the index (JSON: list[{shard_idx, row_idx, |
| win_start_ecg, win_start_ppg, subject_id, ptt_ms}]) |
| normalise: if True, apply bandpass + zscore per window |
| """ |
|
|
| def __init__( |
| self, |
| shard_roots: list[Path], |
| index_path: Path, |
| build_index: bool = True, |
| normalise: bool = True, |
| subjects_allow: set[str] | None = None, |
| subset_frac: float = 1.0, |
| subset_seed: int = 0, |
| ): |
| self.shard_roots = [Path(p) for p in shard_roots] |
| self.index_path = Path(index_path) |
| self.normalise = normalise |
| self.subjects_allow = subjects_allow |
| if build_index or not self.index_path.exists(): |
| self._build_index() |
| self.index = json.loads(self.index_path.read_text()) |
| if subjects_allow is not None: |
| self.index = [r for r in self.index if r["subject_id"] in subjects_allow] |
| if subset_frac < 1.0: |
| rng = np.random.default_rng(subset_seed) |
| n_keep = max(1, int(len(self.index) * subset_frac)) |
| keep = rng.choice(len(self.index), size=n_keep, replace=False) |
| self.index = [self.index[i] for i in sorted(keep)] |
| self._shard_cache: dict[int, object] = {} |
|
|
| def _build_index(self) -> None: |
| records = [] |
| for s_path in self.shard_roots: |
| sidx = int(s_path.name.split("_")[1]) |
| ds = load_from_disk(str(s_path)) |
| for row_idx in range(len(ds)): |
| row = ds[row_idx] |
| names = list(row["ecg_names"]) |
| if "II" not in names: |
| continue |
| subject_id = _parse_subject(row["record_name"]) |
| ecg_siglen = int(row["ecg_siglen"]) |
| ppg_siglen = int(row["ppg_siglen"]) |
| |
| n_win = min( |
| (ecg_siglen - ECG_WIN) // ECG_STRIDE + 1, |
| (ppg_siglen - PPG_WIN) // PPG_STRIDE + 1, |
| ) |
| if n_win <= 0: |
| continue |
| for w in range(n_win): |
| records.append({ |
| "shard_idx": sidx, |
| "row_idx": row_idx, |
| "subject_id": subject_id, |
| "win_start_ecg": w * ECG_STRIDE, |
| "win_start_ppg": w * PPG_STRIDE, |
| }) |
| self.index_path.parent.mkdir(parents=True, exist_ok=True) |
| self.index_path.write_text(json.dumps(records)) |
|
|
| def _load_shard(self, sidx: int): |
| if sidx not in self._shard_cache: |
| for p in self.shard_roots: |
| if int(p.name.split("_")[1]) == sidx: |
| self._shard_cache[sidx] = load_from_disk(str(p)) |
| break |
| return self._shard_cache[sidx] |
|
|
| def __len__(self) -> int: |
| return len(self.index) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| rec = self.index[idx] |
| ds = self._load_shard(rec["shard_idx"]) |
| row = ds[rec["row_idx"]] |
| ecg_full = np.asarray(row["ecg"], dtype=np.float32) |
| ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0] |
| names = list(row["ecg_names"]) |
| ecg_lead = ecg_full[names.index("II")] |
| se = rec["win_start_ecg"] |
| sp = rec["win_start_ppg"] |
| ecg_win = ecg_lead[se : se + ECG_WIN].copy() |
| ppg_win = ppg_full[sp : sp + PPG_WIN].copy() |
| if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN: |
| raise RuntimeError(f"bad window at idx {idx}: {ecg_win.shape}, {ppg_win.shape}") |
|
|
| |
| |
| ptt_ms = float(rec.get("ptt_ms", float("nan"))) |
| if self.normalise: |
| ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0)) |
| ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0)) |
| return { |
| "ecg": torch.from_numpy(ecg_win).unsqueeze(0), |
| "ppg": torch.from_numpy(ppg_win).unsqueeze(0), |
| "subject_id": rec["subject_id"], |
| "ptt_ms": float(ptt_ms) if np.isfinite(ptt_ms) else float("nan"), |
| } |
|
|
|
|
| def split_by_subject( |
| subjects: Iterable[str], frac: float = 0.9, seed: int = 0 |
| ) -> tuple[set[str], set[str]]: |
| subjects = sorted(set(subjects)) |
| rng = np.random.default_rng(seed) |
| perm = rng.permutation(len(subjects)) |
| cut = int(len(subjects) * frac) |
| train = {subjects[i] for i in perm[:cut]} |
| test = {subjects[i] for i in perm[cut:]} |
| return train, test |
|
|
|
|
| def collate_with_dt( |
| items: list[dict], |
| log_uniform_frac: float = 0.6, |
| dt_min_ms: float = 50.0, |
| dt_max_ms: float = 500.0, |
| rng: np.random.Generator | None = None, |
| ) -> dict: |
| """Stack a batch and sample Δt. 60% log-uniform, 40% measured PTT where available.""" |
| rng = rng if rng is not None else np.random.default_rng() |
| ecg = torch.stack([b["ecg"] for b in items]) |
| ppg = torch.stack([b["ppg"] for b in items]) |
| ptts = np.array([b["ptt_ms"] for b in items], dtype=np.float32) |
| b = len(items) |
| dt_ms = np.empty(b, dtype=np.float32) |
| use_log = rng.random(b) < log_uniform_frac |
| log_lo, log_hi = np.log(dt_min_ms), np.log(dt_max_ms) |
| dt_ms[use_log] = np.exp(rng.uniform(log_lo, log_hi, size=int(use_log.sum()))) |
| |
| rest = ~use_log |
| for i in np.nonzero(rest)[0]: |
| if np.isfinite(ptts[i]): |
| dt_ms[i] = ptts[i] |
| else: |
| dt_ms[i] = np.exp(rng.uniform(log_lo, log_hi)) |
| return { |
| "ecg": ecg, |
| "ppg": ppg, |
| "dt_seconds": torch.from_numpy(dt_ms / 1000.0), |
| "ptt_ms": torch.from_numpy(ptts), |
| "subject_id": [b["subject_id"] for b in items], |
| } |
|
|