diff --git a/vital_sqi/common/rpeak_detection.py b/vital_sqi/common/rpeak_detection.py index 3bc178a..4f7d4c3 100644 --- a/vital_sqi/common/rpeak_detection.py +++ b/vital_sqi/common/rpeak_detection.py @@ -315,6 +315,68 @@ def detect_peak_trough_adaptive_threshold( return np.array(peak_finalist), np.array(trough_finalist) + def get_ROI(self, s, adaptive_threshold, margin=0.1): + """ + Identify regions of interest (ROIs) in the signal where peaks or troughs are likely to occur. + + Parameters + ---------- + s : array_like + Input signal. + adaptive_threshold : array_like + Adaptive threshold values for the signal. + margin : float, optional + Margin (fraction of the signal range) to include before and after the ROI (default is 0.1). + + Returns + ------- + tuple + Two lists: start_ROIs and end_ROIs, which contain the start and end indices of the ROIs. + + Notes + ----- + - ROIs are defined as contiguous regions where the signal exceeds the adaptive threshold. + - Margins can be added to widen the ROIs for more inclusive peak/trough detection. + + Example + ------- + >>> s = [0, 1, 3, 7, 5, 2, 0, 6, 9, 8, 4, 1, 0] + >>> adaptive_threshold = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + >>> peak_detector = PeakDetection(fs=100) + >>> start_ROIs, end_ROIs = peak_detector.get_ROI(s, adaptive_threshold) + """ + # Identify regions where the signal exceeds the adaptive threshold + above_threshold = s > adaptive_threshold + + # Find the start and end indices of these regions + start_ROIs = [] + end_ROIs = [] + + for i in range(1, len(above_threshold)): + # Transition from below threshold to above threshold: start of an ROI + if above_threshold[i] and not above_threshold[i - 1]: + start_ROIs.append(i) + + # Transition from above threshold to below threshold: end of an ROI + if not above_threshold[i] and above_threshold[i - 1]: + end_ROIs.append(i - 1) + + # Handle case where the signal ends above the threshold + if above_threshold[-1]: + end_ROIs.append(len(s) - 1) + + # Apply margin to widen the ROIs + signal_length = len(s) + start_ROIs = [ + max(0, int(start - margin * signal_length)) for start in start_ROIs + ] + end_ROIs = [ + min(signal_length - 1, int(end + margin * signal_length)) + for end in end_ROIs + ] + + return start_ROIs, end_ROIs + def detect_peak_trough_moving_average_threshold(self, s): """ Detects peaks using a moving average threshold approach.