"""Validate alignment using PPG foot (onset) rather than systolic peak. Foot = minimum between two consecutive systolic peaks. This is the feature that physiologically corresponds to the pulse arrival time. Using it should collapse the bimodal PTT distribution. """ from __future__ import annotations import json import os import random import re from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np from dotenv import load_dotenv from scipy.signal import butter, filtfilt, find_peaks load_dotenv() os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) from datasets import load_from_disk from huggingface_hub import snapshot_download REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" OUT = Path(__file__).resolve().parent.parent / "docs" FIG = OUT / "figures" FIG.mkdir(parents=True, exist_ok=True) RNG = random.Random(7) def bandpass(x, fs, lo, hi, order=3): ny = 0.5 * fs b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") return filtfilt(b, a, x, method="gust") def r_peaks(ecg, fs): x = bandpass(ecg, fs, 5.0, 15.0) s = np.diff(x, prepend=x[:1]) ** 2 w = max(int(0.12 * fs), 1) mwa = np.convolve(s, np.ones(w) / w, mode="same") thr = mwa.mean() + 0.5 * mwa.std() p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) snap = max(int(0.06 * fs), 1) return np.asarray( [max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p] ) def ppg_feet(ppg, fs): """Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks.""" x = bandpass(ppg, fs, 0.5, 8.0) # find systolic peaks first peaks, _ = find_peaks( x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std() ) feet = [] for i in range(1, len(peaks)): lo, hi = peaks[i - 1], peaks[i] # foot = local minimum between peaks feet.append(lo + int(np.argmin(x[lo:hi]))) return np.asarray(feet, dtype=int) def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p): r = r_peaks(ecg, ecg_fs) f = ppg_feet(ppg, ppg_fs) if len(r) < 3 or len(f) < 3: return [] r_t = t0e + r / ecg_fs f_t = t0p + f / ppg_fs out = [] for rt in r_t: cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)] if len(cand) == 1: out.append((cand[0] - rt) * 1000.0) return out def main(): want = list(range(0, 412, 20)) root = Path( snapshot_download( REPO, repo_type="dataset", allow_patterns=[f"shard_{i:05d}/*" for i in want], max_workers=12, ) ) shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()] all_ptts = [] stds = [] good = 0 for sidx in shards: if good >= 100: break ds = load_from_disk(str(root / f"shard_{sidx:05d}")) for i in range(min(len(ds), 30)): if good >= 100: break row = ds[i] ecg = np.asarray(row["ecg"], dtype=np.float32) ppg = np.asarray(row["ppg"], dtype=np.float32) names = list(row["ecg_names"]) if "II" not in names: continue lead = ecg[names.index("II")] ptts = clean_ptts_via_foot( lead, float(row["ecg_fs"]), ppg[0], float(row["ppg_fs"]), float(row["ecg_time_s"][0]), float(row["ppg_time_s"][0]), ) if len(ptts) >= 5: all_ptts.extend(ptts) stds.append(float(np.std(ptts))) good += 1 res = { "n_clean_beats": len(all_ptts), "n_good_segments": good, "ptt_foot_median_ms": float(np.median(all_ptts)), "ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)), "ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)), "within_segment_std_median_ms": float(np.median(stds)), "within_segment_std_p90_ms": float(np.percentile(stds, 90)), } plt.figure(figsize=(7, 4)) plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black") plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)") plt.ylabel("count") plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments") plt.legend() plt.tight_layout() plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120) plt.close() (OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2)) print(json.dumps(res, indent=2)) if __name__ == "__main__": main()