From fbb982cce22be4b579f6d741b8b7da7132681b05 Mon Sep 17 00:00:00 2001 From: Mehdi Azabou Date: Thu, 24 Oct 2024 16:02:20 -0400 Subject: [PATCH] represent sampling intervals as Interval objects --- CHANGELOG.md | 3 +- tests/test_sampler.py | 64 ++++++++++++++++++++++++------------- torch_brain/data/dataset.py | 25 +++++++-------- torch_brain/data/sampler.py | 12 ++++--- 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ca035a..b1f5602 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,5 +9,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Update workflow to use ubuntu-latest instances from github actions. ([#8](httpps://github.com/neuro-galaxy/torch_brain/pull/8)) - +- Simplify Dataset interface by removing the `include` dictionnary and allowing to directly load selection from a configuration file. ([#10](https://github.com/neuro-galaxy/torch_brain/pull/10)) +- Sampling intervals are now represented as `Interval` objects. ([#11](https://github.com/neuro-galaxy/torch_brain/pull/11)) ### Fixed diff --git a/tests/test_sampler.py b/tests/test_sampler.py index eef66ef..d12ea16 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,6 +3,8 @@ import numpy as np import torch +from temporaldata import Interval + from torch_brain.data.sampler import ( SequentialFixedWindowSampler, RandomFixedWindowSampler, @@ -29,7 +31,9 @@ def samples_in_interval_dict(samples, interval_dict): sum( [ (s.start >= start) and (s.end <= end) - for start, end in allowed_intervals + for start, end in zip( + allowed_intervals.start, allowed_intervals.end + ) ] ) == 1 @@ -42,14 +46,18 @@ def samples_in_interval_dict(samples, interval_dict): def test_sequential_sampler(): sampler = SequentialFixedWindowSampler( interval_dict={ - "session1": [ - (0.0, 2.0), - (3.0, 4.5), - ], # 3 - "session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7 - "session3": [ - (1000.0, 1002.0), - ], # 2 + "session1": Interval( + start=np.array([0.0, 3.0]), + end=np.array([2.0, 4.5]), + ), + "session2": Interval( + start=np.array([0.1, 2.5, 15.0]), + end=np.array([1.25, 5.0, 18.7]), + ), + "session3": Interval( + start=np.array([1000.0]), + end=np.array([1002.0]), + ), }, window_length=1.1, step=0.75, @@ -84,14 +92,18 @@ def test_sequential_sampler(): def test_random_sampler(): interval_dict = { - "session1": [ - (0.0, 2.0), - (3.0, 4.5), - ], # 3 - "session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7 - "session3": [ - (1000.0, 1002.0), - ], # 2 + "session1": Interval( + start=np.array([0.0, 3.0]), + end=np.array([2.0, 4.5]), + ), # 3 + "session2": Interval( + start=np.array([0.1, 2.5, 15.0]), + end=np.array([1.25, 5.0, 18.7]), + ), # 7 + "session3": Interval( + start=np.array([1000.0]), + end=np.array([1002.0]), + ), # 2 } sampler = RandomFixedWindowSampler( @@ -135,12 +147,18 @@ def test_random_sampler(): def test_trial_sampler(): interval_dict = { - "session1": [ - (0.0, 2.0), - (3.0, 4.5), - ], - "session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], - "session3": [(1000.0, 1002.0), (1002.0, 1003.0)], + "session1": Interval( + start=np.array([0.0, 3.0]), + end=np.array([2.0, 4.5]), + ), + "session2": Interval( + start=np.array([0.1, 2.5, 15.0]), + end=np.array([1.25, 5.0, 18.7]), + ), + "session3": Interval( + start=np.array([1000.0, 1002.0]), + end=np.array([1002.0, 1003.0]), + ), } sampler = TrialSampler( diff --git a/torch_brain/data/dataset.py b/torch_brain/data/dataset.py index 660b5c8..964817c 100644 --- a/torch_brain/data/dataset.py +++ b/torch_brain/data/dataset.py @@ -315,21 +315,22 @@ def get_session_data(self, session_id: str): return data def get_sampling_intervals(self): - r"""Returns a dictionary of interval-list for each session. - Each interval-list is a list of tuples (start, end) for each interval. This - represents the intervals that can be sampled from each session. + r"""Returns a dictionary of sampling intervals for each session. + This represents the intervals that can be sampled from each session. - Note that these intervals will change depending on the split. + Note that these intervals will change depending on the split. If no split is + provided, the full domain of the data is used. """ - interval_dict = {} + sampling_intervals_dict = {} for session_id in self.session_dict.keys(): - intervals = getattr(self._data_objects[session_id], f"{self.split}_domain") + sampling_domain = ( + f"{self.split}_domain" if self.split is not None else "domain" + ) + intervals = getattr(self._data_objects[session_id], sampling_domain) sampling_intervals_modifier_code = self.session_dict[session_id][ "config" ].get("sampling_intervals_modifier", None) - if sampling_intervals_modifier_code is None: - interval_dict[session_id] = list(zip(intervals.start, intervals.end)) - else: + if sampling_intervals_modifier_code is not None: local_vars = { "data": copy.deepcopy(self._data_objects[session_id]), "sampling_intervals": intervals, @@ -351,10 +352,8 @@ def get_sampling_intervals(self): raise type(e)(error_message) from e sampling_intervals = local_vars.get("sampling_intervals") - interval_dict[session_id] = list( - zip(sampling_intervals.start, sampling_intervals.end) - ) - return interval_dict + sampling_intervals_dict[session_id] = sampling_intervals + return sampling_intervals_dict def get_session_config_dict(self): r"""Returns configs for each session in the dataset as a dictionary.""" diff --git a/torch_brain/data/sampler.py b/torch_brain/data/sampler.py index 5b2c402..21c5a7a 100644 --- a/torch_brain/data/sampler.py +++ b/torch_brain/data/sampler.py @@ -4,6 +4,8 @@ from functools import cached_property import torch +from temporaldata import Interval + from torch_brain.data.dataset import DatasetIndex @@ -33,7 +35,7 @@ class RandomFixedWindowSampler(torch.utils.data.Sampler): def __init__( self, *, - interval_dict: Dict[str, List[Tuple[float, float]]], + interval_dict: Dict[str, Interval], window_length: float, generator: Optional[torch.Generator], drop_short: bool = True, @@ -49,7 +51,7 @@ def _estimated_len(self): total_short_dropped = 0.0 for session_name, sampling_intervals in self.interval_dict.items(): - for start, end in sampling_intervals: + for start, end in zip(sampling_intervals.start, sampling_intervals.end): interval_length = end - start if interval_length < self.window_length: if self.drop_short: @@ -81,7 +83,7 @@ def __iter__(self): indices = [] for session_name, sampling_intervals in self.interval_dict.items(): - for start, end in sampling_intervals: + for start, end in zip(sampling_intervals.start, sampling_intervals.end): interval_length = end - start if interval_length < self.window_length: if self.drop_short: @@ -177,7 +179,7 @@ def _indices(self) -> List[DatasetIndex]: total_short_dropped = 0.0 for session_name, sampling_intervals in self.interval_dict.items(): - for start, end in sampling_intervals: + for start, end in zip(sampling_intervals.start, sampling_intervals.end): interval_length = end - start if interval_length < self.window_length: if self.drop_short: @@ -252,7 +254,7 @@ def __iter__(self): all_intervals = [ (session_id, start, end) for session_id, intervals in self.interval_dict.items() - for start, end in intervals + for start, end in zip(intervals.start, intervals.end) ] indices = [