Cardiac_Monitor_API / ml_src /feature_extractor.py
Sanuka0523's picture
Deploy cardiac monitor FastAPI backend
11e9a40
"""
ECG Feature Extraction for XGBoost model.
Extracts 26 signal features from single-lead ECG using NeuroKit2.
This module is shared between ml/ training and backend/ inference.
"""
import numpy as np
import neurokit2 as nk
from scipy.stats import kurtosis, skew
def extract_ecg_features(ecg_signal: np.ndarray, sample_rate: int = 100,
heart_rate_sensor: float = None,
spo2: float = None) -> dict:
"""
Extract 26 features from single-lead ECG signal.
Args:
ecg_signal: 1D numpy array of ECG samples
sample_rate: Sampling rate in Hz (100 for ESP32, 500 for PTB-XL)
heart_rate_sensor: HR from MAX30100 (optional, for device features)
spo2: SpO2 from MAX30100 (optional, for device features)
Returns:
dict of 26 features (keys match XGBoost training feature names)
"""
features = {}
try:
# Clean the ECG signal
ecg_cleaned = nk.ecg_clean(ecg_signal, sampling_rate=sample_rate)
# Detect R-peaks
_, rpeaks = nk.ecg_peaks(ecg_cleaned, sampling_rate=sample_rate)
r_peak_indices = rpeaks.get("ECG_R_Peaks", np.array([]))
if len(r_peak_indices) < 3:
return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
# --- HRV Time-Domain Features (7) ---
rr_intervals = np.diff(r_peak_indices) / sample_rate * 1000 # ms
features["mean_rr"] = float(np.mean(rr_intervals))
features["sdnn"] = float(np.std(rr_intervals, ddof=1)) if len(rr_intervals) > 1 else 0.0
features["rmssd"] = float(np.sqrt(np.mean(np.diff(rr_intervals) ** 2))) if len(rr_intervals) > 1 else 0.0
nn_diff = np.abs(np.diff(rr_intervals))
features["pnn50"] = float(np.sum(nn_diff > 50) / len(nn_diff) * 100) if len(nn_diff) > 0 else 0.0
hr_from_rr = 60000.0 / rr_intervals
features["mean_hr_ecg"] = float(np.mean(hr_from_rr))
features["hr_std"] = float(np.std(hr_from_rr))
features["rr_range"] = float(np.max(rr_intervals) - np.min(rr_intervals))
# --- ECG Morphology Features (9) ---
try:
# Delineate ECG waves
_, waves = nk.ecg_delineate(ecg_cleaned, rpeaks, sampling_rate=sample_rate, method="dwt")
# QRS duration
qrs_onsets = [x for x in waves.get("ECG_Q_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
qrs_offsets = [x for x in waves.get("ECG_S_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
if qrs_onsets and qrs_offsets:
qrs_durations = []
for q, s in zip(qrs_onsets[:len(qrs_offsets)], qrs_offsets[:len(qrs_onsets)]):
qrs_durations.append(abs(s - q) / sample_rate * 1000)
features["qrs_duration"] = float(np.mean(qrs_durations)) if qrs_durations else 100.0
else:
features["qrs_duration"] = 100.0
# R amplitude
r_amplitudes = ecg_cleaned[r_peak_indices.astype(int)]
features["r_amplitude"] = float(np.mean(r_amplitudes))
features["r_amplitude_std"] = float(np.std(r_amplitudes))
# QT interval
t_offsets = [x for x in waves.get("ECG_T_Offsets", []) if isinstance(x, (int, float)) and not np.isnan(x)]
if qrs_onsets and t_offsets:
qt_intervals = []
for q, t in zip(qrs_onsets[:len(t_offsets)], t_offsets[:len(qrs_onsets)]):
qt_intervals.append(abs(t - q) / sample_rate * 1000)
features["qt_interval"] = float(np.mean(qt_intervals)) if qt_intervals else 400.0
# Bazett's QTc
mean_rr_sec = features["mean_rr"] / 1000
features["qtc"] = float(features["qt_interval"] / np.sqrt(mean_rr_sec)) if mean_rr_sec > 0 else 440.0
else:
features["qt_interval"] = 400.0
features["qtc"] = 440.0
# ST level (amplitude at J-point, ~40ms after R-peak)
j_offset = int(0.04 * sample_rate)
st_levels = []
for rp in r_peak_indices.astype(int):
j_idx = rp + j_offset
if j_idx < len(ecg_cleaned):
st_levels.append(ecg_cleaned[j_idx])
features["st_level"] = float(np.mean(st_levels)) if st_levels else 0.0
# T-wave amplitude
t_peaks = [x for x in waves.get("ECG_T_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
if t_peaks:
t_amps = [ecg_cleaned[int(t)] for t in t_peaks if int(t) < len(ecg_cleaned)]
features["t_amplitude"] = float(np.mean(t_amps)) if t_amps else 0.0
else:
features["t_amplitude"] = 0.0
# P-wave ratio (P amplitude / R amplitude)
p_peaks = [x for x in waves.get("ECG_P_Peaks", []) if isinstance(x, (int, float)) and not np.isnan(x)]
if p_peaks and features["r_amplitude"] != 0:
p_amps = [ecg_cleaned[int(p)] for p in p_peaks if int(p) < len(ecg_cleaned)]
features["p_wave_ratio"] = float(np.mean(p_amps) / features["r_amplitude"]) if p_amps else 0.1
else:
features["p_wave_ratio"] = 0.1
except Exception:
features.setdefault("qrs_duration", 100.0)
features.setdefault("r_amplitude", float(np.max(ecg_cleaned) - np.min(ecg_cleaned)))
features.setdefault("r_amplitude_std", 0.0)
features.setdefault("qt_interval", 400.0)
features.setdefault("qtc", 440.0)
features.setdefault("st_level", 0.0)
features.setdefault("t_amplitude", 0.0)
features.setdefault("p_wave_ratio", 0.1)
# --- Signal Statistics (6) ---
features["rms"] = float(np.sqrt(np.mean(ecg_cleaned ** 2)))
features["entropy"] = float(_sample_entropy(ecg_cleaned))
features["zero_crossing_rate"] = float(
np.sum(np.diff(np.sign(ecg_cleaned - np.mean(ecg_cleaned))) != 0) / len(ecg_cleaned)
)
features["kurtosis"] = float(kurtosis(ecg_cleaned))
features["skewness"] = float(skew(ecg_cleaned))
features["snr"] = float(_estimate_snr(ecg_cleaned, sample_rate))
# --- Device Sensor Features (4) ---
features["heart_rate_sensor"] = float(heart_rate_sensor) if heart_rate_sensor else features["mean_hr_ecg"]
features["spo2"] = float(spo2) if spo2 else 97.0
hr_diff = abs(features["heart_rate_sensor"] - features["mean_hr_ecg"])
features["hr_sensor_ecg_diff"] = float(hr_diff)
# ECG quality score (based on peak regularity)
if len(rr_intervals) > 1:
cv = np.std(rr_intervals) / np.mean(rr_intervals)
features["ecg_quality"] = float(max(0, 1 - cv))
else:
features["ecg_quality"] = 0.5
except Exception:
return _fallback_features(ecg_signal, heart_rate_sensor, spo2)
return features
def _sample_entropy(signal, m=2, r_factor=0.2):
"""Approximate sample entropy."""
try:
r = r_factor * np.std(signal)
N = len(signal)
if N < m + 2 or r == 0:
return 0.0
# Use simplified approach for speed
templates_m = np.array([signal[i:i + m] for i in range(N - m)])
templates_m1 = np.array([signal[i:i + m + 1] for i in range(N - m - 1)])
count_m = 0
count_m1 = 0
# Sample subset for speed
n_check = min(200, len(templates_m))
indices = np.random.choice(len(templates_m), n_check, replace=False) if len(templates_m) > n_check else range(len(templates_m))
for i in indices:
dist_m = np.max(np.abs(templates_m - templates_m[i]), axis=1)
count_m += np.sum(dist_m < r) - 1
if i < len(templates_m1):
dist_m1 = np.max(np.abs(templates_m1 - templates_m1[i]), axis=1)
count_m1 += np.sum(dist_m1 < r) - 1
if count_m == 0 or count_m1 == 0:
return 0.0
return -np.log(count_m1 / count_m)
except Exception:
return 0.0
def _estimate_snr(signal, sample_rate):
"""Estimate signal-to-noise ratio."""
try:
cleaned = nk.ecg_clean(signal, sampling_rate=sample_rate)
noise = signal - cleaned
signal_power = np.mean(cleaned ** 2)
noise_power = np.mean(noise ** 2)
if noise_power == 0:
return 30.0
return float(10 * np.log10(signal_power / noise_power))
except Exception:
return 10.0
def _fallback_features(ecg_signal, heart_rate_sensor=None, spo2=None) -> dict:
"""Return default features when ECG processing fails."""
return {
"mean_rr": 800.0, "sdnn": 50.0, "rmssd": 30.0, "pnn50": 10.0,
"mean_hr_ecg": 75.0, "hr_std": 5.0, "rr_range": 200.0,
"qrs_duration": 100.0, "r_amplitude": 1.0, "r_amplitude_std": 0.1,
"qt_interval": 400.0, "qtc": 440.0, "st_level": 0.0,
"t_amplitude": 0.3, "p_wave_ratio": 0.1,
"rms": float(np.sqrt(np.mean(ecg_signal ** 2))) if len(ecg_signal) > 0 else 0.5,
"entropy": 0.5, "zero_crossing_rate": 0.1,
"kurtosis": 0.0, "skewness": 0.0, "snr": 10.0,
"heart_rate_sensor": float(heart_rate_sensor) if heart_rate_sensor else 75.0,
"spo2": float(spo2) if spo2 else 97.0,
"hr_sensor_ecg_diff": 0.0, "ecg_quality": 0.5,
}
# Ordered feature names for XGBoost (must match training order)
FEATURE_NAMES = [
"mean_rr", "sdnn", "rmssd", "pnn50", "mean_hr_ecg", "hr_std", "rr_range",
"qrs_duration", "r_amplitude", "r_amplitude_std", "qt_interval", "qtc",
"st_level", "t_amplitude", "p_wave_ratio",
"rms", "entropy", "zero_crossing_rate", "kurtosis", "skewness", "snr",
"heart_rate_sensor", "spo2", "hr_sensor_ecg_diff", "ecg_quality",
]
# User profile feature names (appended after ECG features)
PROFILE_FEATURE_NAMES = [
"age", "sex", "bmi", "is_diabetic", "is_hypertensive",
"is_smoker", "family_history",
]
# Historical baseline feature names
HISTORY_FEATURE_NAMES = [
"hr_baseline_24h", "hr_baseline_7d", "spo2_baseline_24h",
"hr_deviation", "spo2_deviation", "resting_hr_trend", "readings_count_24h",
]
ALL_FEATURE_NAMES = FEATURE_NAMES + PROFILE_FEATURE_NAMES + HISTORY_FEATURE_NAMES
def features_to_array(features: dict, include_profile: bool = False) -> np.ndarray:
"""Convert feature dict to numpy array in correct order for XGBoost."""
names = ALL_FEATURE_NAMES if include_profile else FEATURE_NAMES
return np.array([features.get(name, 0.0) for name in names], dtype=np.float32)