From cc57c1c8707c4183c91cab4de6e5c305759a9ce8 Mon Sep 17 00:00:00 2001 From: Koaha Date: Thu, 28 Nov 2024 15:24:50 +0700 Subject: [PATCH] update docstring and coverage --- docs/source/_static/sidebar.css | 22 +++ docs/source/conf.py | 5 + docs/source/index.rst | 56 ++++---- pyproject.toml | 2 +- setup.py | 2 +- tests/common/test_generate_template.py | 1 + tests/conftest.py | 2 + tests/sqi/test_hrv_sqi.py | 180 +++++++++++++++++++++++-- tests/sqi/test_rpeaks_sqi.py | 26 +++- vital_sqi/sqi/hrv_sqi.py | 22 ++- vital_sqi/sqi/rpeaks_sqi.py | 145 +++++++++++++++----- 11 files changed, 384 insertions(+), 79 deletions(-) create mode 100644 docs/source/_static/sidebar.css diff --git a/docs/source/_static/sidebar.css b/docs/source/_static/sidebar.css new file mode 100644 index 0000000..652670b --- /dev/null +++ b/docs/source/_static/sidebar.css @@ -0,0 +1,22 @@ +.wy-nav-side { + background-color: #f8f9fa; /* Light background */ +} +.wy-nav-content { + padding: 1em; +} +.wy-menu-vertical { + border-right: 1px solid #ddd; +} +.wy-menu-vertical li a { + font-size: 0.95em; + color: #007bff; /* Link color */ +} +.wy-menu-vertical li a:hover { + text-decoration: underline; +} +.wy-menu-vertical li.toctree-l1 > a { + font-weight: bold; /* Top-level items */ +} +.wy-menu-vertical li.toctree-l2 > a { + margin-left: 10px; /* Indent for second-level items */ +} diff --git a/docs/source/conf.py b/docs/source/conf.py index a9f944e..d5136ed 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -60,6 +60,11 @@ # Substitute project name into .rst files when |project_name| is used rst_epilog = '.. |project_name| replace:: %s' % project +# Add CSS files +html_css_files = [ + 'sidebar.css', +] + # -- Extensions configuration ------------------------------------------------ diff --git a/docs/source/index.rst b/docs/source/index.rst index a228222..1fa2fc1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,52 +1,50 @@ -Welcome to vital_sqi's documentation! +Welcome to vital_sqi's Documentation! ===================================== .. image:: ./_static/imgs/logo.png :alt: Vital_SQI logo + :width: 200px + :align: center .. note:: Vital_SQI is a Python library designed for analyzing physiological signals (like ECG and PPG) and computing Signal Quality Indices (SQI). - The code of the project is on `GitHub `_. -Getting Started ---------------- -.. toctree:: - :maxdepth: 2 - - usage/installation - usage/introduction - usage/contributions - usage/development - usage/quickstart - -Data Manipulation ------------------ +🚀 Getting Started +------------------ .. toctree:: :maxdepth: 2 + :caption: Basics - _examples/notebooks/Data_manipulation_ECG_PPG + usage/installation 💻 Installation Guide + usage/introduction 📖 Introduction + usage/quickstart ⚡ Quickstart Tutorial + usage/contributions 🤝 Contributing to Vital_SQI + usage/development 🔧 Development Guide -Pipeline --------- +🛠️ Tutorials and Examples +------------------------- .. toctree:: :maxdepth: 2 + :caption: Tutorials - _examples/notebooks/SQI_pipeline + _examples/notebooks/Data_manipulation_ECG_PPG 🛠️ Data Manipulation with ECG & PPG + _examples/notebooks/SQI_pipeline 🔄 Building SQI Pipelines -Documentation -------------- +📚 Documentation +---------------- .. toctree:: :maxdepth: 2 + :caption: API Reference - docstring/modules - docstring/vital_sqi.pipeline - docstring/vital_sqi.sqi + docstring/modules 📘 Module Overview + docstring/vital_sqi.pipeline 🔄 Pipeline Module + docstring/vital_sqi.sqi 📊 SQI Module -Indices and tables -================== +🔍 Indices and Tables +===================== -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` \ No newline at end of file +* :ref:`genindex` - General Index +* :ref:`modindex` - Module Index +* :ref:`search` - Search diff --git a/pyproject.toml b/pyproject.toml index 79a9707..2538e92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vitalsqi_toolkit" -version = "1.0.1" +version = "1.0.2" description = "A toolkit for signal quality analysis of ECG and PPG signals" readme = "README.md" requires-python = ">=3.7" diff --git a/setup.py b/setup.py index d3949e4..02d93b3 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name = 'vitalsqi_toolkit', - version = '1.0.1', + version = '1.0.2', packages = find_packages(include = ["vital_sqi", "vital_sqi.*"]), description = "Signal quality control pipeline for electrocardiogram and " "photoplethysmogram", diff --git a/tests/common/test_generate_template.py b/tests/common/test_generate_template.py index 986666d..4f5a256 100644 --- a/tests/common/test_generate_template.py +++ b/tests/common/test_generate_template.py @@ -1,5 +1,6 @@ from scipy.special import erf + class TestPPGDualDoubleFrequencyTemplate(object): def test_on_ppg_dual_double_frequency_template(self): pass diff --git a/tests/conftest.py b/tests/conftest.py index 0e21c7f..4cddb6e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,12 @@ from webdriver_manager.firefox import GeckoDriverManager from dash.testing.composite import DashComposite + def pytest_ignore_collect(path): if "vital_sqi/app" in str(path): return True + def pytest_addoption(parser): """Add a command-line option for selecting the browser.""" parser.addoption( diff --git a/tests/sqi/test_hrv_sqi.py b/tests/sqi/test_hrv_sqi.py index f0c0f70..6a26b50 100644 --- a/tests/sqi/test_hrv_sqi.py +++ b/tests/sqi/test_hrv_sqi.py @@ -55,8 +55,9 @@ def test_sdsd_sqi(self, valid_nn_intervals, short_nn_intervals): assert np.isnan(sdsd_sqi(["invalid_input"])) # # Test with invalid input - # with pytest.raises(ValueError, match="diff requires input that is at least one dimensional"): - # sdsd_sqi("invalid_input") + with pytest.warns(UserWarning): + res = sdsd_sqi("invalid_input") + assert np.isnan(res) # Ensure the function returns NaN for invalid inputs def test_rmssd_sqi(self, valid_nn_intervals, short_nn_intervals): # Test with valid NN intervals @@ -73,14 +74,21 @@ def test_rmssd_sqi(self, valid_nn_intervals, short_nn_intervals): assert np.isnan(rmssd_sqi(["invalid_input"])) # Test with invalid input - # with pytest.raises(ValueError, match="diff requires input that is at least one dimensional"): - # rmssd_sqi("invalid_input") + with pytest.warns(UserWarning): + res = rmssd_sqi("invalid_input") + assert np.isnan(res) # Ensure the function returns NaN for invalid inputs def test_cvsd_sqi(self, valid_nn_intervals): assert cvsd_sqi(valid_nn_intervals) == pytest.approx(0.01994, rel=1e-1) + with pytest.warns(UserWarning): + res = cvsd_sqi("invalid_input") + assert np.isnan(res) # Ensure the function returns NaN for invalid inputs def test_cvnn_sqi(self, valid_nn_intervals): assert cvnn_sqi(valid_nn_intervals) == pytest.approx(0.0182, rel=1e-1) + with pytest.warns(UserWarning): + res = cvnn_sqi("invalid_input") + assert np.isnan(res) # Ensure the function returns NaN for invalid inputs def test_median_nn_sqi(self, valid_nn_intervals, empty_nn_intervals): assert median_nn_sqi(valid_nn_intervals) == pytest.approx(810, rel=1e1) @@ -91,6 +99,9 @@ def test_pnn_sqi(self, valid_nn_intervals, short_nn_intervals): 100.0, rel=1e-2 ) assert np.isnan(pnn_sqi(short_nn_intervals)) + with pytest.warns(UserWarning): + res = pnn_sqi("invalid_input") + assert np.isnan(res) # Ensure the function returns NaN for invalid inputs def test_hr_sqi(self, valid_nn_intervals): assert hr_sqi(valid_nn_intervals, stat="mean") == pytest.approx(74.26, rel=1e-2) @@ -147,6 +158,70 @@ def test_frequency_sqi(self, valid_nn_intervals): ) ) + # Test all metrics + length = 500 + base_rate = 600 + variability = 50 + synthetic_nn_intervals = base_rate + np.random.randint( + -variability, variability, size=length + ) + freq_min = 0.04 + freq_max = 0.15 + result_peak = frequency_sqi( + synthetic_nn_intervals, freq_min=freq_min, freq_max=freq_max, metric="peak" + ) + assert result_peak >= 0 or np.isnan(result_peak) + result_absolute = frequency_sqi( + synthetic_nn_intervals, + freq_min=freq_min, + freq_max=freq_max, + metric="absolute", + ) + assert result_absolute >= 0 or np.isnan(result_absolute) + result_log = frequency_sqi( + synthetic_nn_intervals, freq_min=freq_min, freq_max=freq_max, metric="log" + ) + assert result_log >= 0 or np.isnan(result_log) + result_normalized = frequency_sqi( + synthetic_nn_intervals, + freq_min=freq_min, + freq_max=freq_max, + metric="normalized", + ) + assert result_normalized >= 0 or np.isnan(result_normalized) + result_relative = frequency_sqi( + synthetic_nn_intervals, + freq_min=freq_min, + freq_max=freq_max, + metric="relative", + ) + assert result_relative >= 0 or np.isnan(result_relative) + + def generate_nn_intervals(self, length=500, base_rate=600, variability=50): + """ + Generate synthetic NN intervals mimicking heart rate variability. + + Parameters: + ---------- + length : int + Number of intervals to generate. + base_rate : int + Base NN interval in milliseconds. + variability : int + Maximum variability in NN interval. + + Returns: + ------- + list + A list of synthetic NN intervals. + """ + np.random.seed(42) # For reproducibility + # Generate random variations around the base rate + nn_intervals = base_rate + np.random.randint( + -variability, variability, size=length + ) + return nn_intervals.tolist() + def test_lf_hf_ratio_sqi(self, valid_nn_intervals, empty_nn_intervals): """Test LF/HF ratio calculation.""" # Valid case @@ -181,6 +256,23 @@ def test_lf_hf_ratio_sqi(self, valid_nn_intervals, empty_nn_intervals): ) ) + length = 500 + base_rate = 600 + variability = 50 + synthetic_nn_intervals = base_rate + np.random.randint( + -variability, variability, size=length + ) + ratio = lf_hf_ratio_sqi( + synthetic_nn_intervals, lf_range=(1e-3, 1e3), hf_range=(1e-4, 1e4) + ) + print(ratio) + assert not np.isnan(ratio) + + very_high_hf_ratio = lf_hf_ratio_sqi( + synthetic_nn_intervals, lf_range=(0.5, 1.0), hf_range=(10.0, 20.0) + ) + assert np.isnan(very_high_hf_ratio) + def test_poincare_features_sqi(self, valid_nn_intervals, short_nn_intervals): features = poincare_features_sqi(valid_nn_intervals) assert features["sd1"] >= 0 @@ -190,22 +282,23 @@ def test_poincare_features_sqi(self, valid_nn_intervals, short_nn_intervals): features = poincare_features_sqi(short_nn_intervals) for key in features: assert np.isnan(features[key]) + with pytest.warns(UserWarning): + res = poincare_features_sqi("invalid_input") + assert np.isnan( + res["sd1"] + ) # Ensure the function returns NaN for invalid inputs.res['sd1'] def test_get_all_features_hrva(self): signal = np.sin(np.linspace(0, 2 * np.pi, 1000)) # Simulated signal sample_rate = 100 # Valid case - features = ( - get_all_features_hrva(signal, sample_rate=sample_rate) - ) + features = get_all_features_hrva(signal, sample_rate=sample_rate) assert isinstance(features, dict) # Edge case: Invalid peak detection invalid_signal = [0] * 1000 # Flat signal, no peaks - features = ( - get_all_features_hrva(invalid_signal, sample_rate=sample_rate) - ) + features = get_all_features_hrva(invalid_signal, sample_rate=sample_rate) assert features == {} # Edge case: Invalid wave type @@ -217,3 +310,70 @@ def test_get_all_features_hrva(self): # Edge case: Invalid sample rate with pytest.raises(Exception, match="Sample rate must be a positive number."): get_all_features_hrva(signal, sample_rate=-100) + + with pytest.warns(UserWarning): + res = get_all_features_hrva("invalid_input") + assert len(res) == 0 + + def test_hr_sqi_invalid_stat(self, valid_nn_intervals): + result = hr_sqi(valid_nn_intervals, stat="invalid") + assert np.isnan(result), "Expected NaN for invalid stat input" + + def test_hr_range_sqi_edge_cases(self): + nn_intervals = [800, 810, 820, 830] + assert hr_range_sqi(nn_intervals, range_min=900, range_max=1000) == 100.0 + assert hr_range_sqi(nn_intervals, range_min=500, range_max=700) == 100.0 + + def test_frequency_sqi_empty_band_powers(self): + nn_intervals = [ + 800, + 810, + 820, + 830, + ] # Adjust as necessary to force empty band_powers + result = frequency_sqi( + nn_intervals, freq_min=0.4, freq_max=0.5, metric="absolute" + ) + assert np.isnan(result) + + def test_lf_hf_ratio_sqi_invalid_range(self, valid_nn_intervals): + result = lf_hf_ratio_sqi( + valid_nn_intervals, lf_range=(0.5, 0.4), hf_range=(0.15, 0.1) + ) + assert np.isnan(result) + + def test_poincare_features_sqi_insufficient_data(self): + features = poincare_features_sqi([800]) # Insufficient intervals + for key in features: + assert np.isnan(features[key]) + + def test_get_all_features_hrva_invalid_inputs(self): + signal = np.sin(np.linspace(0, 2 * np.pi, 1000)) + sample_rate = 100 + + # Invalid peak detection method + result = get_all_features_hrva( + signal, sample_rate=sample_rate, rpeak_method=999 + ) + assert ( + result is not None + ), "Expected a default dictionary for invalid peak detection method" + + # Invalid sample rate + with pytest.raises(ValueError, match="Sample rate must be a positive number"): + get_all_features_hrva(signal, sample_rate=-100) + + def test_get_all_features_hrva_rr_interval_failure(self): + invalid_signal = [0] * 1000 # Flat signal, no peaks + features = get_all_features_hrva(invalid_signal, sample_rate=100) + assert features == {} + + def test_get_all_features_hrva_feature_extraction_failure(self): + invalid_signal = np.sin(np.linspace(0, 2 * np.pi, 10)) # Too short for HRV + features = get_all_features_hrva(invalid_signal, sample_rate=100) + assert features == {} + + def test_get_all_features_hrva_final_block(self): + invalid_signal = [0] * 1000 + features = get_all_features_hrva(invalid_signal, sample_rate=100) + assert features == {} diff --git a/tests/sqi/test_rpeaks_sqi.py b/tests/sqi/test_rpeaks_sqi.py index d07aaf0..1f44e0b 100644 --- a/tests/sqi/test_rpeaks_sqi.py +++ b/tests/sqi/test_rpeaks_sqi.py @@ -1,6 +1,7 @@ import pytest import numpy as np from scipy import signal +from vitalDSP.utils.synthesize_data import generate_ecg_signal from vital_sqi.sqi.rpeaks_sqi import ( ectopic_sqi, correlogram_sqi, @@ -70,6 +71,17 @@ def test_ectopic_sqi(self): ): assert np.isnan(ectopic_sqi([], sample_rate=sample_rate)) + sfecg = 256 + N = 100 + Anoise = 0.05 + hrmean = 70 + ecg_signal = generate_ecg_signal(sfecg=sfecg, N=N, Anoise=Anoise, hrmean=hrmean) + for rule_index in range(1, 4): + result = ectopic_sqi( + ecg_signal, sample_rate=sfecg, wave_type="ECG", rule_index=rule_index + ) + assert np.isreal(result) + def test_correlogram_sqi(self): """Test the correlogram_sqi function.""" # Create a simulated signal @@ -88,11 +100,10 @@ def test_correlogram_sqi(self): with pytest.warns( UserWarning, match="Signal length is too short for the specified time lag" ): - assert ( + assert np.isnan( correlogram_sqi( short_signal, sample_rate=sample_rate, time_lag=3, n_selection=3 ) - == [] ) # Test with flat signal (no peaks) @@ -101,13 +112,20 @@ def test_correlogram_sqi(self): UserWarning, # match="No peaks detected in the autocorrelation function." ): - assert ( + assert np.isnan( correlogram_sqi( flat_signal, sample_rate=sample_rate, time_lag=3, n_selection=3 ) - == [] ) + sfecg = 256 + N = 100 + Anoise = 0.05 + hrmean = 70 + ecg_signal = generate_ecg_signal(sfecg=sfecg, N=N, Anoise=Anoise, hrmean=hrmean) + result = correlogram_sqi(ecg_signal, sample_rate=sfecg, wave_type="ECG") + assert result is not None + def test_interpolation_sqi(self, valid_signal): """Test the interpolation_sqi function.""" result = interpolation_sqi(valid_signal) diff --git a/vital_sqi/sqi/hrv_sqi.py b/vital_sqi/sqi/hrv_sqi.py index ad618db..1a50777 100644 --- a/vital_sqi/sqi/hrv_sqi.py +++ b/vital_sqi/sqi/hrv_sqi.py @@ -400,7 +400,7 @@ def lf_hf_ratio_sqi(nn_intervals, lf_range=(0.04, 0.15), hf_range=(0.15, 0.4)): warnings.warn("Invalid input: nn_intervals must be a list or numpy array.") return np.nan - if not nn_intervals or len(nn_intervals) < 3: + if len(nn_intervals) < 3: warnings.warn("Insufficient NN intervals for LF/HF ratio calculation.") return np.nan @@ -414,7 +414,6 @@ def lf_hf_ratio_sqi(nn_intervals, lf_range=(0.04, 0.15), hf_range=(0.15, 0.4)): lf_power = np.sum(powers[(freqs >= lf_range[0]) & (freqs < lf_range[1])]) hf_power = np.sum(powers[(freqs >= hf_range[0]) & (freqs < hf_range[1])]) - if hf_power <= 0: warnings.warn("HF power is zero or negative, cannot compute LF/HF ratio.") return np.nan @@ -484,6 +483,8 @@ def get_all_features_hrva(signal, sample_rate=100, rpeak_method=6, wave_type="EC if wave_type == "PPG" else detector.ecg_detector(signal) ) + if isinstance(peak_list, tuple) and peak_list: + peak_list = peak_list[0] except Exception as e: warnings.warn(f"Error during peak detection: {e}") return {} @@ -519,3 +520,20 @@ def get_all_features_hrva(signal, sample_rate=100, rpeak_method=6, wave_type="EC # return {}, {}, {}, {} # return time_features, freq_features, geometric_features, csi_cvi_features + + +# from vitalDSP.utils.synthesize_data import generate_ecg_signal + +# if __name__ == "__main__": +# sfecg = 256 +# N = 100 +# Anoise = 0.05 +# hrmean = 70 +# ecg_signal = generate_ecg_signal(sfecg=sfecg, N=N, Anoise=Anoise, hrmean=hrmean) +# # result_correlogram = correlogram_sqi(ecg_signal, sample_rate=sfecg, wave_type="ECG") +# # print(result_correlogram) + +# result_ectopic = get_all_features_hrva( +# ecg_signal, sample_rate=sfecg, rpeak_method=6, wave_type="ECG" +# ) +# print(result_ectopic) diff --git a/vital_sqi/sqi/rpeaks_sqi.py b/vital_sqi/sqi/rpeaks_sqi.py index 5438f04..3ef8c81 100644 --- a/vital_sqi/sqi/rpeaks_sqi.py +++ b/vital_sqi/sqi/rpeaks_sqi.py @@ -18,6 +18,7 @@ import warnings from statsmodels.tsa.stattools import acf from vital_sqi.common.rpeak_detection import PeakDetector +from scipy import signal from vitalDSP.transforms.beats_transformation import RRTransformation @@ -82,7 +83,7 @@ def ectopic_sqi( # Remove outliers based on RRI bounds and calculate outlier ratio rr_intervals_cleaned = transformer.remove_invalid_rr_intervals( - rr_intervals, min_rri=low_rri / 1000, max_rri=high_rri / 1000 + rr_intervals, min_rr=low_rri / 1000, max_rr=high_rri / 1000 ) number_outliers = np.isnan(rr_intervals_cleaned).sum() total_rr_intervals = len(rr_intervals_cleaned) @@ -96,7 +97,7 @@ def ectopic_sqi( rr_intervals_cleaned ) selected_rule = rules[rule_index - 1] - nn_intervals = transformer.remove_ectopic_beats( + nn_intervals = remove_ectopic_beats( interpolated_rr_intervals, method=selected_rule ) number_ectopics = np.isnan(nn_intervals).sum() @@ -104,27 +105,95 @@ def ectopic_sqi( return ectopic_ratio - except ValueError as e: - warnings.warn(str(e)) - return np.nan - except Exception as e: warnings.warn(f"Unexpected error in ectopic_sqi: {e}") return np.nan - except ValueError as e: - warnings.warn(f"Error in ectopic_sqi: {e}") - return np.nan - except Exception as e: - warnings.warn(f"Unexpected error in ectopic_sqi: {e}") - return np.nan +def remove_ectopic_beats(rr_intervals, method="adaptive"): + """ + Removes ectopic beats from RR intervals using a specified method. + + Parameters + ---------- + rr_intervals : np.array + The array of RR intervals (in seconds) with or without NaN values. + method : str, optional + The method to detect and remove ectopic beats. Options are 'adaptive', 'linear', or 'spline'. + Default is 'adaptive'. + + Returns + ------- + np.array + The array of RR intervals with ectopic beats marked as NaN. + + Notes + ----- + - Adaptive: Uses local and global trends to detect ectopic beats. + - Linear: Removes beats based on a linear trend. + - Spline: Uses spline fitting for ectopic beat detection. + + Example + ------- + >>> rr_intervals = np.array([0.8, 1.2, 1.0, 2.5, 0.9, 0.85]) + >>> rr_transformation = RRTransformation(signal, fs, "ECG") + >>> clean_rr_intervals = rr_transformation.remove_ectopic_beats(rr_intervals, method="adaptive") + """ + try: + if method not in ["adaptive", "linear", "spline"]: + raise ValueError( + f"Invalid method: {method}. Choose 'adaptive', 'linear', or 'spline'." + ) + + rr_intervals_cleaned = np.copy(rr_intervals) + valid_intervals = rr_intervals_cleaned[~np.isnan(rr_intervals_cleaned)] + + if len(valid_intervals) < 3: + raise ValueError("Not enough valid RR intervals for ectopic detection.") + + if method == "adaptive": + # Detect ectopic beats using a running mean and threshold + running_mean = np.convolve(valid_intervals, np.ones(5) / 5, mode="same") + deviations = np.abs(valid_intervals - running_mean) + threshold = 0.2 * running_mean # 20% deviation considered ectopic + ectopic_mask = deviations > threshold + rr_intervals_cleaned[~np.isnan(rr_intervals_cleaned)] = np.where( + ectopic_mask, np.nan, valid_intervals + ) + + elif method == "linear": + # Use linear interpolation to fit a trend and remove ectopic beats + linear_fit = np.polyval( + np.polyfit(np.arange(len(valid_intervals)), valid_intervals, 1), + np.arange(len(valid_intervals)), + ) + deviations = np.abs(valid_intervals - linear_fit) + threshold = 0.15 * linear_fit # 15% deviation considered ectopic + ectopic_mask = deviations > threshold + rr_intervals_cleaned[~np.isnan(rr_intervals_cleaned)] = np.where( + ectopic_mask, np.nan, valid_intervals + ) + + elif method == "spline": + # Use spline fitting to detect and remove ectopic beats + from scipy.interpolate import UnivariateSpline + + spline = UnivariateSpline( + np.arange(len(valid_intervals)), valid_intervals, s=0.1 + ) + spline_fit = spline(np.arange(len(valid_intervals))) + deviations = np.abs(valid_intervals - spline_fit) + threshold = 0.1 * spline_fit # 10% deviation considered ectopic + ectopic_mask = deviations > threshold + rr_intervals_cleaned[~np.isnan(rr_intervals_cleaned)] = np.where( + ectopic_mask, np.nan, valid_intervals + ) + + return rr_intervals_cleaned except Exception as e: - warnings.warn( - f"No peaks detected in the signal. RR interval computation failed: {e}" - ) - return np.nan + warnings.warn(f"Error in remove_ectopic_beats: {e}") + return rr_intervals def correlogram_sqi(s, sample_rate=100, wave_type="PPG", time_lag=3, n_selection=3): @@ -159,29 +228,25 @@ def correlogram_sqi(s, sample_rate=100, wave_type="PPG", time_lag=3, n_selection corr = acf(s, nlags=nlags, fft=True) # Find peaks in the autocorrelation function - # corr_peaks_idx = signal.find_peaks(corr)[0] - # corr_peaks_value = corr[corr_peaks_idx] - detector = PeakDetector(wave_type=wave_type) - - if wave_type == "PPG": - corr_peaks_idx, _ = detector.ppg_detector(s) - else: - corr_peaks_idx, _ = detector.ecg_detector(s) + corr_peaks_idx = signal.find_peaks(corr)[0] corr_peaks_value = corr[corr_peaks_idx] + if len(corr_peaks_idx) == 0: warnings.warn("No peaks detected in the autocorrelation function.") - return [] + return np.nan - # Select top peaks - n_selection = min(n_selection, len(corr_peaks_value)) - corr_sqi = list(corr_peaks_idx[:n_selection]) + list( - corr_peaks_value[:n_selection] - ) - return corr_sqi + # Select top peaks based on autocorrelation values + top_values = np.sort(corr_peaks_value)[ + -n_selection: + ] # Select top `n_selection` values + + # Compute SQI as the mean of the top peak values + sqi_value = np.mean(top_values) + return sqi_value except Exception as e: warnings.warn(f"Error in correlogram_sqi: {e}") - return [] + return np.nan def interpolation_sqi(s): @@ -254,3 +319,19 @@ def msq_sqi(s, peak_detector_1=7, peak_detector_2=6, wave_type="PPG"): except Exception as e: warnings.warn(f"Error in msq_sqi: {e}") return np.nan + + +# from vitalDSP.utils.synthesize_data import generate_ecg_signal +# if __name__ == "__main__": +# sfecg = 256 +# N = 100 +# Anoise = 0.05 +# hrmean = 70 +# ecg_signal = generate_ecg_signal( +# sfecg=sfecg, N=N, Anoise=Anoise, hrmean=hrmean +# ) +# # result_correlogram = correlogram_sqi(ecg_signal, sample_rate=sfecg, wave_type="ECG") +# # print(result_correlogram) + +# result_ectopic = ectopic_sqi(ecg_signal, sample_rate=sfecg, wave_type="ECG") +# print(result_ectopic)