"""PyTorch Dataset over lucky9-cyou/mimic-iv-aligned-ppg-ecg. For v1 we: - Keep only segments where ECG lead II is present (93.7% of data) - Extract lead II ECG and PPG Pleth - Window: 10 s slices at 5 s stride - Native rates: ECG 250 Hz, PPG 125 Hz -> ECG window 2500 samples, PPG 1250 Each item returns {ecg: [1, 2500], ppg: [1, 1250], subject_id, segment_start, measured_ptt_ms (per-window estimate, may be NaN), delta_t_seconds (sampled per step outside the dataset)}. The caller handles delta_t sampling (60% log-uniform + 40% from measured_ptt). """ from __future__ import annotations import json import os import re from pathlib import Path from typing import Iterable import numpy as np import torch from scipy.signal import butter, filtfilt, find_peaks from torch.utils.data import Dataset from datasets import load_from_disk ECG_FS = 250.0 PPG_FS = 125.0 WINDOW_SEC = 10.0 STRIDE_SEC = 5.0 ECG_WIN = int(ECG_FS * WINDOW_SEC) # 2500 PPG_WIN = int(PPG_FS * WINDOW_SEC) # 1250 ECG_STRIDE = int(ECG_FS * STRIDE_SEC) PPG_STRIDE = int(PPG_FS * STRIDE_SEC) def _parse_subject(record_name: str) -> str: m = re.match(r"p\d+/(p\d+)/", record_name) return m.group(1) if m else record_name.split("/")[0] def _bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: ny = 0.5 * fs b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") return filtfilt(b, a, x, method="gust").astype(np.float32) def _zscore(x: np.ndarray, eps: float = 1e-6) -> np.ndarray: m = x.mean() s = x.std() + eps return ((x - m) / s).astype(np.float32) def _r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: x = _bandpass(ecg, fs, 5.0, 15.0) s = np.diff(x, prepend=x[:1]) ** 2 w = max(int(0.12 * fs), 1) mwa = np.convolve(s, np.ones(w) / w, mode="same") thr = mwa.mean() + 0.5 * mwa.std() p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) return p def _ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: x = _bandpass(ppg, fs, 0.5, 8.0) p, _ = find_peaks(x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std()) return p def _window_ptt_ms(ecg_win: np.ndarray, ppg_win: np.ndarray) -> float: """Median PTT across beats in one window; np.nan if <3 clean beats.""" r = _r_peaks(ecg_win, ECG_FS) p = _ppg_peaks(ppg_win, PPG_FS) if len(r) < 3 or len(p) < 3: return float("nan") r_t = r / ECG_FS p_t = p / PPG_FS ptts = [] for rt in r_t: cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] if len(cand) == 1: ptts.append((cand[0] - rt) * 1000.0) if len(ptts) < 3: return float("nan") return float(np.median(ptts)) class MIMICAlignedDataset(Dataset): """Indexes windows across a set of cached shard directories. Args: shard_roots: list of "/shard_XXXXX" paths (pre-downloaded) build_index: if True, scan and build/save the window index; if False, load existing index_path index_path: where to cache the index (JSON: list[{shard_idx, row_idx, win_start_ecg, win_start_ppg, subject_id, ptt_ms}]) normalise: if True, apply bandpass + zscore per window """ def __init__( self, shard_roots: list[Path], index_path: Path, build_index: bool = True, normalise: bool = True, subjects_allow: set[str] | None = None, subset_frac: float = 1.0, subset_seed: int = 0, ): self.shard_roots = [Path(p) for p in shard_roots] self.index_path = Path(index_path) self.normalise = normalise self.subjects_allow = subjects_allow if build_index or not self.index_path.exists(): self._build_index() self.index = json.loads(self.index_path.read_text()) if subjects_allow is not None: self.index = [r for r in self.index if r["subject_id"] in subjects_allow] if subset_frac < 1.0: rng = np.random.default_rng(subset_seed) n_keep = max(1, int(len(self.index) * subset_frac)) keep = rng.choice(len(self.index), size=n_keep, replace=False) self.index = [self.index[i] for i in sorted(keep)] self._shard_cache: dict[int, object] = {} def _build_index(self) -> None: records = [] for s_path in self.shard_roots: sidx = int(s_path.name.split("_")[1]) ds = load_from_disk(str(s_path)) for row_idx in range(len(ds)): row = ds[row_idx] names = list(row["ecg_names"]) if "II" not in names: continue subject_id = _parse_subject(row["record_name"]) ecg_siglen = int(row["ecg_siglen"]) ppg_siglen = int(row["ppg_siglen"]) # require full windows only n_win = min( (ecg_siglen - ECG_WIN) // ECG_STRIDE + 1, (ppg_siglen - PPG_WIN) // PPG_STRIDE + 1, ) if n_win <= 0: continue for w in range(n_win): records.append({ "shard_idx": sidx, "row_idx": row_idx, "subject_id": subject_id, "win_start_ecg": w * ECG_STRIDE, "win_start_ppg": w * PPG_STRIDE, }) self.index_path.parent.mkdir(parents=True, exist_ok=True) self.index_path.write_text(json.dumps(records)) def _load_shard(self, sidx: int): if sidx not in self._shard_cache: for p in self.shard_roots: if int(p.name.split("_")[1]) == sidx: self._shard_cache[sidx] = load_from_disk(str(p)) break return self._shard_cache[sidx] def __len__(self) -> int: return len(self.index) def __getitem__(self, idx: int) -> dict: rec = self.index[idx] ds = self._load_shard(rec["shard_idx"]) row = ds[rec["row_idx"]] ecg_full = np.asarray(row["ecg"], dtype=np.float32) ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0] names = list(row["ecg_names"]) ecg_lead = ecg_full[names.index("II")] se = rec["win_start_ecg"] sp = rec["win_start_ppg"] ecg_win = ecg_lead[se : se + ECG_WIN].copy() ppg_win = ppg_full[sp : sp + PPG_WIN].copy() if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN: raise RuntimeError(f"bad window at idx {idx}: {ecg_win.shape}, {ppg_win.shape}") # PTT is computed ONLY at index-build time (cached in the index dict). # __getitem__ stays cheap so the GPU isn't waiting on peak detection. ptt_ms = float(rec.get("ptt_ms", float("nan"))) if self.normalise: ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0)) ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0)) return { "ecg": torch.from_numpy(ecg_win).unsqueeze(0), # [1, 2500] "ppg": torch.from_numpy(ppg_win).unsqueeze(0), # [1, 1250] "subject_id": rec["subject_id"], "ptt_ms": float(ptt_ms) if np.isfinite(ptt_ms) else float("nan"), } def split_by_subject( subjects: Iterable[str], frac: float = 0.9, seed: int = 0 ) -> tuple[set[str], set[str]]: subjects = sorted(set(subjects)) rng = np.random.default_rng(seed) perm = rng.permutation(len(subjects)) cut = int(len(subjects) * frac) train = {subjects[i] for i in perm[:cut]} test = {subjects[i] for i in perm[cut:]} return train, test def collate_with_dt( items: list[dict], log_uniform_frac: float = 0.6, dt_min_ms: float = 50.0, dt_max_ms: float = 500.0, rng: np.random.Generator | None = None, ) -> dict: """Stack a batch and sample Δt. 60% log-uniform, 40% measured PTT where available.""" rng = rng if rng is not None else np.random.default_rng() ecg = torch.stack([b["ecg"] for b in items]) ppg = torch.stack([b["ppg"] for b in items]) ptts = np.array([b["ptt_ms"] for b in items], dtype=np.float32) b = len(items) dt_ms = np.empty(b, dtype=np.float32) use_log = rng.random(b) < log_uniform_frac log_lo, log_hi = np.log(dt_min_ms), np.log(dt_max_ms) dt_ms[use_log] = np.exp(rng.uniform(log_lo, log_hi, size=int(use_log.sum()))) # for the 40% branch: measured PTT when finite, else fallback to log-uniform rest = ~use_log for i in np.nonzero(rest)[0]: if np.isfinite(ptts[i]): dt_ms[i] = ptts[i] else: dt_ms[i] = np.exp(rng.uniform(log_lo, log_hi)) return { "ecg": ecg, "ppg": ppg, "dt_seconds": torch.from_numpy(dt_ms / 1000.0), "ptt_ms": torch.from_numpy(ptts), "subject_id": [b["subject_id"] for b in items], }