Skip to content

Commit

Permalink
represent sampling intervals as Interval objects (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou authored Oct 24, 2024
1 parent 8f91906 commit 14786d2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 42 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Changed
- Update workflow to use ubuntu-latest instances from github actions. ([#8](httpps://github.com/neuro-galaxy/torch_brain/pull/8))

- Simplify Dataset interface by removing the `include` dictionnary and allowing to directly load selection from a configuration file. ([#10](https://github.com/neuro-galaxy/torch_brain/pull/10))
- Sampling intervals are now represented as `Interval` objects. ([#11](https://github.com/neuro-galaxy/torch_brain/pull/11))
### Fixed
64 changes: 41 additions & 23 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch

from temporaldata import Interval

from torch_brain.data.sampler import (
SequentialFixedWindowSampler,
RandomFixedWindowSampler,
Expand All @@ -29,7 +31,9 @@ def samples_in_interval_dict(samples, interval_dict):
sum(
[
(s.start >= start) and (s.end <= end)
for start, end in allowed_intervals
for start, end in zip(
allowed_intervals.start, allowed_intervals.end
)
]
)
== 1
Expand All @@ -42,14 +46,18 @@ def samples_in_interval_dict(samples, interval_dict):
def test_sequential_sampler():
sampler = SequentialFixedWindowSampler(
interval_dict={
"session1": [
(0.0, 2.0),
(3.0, 4.5),
], # 3
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7
"session3": [
(1000.0, 1002.0),
], # 2
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
),
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
),
"session3": Interval(
start=np.array([1000.0]),
end=np.array([1002.0]),
),
},
window_length=1.1,
step=0.75,
Expand Down Expand Up @@ -84,14 +92,18 @@ def test_sequential_sampler():
def test_random_sampler():

interval_dict = {
"session1": [
(0.0, 2.0),
(3.0, 4.5),
], # 3
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7
"session3": [
(1000.0, 1002.0),
], # 2
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
), # 3
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
), # 7
"session3": Interval(
start=np.array([1000.0]),
end=np.array([1002.0]),
), # 2
}

sampler = RandomFixedWindowSampler(
Expand Down Expand Up @@ -135,12 +147,18 @@ def test_random_sampler():

def test_trial_sampler():
interval_dict = {
"session1": [
(0.0, 2.0),
(3.0, 4.5),
],
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)],
"session3": [(1000.0, 1002.0), (1002.0, 1003.0)],
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
),
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
),
"session3": Interval(
start=np.array([1000.0, 1002.0]),
end=np.array([1002.0, 1003.0]),
),
}

sampler = TrialSampler(
Expand Down
25 changes: 12 additions & 13 deletions torch_brain/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,22 @@ def get_session_data(self, session_id: str):
return data

def get_sampling_intervals(self):
r"""Returns a dictionary of interval-list for each session.
Each interval-list is a list of tuples (start, end) for each interval. This
represents the intervals that can be sampled from each session.
r"""Returns a dictionary of sampling intervals for each session.
This represents the intervals that can be sampled from each session.
Note that these intervals will change depending on the split.
Note that these intervals will change depending on the split. If no split is
provided, the full domain of the data is used.
"""
interval_dict = {}
sampling_intervals_dict = {}
for session_id in self.session_dict.keys():
intervals = getattr(self._data_objects[session_id], f"{self.split}_domain")
sampling_domain = (
f"{self.split}_domain" if self.split is not None else "domain"
)
intervals = getattr(self._data_objects[session_id], sampling_domain)
sampling_intervals_modifier_code = self.session_dict[session_id][
"config"
].get("sampling_intervals_modifier", None)
if sampling_intervals_modifier_code is None:
interval_dict[session_id] = list(zip(intervals.start, intervals.end))
else:
if sampling_intervals_modifier_code is not None:
local_vars = {
"data": copy.deepcopy(self._data_objects[session_id]),
"sampling_intervals": intervals,
Expand All @@ -351,10 +352,8 @@ def get_sampling_intervals(self):
raise type(e)(error_message) from e

sampling_intervals = local_vars.get("sampling_intervals")
interval_dict[session_id] = list(
zip(sampling_intervals.start, sampling_intervals.end)
)
return interval_dict
sampling_intervals_dict[session_id] = sampling_intervals
return sampling_intervals_dict

def get_session_config_dict(self):
r"""Returns configs for each session in the dataset as a dictionary."""
Expand Down
12 changes: 7 additions & 5 deletions torch_brain/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import cached_property

import torch
from temporaldata import Interval


from torch_brain.data.dataset import DatasetIndex

Expand Down Expand Up @@ -33,7 +35,7 @@ class RandomFixedWindowSampler(torch.utils.data.Sampler):
def __init__(
self,
*,
interval_dict: Dict[str, List[Tuple[float, float]]],
interval_dict: Dict[str, Interval],
window_length: float,
generator: Optional[torch.Generator],
drop_short: bool = True,
Expand All @@ -49,7 +51,7 @@ def _estimated_len(self):
total_short_dropped = 0.0

for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -81,7 +83,7 @@ def __iter__(self):

indices = []
for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -177,7 +179,7 @@ def _indices(self) -> List[DatasetIndex]:
total_short_dropped = 0.0

for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -252,7 +254,7 @@ def __iter__(self):
all_intervals = [
(session_id, start, end)
for session_id, intervals in self.interval_dict.items()
for start, end in intervals
for start, end in zip(intervals.start, intervals.end)
]

indices = [
Expand Down

0 comments on commit 14786d2

Please sign in to comment.