From 35db8bdc439fb3dcf8bdbe541f1f4820c9260b11 Mon Sep 17 00:00:00 2001 From: Scott Huberty Date: Mon, 23 Dec 2024 17:00:36 -0800 Subject: [PATCH] RFC, WIP: public function to find and return bad channels --- pylossless/pipeline.py | 56 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/pylossless/pipeline.py b/pylossless/pipeline.py index 45931cf..d474d7e 100644 --- a/pylossless/pipeline.py +++ b/pylossless/pipeline.py @@ -227,6 +227,52 @@ def _detect_outliers( return prop_outliers[prop_outliers > flag_crit].coords.to_index().values +def find_bads_by_threshold(epochs, threshold=5e-5): + """Return channels with a standard deviation consistently above a fixed threshold. + + Parameters + ---------- + inst : mne.Epochs + an instance of mne.Epochs + threshold : float + the threshold in volts. If the standard deviation of a channel's voltage + varianceat a specific epoch is above the threshold, then that channel x epoch + will be flagged as an "outlier". Default is 5e-5 (0.00005), i.e. + 50 microvolts. + + Returns + ------- + list + a list of channel names that are considered outliers. + + Notes + ----- + If you are having trouble converting between exponential notation and + decimal notation, you can use the following code to convert between the two: + + >>> import numpy as np + >>> threshold = 5e-5 + >>> with np.printoptions(suppress=True): + ... print(threshold) + 0.00005 + + Examples + -------- + >>> import mne + >>> import pylossless as ll + >>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif" + >>> raw = mne.io.read_raw(fname, preload=True).pick("eeg") + >>> raw.apply_function(lambda x: x * 3, picks=["EEG 001"]) # Make a noisy channel + >>> epochs = mne.make_fixed_length_epochs(raw, preload=True) + >>> bad_chs = ll.pipeline.find_bads_by_threshold(epochs) + """ + bads = _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold) + logger.info( + f"Found {len(bads)} channels with high voltage variance: {bads}" + ) + return _threshold_volt_std(epochs, flag_dim="ch", threshold=threshold) + + def _threshold_volt_std(epochs, flag_dim, threshold=5e-5): """Detect epochs or channels whose voltage std is above threshold. @@ -765,9 +811,13 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"): on. You may need to assess a more appropriate value for your own data. """ epochs = self.get_epochs(picks=picks) - above_threshold = _threshold_volt_std( - epochs, flag_dim=flag_dim, threshold=threshold - ) + # So yes this add a few LOC, but IMO it's worth it for readability + if flag_dim == "ch": + above_threshold = find_bads_by_threshold(epochs, threshold=threshold) + else: # TODO: Implement an annotate_bads_by_threshold for epochs + above_threshold = _threshold_volt_std( + epochs, flag_dim=flag_dim, threshold=threshold + ) self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs) def find_outlier_chs(self, epochs=None, picks="eeg"):