File size: 4,859 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Validate alignment using PPG foot (onset) rather than systolic peak.

Foot = minimum between two consecutive systolic peaks. This is the feature
that physiologically corresponds to the pulse arrival time. Using it should
collapse the bimodal PTT distribution.
"""
from __future__ import annotations

import json
import os
import random
import re
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from dotenv import load_dotenv
from scipy.signal import butter, filtfilt, find_peaks

load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
from datasets import load_from_disk
from huggingface_hub import snapshot_download

REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg"
OUT = Path(__file__).resolve().parent.parent / "docs"
FIG = OUT / "figures"
FIG.mkdir(parents=True, exist_ok=True)
RNG = random.Random(7)


def bandpass(x, fs, lo, hi, order=3):
    ny = 0.5 * fs
    b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band")
    return filtfilt(b, a, x, method="gust")


def r_peaks(ecg, fs):
    x = bandpass(ecg, fs, 5.0, 15.0)
    s = np.diff(x, prepend=x[:1]) ** 2
    w = max(int(0.12 * fs), 1)
    mwa = np.convolve(s, np.ones(w) / w, mode="same")
    thr = mwa.mean() + 0.5 * mwa.std()
    p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs))
    snap = max(int(0.06 * fs), 1)
    return np.asarray(
        [max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p]
    )


def ppg_feet(ppg, fs):
    """Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks."""
    x = bandpass(ppg, fs, 0.5, 8.0)
    # find systolic peaks first
    peaks, _ = find_peaks(
        x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std()
    )
    feet = []
    for i in range(1, len(peaks)):
        lo, hi = peaks[i - 1], peaks[i]
        # foot = local minimum between peaks
        feet.append(lo + int(np.argmin(x[lo:hi])))
    return np.asarray(feet, dtype=int)


def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p):
    r = r_peaks(ecg, ecg_fs)
    f = ppg_feet(ppg, ppg_fs)
    if len(r) < 3 or len(f) < 3:
        return []
    r_t = t0e + r / ecg_fs
    f_t = t0p + f / ppg_fs
    out = []
    for rt in r_t:
        cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)]
        if len(cand) == 1:
            out.append((cand[0] - rt) * 1000.0)
    return out


def main():
    want = list(range(0, 412, 20))
    root = Path(
        snapshot_download(
            REPO,
            repo_type="dataset",
            allow_patterns=[f"shard_{i:05d}/*" for i in want],
            max_workers=12,
        )
    )
    shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()]
    all_ptts = []
    stds = []
    good = 0
    for sidx in shards:
        if good >= 100:
            break
        ds = load_from_disk(str(root / f"shard_{sidx:05d}"))
        for i in range(min(len(ds), 30)):
            if good >= 100:
                break
            row = ds[i]
            ecg = np.asarray(row["ecg"], dtype=np.float32)
            ppg = np.asarray(row["ppg"], dtype=np.float32)
            names = list(row["ecg_names"])
            if "II" not in names:
                continue
            lead = ecg[names.index("II")]
            ptts = clean_ptts_via_foot(
                lead,
                float(row["ecg_fs"]),
                ppg[0],
                float(row["ppg_fs"]),
                float(row["ecg_time_s"][0]),
                float(row["ppg_time_s"][0]),
            )
            if len(ptts) >= 5:
                all_ptts.extend(ptts)
                stds.append(float(np.std(ptts)))
                good += 1
    res = {
        "n_clean_beats": len(all_ptts),
        "n_good_segments": good,
        "ptt_foot_median_ms": float(np.median(all_ptts)),
        "ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)),
        "ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)),
        "within_segment_std_median_ms": float(np.median(stds)),
        "within_segment_std_p90_ms": float(np.percentile(stds, 90)),
    }
    plt.figure(figsize=(7, 4))
    plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black")
    plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms")
    plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms")
    plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)")
    plt.ylabel("count")
    plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments")
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120)
    plt.close()
    (OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2))
    print(json.dumps(res, indent=2))


if __name__ == "__main__":
    main()