diff --git a/pycrostates/segmentation/segmentation.py b/pycrostates/segmentation/segmentation.py index 951bdfe7..e43442b6 100644 --- a/pycrostates/segmentation/segmentation.py +++ b/pycrostates/segmentation/segmentation.py @@ -1,6 +1,8 @@ """Segmentation module for segmented data.""" -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from matplotlib.axes import Axes from mne import BaseEpochs @@ -11,6 +13,11 @@ from ..viz import plot_epoch_segmentation, plot_raw_segmentation from ._base import _BaseSegmentation +if TYPE_CHECKING: + from typing import Optional, Union + + from pandas import DataFrame + @fill_doc class RawSegmentation(_BaseSegmentation): @@ -130,6 +137,74 @@ def __init__(self, *args, **kwargs): f"samples, while the 'labels' has {self._labels.shape[-1]} samples." ) + def __getitem__(self, item): + """Select epochs in a :class:`~pycrostates.segmentation.EpochsSegmentation`. + + Parameters + ---------- + item : slice, array-like, str, or list + See below for use cases. + + Returns + ------- + epochs : instance of EpochsSegmentation + Returns a copy of the original instance. See below for use cases. + + Notes + ----- + :class:`~pycrostates.segmentation.EpochsSegmentation` can be accessed as + ``segmentation[...]`` in several ways: + + 1. **Integer or slice:** ``segmentation[idx]`` will return an + :class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of + epochs chosen by index (supports single index and Python-style slicing). + + 2. **String:** ``segmentation['name']`` will return an + :class:`~pycrostates.segmentation.EpochsSegmentation` object comprising only + the epochs labeled ``'name'`` (i.e., epochs created around events with the + label ``'name'``). + + If there are no epochs labeled ``'name'`` but there are epochs + labeled with /-separated tags (e.g. ``'name/left'``, + ``'name/right'``), then ``segmentation['name']`` will select the epochs + with labels that contain that tag (e.g., ``segmentation['left']`` selects + epochs labeled ``'audio/left'`` and ``'visual/left'``, but not + ``'audio_left'``). + + If multiple tags are provided *as a single string* (e.g., + ``segmentation['name_1/name_2']``), this selects epochs containing *all* + provided tags. For example, ``segmentation['audio/left']`` selects + ``'audio/left'`` and ``'audio/quiet/left'``, but not + ``'audio/right'``. Note that tag-based selection is insensitive to + order: tags like ``'audio/left'`` and ``'left/audio'`` will be + treated the same way when selecting via tag. + + 3. **List of strings:** ``segmentation[['name_1', 'name_2', ... ]]`` will + return an :class:`~pycrostates.segmentation.EpochsSegmentation` object + comprising epochs that match *any* of the provided names (i.e., the list of + names is treated as an inclusive-or condition). If *none* of the provided + names match any epoch labels, a ``KeyError`` will be raised. + + If epoch labels are /-separated tags, then providing multiple tags + *as separate list entries* will likewise act as an inclusive-or + filter. For example, ``segmentation[['audio', 'left']]`` would select + ``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not + ``'visual/right'``. + + 4. **Pandas query:** ``segmentation['pandas query']`` will return an + :class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of + epochs (and matching metadata) selected by the query called with + ``self.metadata.eval``, e.g.:: + + epochs["col_a > 2 and col_b == 'foo'"] + + would return all epochs whose associated ``col_a`` metadata was + greater than two, and whose ``col_b`` metadata was the string 'foo'. + Query-based indexing only works if Pandas is installed and + ``self.metadata`` is a :class:`pandas.DataFrame`. + """ + inst = self.copy() # noqa: F841 + @fill_doc def plot( self, @@ -173,3 +248,8 @@ def plot( def epochs(self) -> BaseEpochs: """`~mne.Epochs` instance from which the segmentation was computed.""" return self._inst.copy() + + @property + def metadata(self) -> Optional[DataFrame]: + """Epochs metadata.""" + return self._inst.metadata