File size: 4,859 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """Validate alignment using PPG foot (onset) rather than systolic peak.
Foot = minimum between two consecutive systolic peaks. This is the feature
that physiologically corresponds to the pulse arrival time. Using it should
collapse the bimodal PTT distribution.
"""
from __future__ import annotations
import json
import os
import random
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from scipy.signal import butter, filtfilt, find_peaks
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download
REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
OUT = Path(__file__).resolve().parent.parent / "docs"
FIG = OUT / "figures"
FIG.mkdir(parents=True, exist_ok=True)
RNG = random.Random(7)
def bandpass(x, fs, lo, hi, order=3):
ny = 0.5 * fs
b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
return filtfilt(b, a, x, method="gust")
def r_peaks(ecg, fs):
x = bandpass(ecg, fs, 5.0, 15.0)
s = np.diff(x, prepend=x[:1]) ** 2
w = max(int(0.12 * fs), 1)
mwa = np.convolve(s, np.ones(w) / w, mode="same")
thr = mwa.mean() + 0.5 * mwa.std()
p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs))
snap = max(int(0.06 * fs), 1)
return np.asarray(
[max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p]
)
def ppg_feet(ppg, fs):
"""Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks."""
x = bandpass(ppg, fs, 0.5, 8.0)
# find systolic peaks first
peaks, _ = find_peaks(
x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std()
)
feet = []
for i in range(1, len(peaks)):
lo, hi = peaks[i - 1], peaks[i]
# foot = local minimum between peaks
feet.append(lo + int(np.argmin(x[lo:hi])))
return np.asarray(feet, dtype=int)
def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p):
r = r_peaks(ecg, ecg_fs)
f = ppg_feet(ppg, ppg_fs)
if len(r) < 3 or len(f) < 3:
return []
r_t = t0e + r / ecg_fs
f_t = t0p + f / ppg_fs
out = []
for rt in r_t:
cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)]
if len(cand) == 1:
out.append((cand[0] - rt) * 1000.0)
return out
def main():
want = list(range(0, 412, 20))
root = Path(
snapshot_download(
REPO,
repo_type="dataset",
allow_patterns=[f"shard_{i:05d}/*" for i in want],
max_workers=12,
)
)
shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()]
all_ptts = []
stds = []
good = 0
for sidx in shards:
if good >= 100:
break
ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
for i in range(min(len(ds), 30)):
if good >= 100:
break
row = ds[i]
ecg = np.asarray(row["ecg"], dtype=np.float32)
ppg = np.asarray(row["ppg"], dtype=np.float32)
names = list(row["ecg_names"])
if "II" not in names:
continue
lead = ecg[names.index("II")]
ptts = clean_ptts_via_foot(
lead,
float(row["ecg_fs"]),
ppg[0],
float(row["ppg_fs"]),
float(row["ecg_time_s"][0]),
float(row["ppg_time_s"][0]),
)
if len(ptts) >= 5:
all_ptts.extend(ptts)
stds.append(float(np.std(ptts)))
good += 1
res = {
"n_clean_beats": len(all_ptts),
"n_good_segments": good,
"ptt_foot_median_ms": float(np.median(all_ptts)),
"ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)),
"ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)),
"within_segment_std_median_ms": float(np.median(stds)),
"within_segment_std_p90_ms": float(np.percentile(stds, 90)),
}
plt.figure(figsize=(7, 4))
plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black")
plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms")
plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms")
plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)")
plt.ylabel("count")
plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments")
plt.legend()
plt.tight_layout()
plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120)
plt.close()
(OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2))
print(json.dumps(res, indent=2))
if __name__ == "__main__":
main()
|