From ef5cdb9732b99f7bfdbed89ebe3bd310eb1a679b Mon Sep 17 00:00:00 2001
From: Mathieu Scheltienne <mathieu.scheltienne@fcbg.ch>
Date: Fri, 22 Dec 2023 16:01:55 +0100
Subject: [PATCH] add placeholder for epoch selection

---
 pycrostates/segmentation/segmentation.py | 82 +++++++++++++++++++++++-
 1 file changed, 81 insertions(+), 1 deletion(-)

diff --git a/pycrostates/segmentation/segmentation.py b/pycrostates/segmentation/segmentation.py
index 951bdfe7..e43442b6 100644
--- a/pycrostates/segmentation/segmentation.py
+++ b/pycrostates/segmentation/segmentation.py
@@ -1,6 +1,8 @@
 """Segmentation module for segmented data."""
 
-from typing import Optional, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
 
 from matplotlib.axes import Axes
 from mne import BaseEpochs
@@ -11,6 +13,11 @@
 from ..viz import plot_epoch_segmentation, plot_raw_segmentation
 from ._base import _BaseSegmentation
 
+if TYPE_CHECKING:
+    from typing import Optional, Union
+
+    from pandas import DataFrame
+
 
 @fill_doc
 class RawSegmentation(_BaseSegmentation):
@@ -130,6 +137,74 @@ def __init__(self, *args, **kwargs):
                 f"samples, while the 'labels' has {self._labels.shape[-1]} samples."
             )
 
+    def __getitem__(self, item):
+        """Select epochs in a :class:`~pycrostates.segmentation.EpochsSegmentation`.
+
+        Parameters
+        ----------
+        item : slice, array-like, str, or list
+            See below for use cases.
+
+        Returns
+        -------
+        epochs : instance of EpochsSegmentation
+            Returns a copy of the original instance. See below for use cases.
+
+        Notes
+        -----
+        :class:`~pycrostates.segmentation.EpochsSegmentation` can be accessed as
+        ``segmentation[...]`` in several ways:
+
+        1. **Integer or slice:** ``segmentation[idx]`` will return an
+           :class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of
+           epochs chosen by index (supports single index and Python-style slicing).
+
+        2. **String:** ``segmentation['name']`` will return an
+           :class:`~pycrostates.segmentation.EpochsSegmentation` object comprising only
+           the epochs labeled ``'name'`` (i.e., epochs created around events with the
+           label ``'name'``).
+
+           If there are no epochs labeled ``'name'`` but there are epochs
+           labeled with /-separated tags (e.g. ``'name/left'``,
+           ``'name/right'``), then ``segmentation['name']`` will select the epochs
+           with labels that contain that tag (e.g., ``segmentation['left']`` selects
+           epochs labeled ``'audio/left'`` and ``'visual/left'``, but not
+           ``'audio_left'``).
+
+           If multiple tags are provided *as a single string* (e.g.,
+           ``segmentation['name_1/name_2']``), this selects epochs containing *all*
+           provided tags. For example, ``segmentation['audio/left']`` selects
+           ``'audio/left'`` and ``'audio/quiet/left'``, but not
+           ``'audio/right'``. Note that tag-based selection is insensitive to
+           order: tags like ``'audio/left'`` and ``'left/audio'`` will be
+           treated the same way when selecting via tag.
+
+        3. **List of strings:** ``segmentation[['name_1', 'name_2', ... ]]`` will
+           return an :class:`~pycrostates.segmentation.EpochsSegmentation` object
+           comprising epochs that match *any* of the provided names (i.e., the list of
+           names is treated as an inclusive-or condition). If *none* of the provided
+           names match any epoch labels, a ``KeyError`` will be raised.
+
+           If epoch labels are /-separated tags, then providing multiple tags
+           *as separate list entries* will likewise act as an inclusive-or
+           filter. For example, ``segmentation[['audio', 'left']]`` would select
+           ``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not
+           ``'visual/right'``.
+
+        4. **Pandas query:** ``segmentation['pandas query']`` will return an
+           :class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of
+           epochs (and matching metadata) selected by the query called with
+           ``self.metadata.eval``, e.g.::
+
+               epochs["col_a > 2 and col_b == 'foo'"]
+
+           would return all epochs whose associated ``col_a`` metadata was
+           greater than two, and whose ``col_b`` metadata was the string 'foo'.
+           Query-based indexing only works if Pandas is installed and
+           ``self.metadata`` is a :class:`pandas.DataFrame`.
+        """
+        inst = self.copy()  # noqa: F841
+
     @fill_doc
     def plot(
         self,
@@ -173,3 +248,8 @@ def plot(
     def epochs(self) -> BaseEpochs:
         """`~mne.Epochs` instance from which the segmentation was computed."""
         return self._inst.copy()
+
+    @property
+    def metadata(self) -> Optional[DataFrame]:
+        """Epochs metadata."""
+        return self._inst.metadata