Skip to content

Commit

Permalink
FIX, DOC: Pipeline methods shouldn't accept just any instance of raw (#…
Browse files Browse the repository at this point in the history
…213)

* FIX, DOC: Pipeline methods shouldn't accept just any instance of raw

Following up from #212. I realized that if one did do `my_pipeline.find_outlier_chs(my_raw), then under the hood
this method would actually create an epochs instance from my_pipeline.raw ...

this would lead to unexpected results if my_raw and my_pipeline.raw are not the same object of memory (i.e. they are different raw objects).

Since in our codebase we never do pipeline.find_outlier_chs(raw), I dont think we should support this.

Instead, we should either always expect an instanc of mne.Epochs

OR

We should change the method signature to be def find_outlier_chs(epochs=None) , where if it is None, the method creates epochs from pipeline.raw under the hood.

* Apply suggestions from Christian O'Reilly code review

- Change API form `inst` to `epochs | None`
- if `None`, then make epochs from self.raw
- No need check and raise type error now.

Co-authored-by: Christian O'Reilly <[email protected]>

* FIX: cruft

* DOC: improve docstring a little

* DOC, FIX: add example to docstring and explicitly call pick

* TST: OK codecov here is your test :)

* STY: 2 lines bt functions

* API: make epochs=None the default find_outlier_chs(epochs=None)

---------

Co-authored-by: Christian O'Reilly <[email protected]>
  • Loading branch information
scott-huberty and christian-oreilly authored Dec 23, 2024
1 parent 1297027 commit c985968
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
51 changes: 41 additions & 10 deletions pylossless/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,18 +770,49 @@ def _flag_volt_std(self, flag_dim, threshold=5e-5, picks="eeg"):
)
self.flags[flag_dim].add_flag_cat("volt_std", above_threshold, epochs)

def find_outlier_chs(self, inst, picks="eeg"):
"""Detect outlier Channels to leave out of rereference."""
def find_outlier_chs(self, epochs=None, picks="eeg"):
"""Detect outlier Channels to leave out of rereference.
Parameters
----------
epochs : mne.Epochs | None
An instance of :class:`mne.Epochs`, or ``None``. If ``None``, then
:attr:`pylossless.LosslessPipeline.raw` should be set, and this
method will call :meth:`pylossless.LosslessPipeline.get_epochs`
to create epochs to use for outlier detection.
picks : str (default "eeg")
Channels to include in the outlier detection process. You can pass any
argument that is valid for the :meth:`~mne.Epochs.pick` method, but
you should avoid passing a mix of channel types with differing units of
measurement (e.g. EEG and MEG), as this would likely lead to incorrect
outlier detection (e.g. all EEG channels would be flagged as outliers).
Returns
-------
list
a list of channel names that are considered outliers.
Notes
-----
- This method is used to detect channels that are so noisy that they
should be left out of the robust average rereference process.
Examples
--------
>>> import mne
>>> import pylossless as ll
>>> config = ll.Config().load_default()
>>> pipeline = ll.LosslessPipeline(config=config)
>>> fname = mne.datasets.sample.data_path() / "MEG/sample/sample_audvis_raw.fif"
>>> raw = mne.io.read_raw(fname)
>>> epochs = mne.make_fixed_length_epochs(raw, preload=True)
>>> chs_to_leave_out = pipeline.find_outlier_chs(epochs=epochs)
"""
# TODO: Reuse _detect_outliers here.
logger.info("🔍 Detecting channels to leave out of reference.")
if isinstance(inst, mne.Epochs):
epochs = inst
elif isinstance(inst, mne.io.Raw):
epochs = self.get_epochs(rereference=False, picks=picks)
else:
raise TypeError(
"inst must be an MNE Raw or Epochs object," f" but got {type(inst)}."
)
if epochs is None:
epochs = self.get_epochs(rereference=False)
epochs = epochs.copy().pick(picks=picks)
epochs_xr = epochs_to_xr(epochs, kind="ch")

# Determines comically bad channels,
Expand Down
12 changes: 12 additions & 0 deletions pylossless/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def test_find_breaks(logging):
Path(config_fname).unlink() # delete config file


def test_find_outliers():
"""Test the find_outliers method for the case that epochs is None."""
fname = mne.datasets.sample.data_path() / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(fname, preload=True)
raw.apply_function(lambda x: x * 10, picks="EEG 001") # create an outlier
config = ll.config.Config().load_default()
pipeline = ll.LosslessPipeline(config=config)
pipeline.raw = raw
chs_to_leave_out = pipeline.find_outlier_chs()
assert chs_to_leave_out == ['EEG 001']


def test_deprecation():
"""Test the config_name property added for deprecation."""
config = ll.config.Config()
Expand Down

0 comments on commit c985968

Please sign in to comment.