Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

represent sampling intervals as Interval objects #11

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Changed
- Update workflow to use ubuntu-latest instances from github actions. ([#8](httpps://github.com/neuro-galaxy/torch_brain/pull/8))

- Simplify Dataset interface by removing the `include` dictionnary and allowing to directly load selection from a configuration file. ([#10](https://github.com/neuro-galaxy/torch_brain/pull/10))
- Sampling intervals are now represented as `Interval` objects. ([#11](https://github.com/neuro-galaxy/torch_brain/pull/11))
### Fixed
64 changes: 41 additions & 23 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch

from temporaldata import Interval

from torch_brain.data.sampler import (
SequentialFixedWindowSampler,
RandomFixedWindowSampler,
Expand All @@ -29,7 +31,9 @@ def samples_in_interval_dict(samples, interval_dict):
sum(
[
(s.start >= start) and (s.end <= end)
for start, end in allowed_intervals
for start, end in zip(
allowed_intervals.start, allowed_intervals.end
)
]
)
== 1
Expand All @@ -42,14 +46,18 @@ def samples_in_interval_dict(samples, interval_dict):
def test_sequential_sampler():
sampler = SequentialFixedWindowSampler(
interval_dict={
"session1": [
(0.0, 2.0),
(3.0, 4.5),
], # 3
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7
"session3": [
(1000.0, 1002.0),
], # 2
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
),
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
),
"session3": Interval(
start=np.array([1000.0]),
end=np.array([1002.0]),
),
},
window_length=1.1,
step=0.75,
Expand Down Expand Up @@ -84,14 +92,18 @@ def test_sequential_sampler():
def test_random_sampler():

interval_dict = {
"session1": [
(0.0, 2.0),
(3.0, 4.5),
], # 3
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)], # 7
"session3": [
(1000.0, 1002.0),
], # 2
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
), # 3
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
), # 7
"session3": Interval(
start=np.array([1000.0]),
end=np.array([1002.0]),
), # 2
}

sampler = RandomFixedWindowSampler(
Expand Down Expand Up @@ -135,12 +147,18 @@ def test_random_sampler():

def test_trial_sampler():
interval_dict = {
"session1": [
(0.0, 2.0),
(3.0, 4.5),
],
"session2": [(0.1, 1.25), (2.5, 5.0), (15.0, 18.7)],
"session3": [(1000.0, 1002.0), (1002.0, 1003.0)],
"session1": Interval(
start=np.array([0.0, 3.0]),
end=np.array([2.0, 4.5]),
),
"session2": Interval(
start=np.array([0.1, 2.5, 15.0]),
end=np.array([1.25, 5.0, 18.7]),
),
"session3": Interval(
start=np.array([1000.0, 1002.0]),
end=np.array([1002.0, 1003.0]),
),
}

sampler = TrialSampler(
Expand Down
25 changes: 12 additions & 13 deletions torch_brain/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,22 @@ def get_session_data(self, session_id: str):
return data

def get_sampling_intervals(self):
r"""Returns a dictionary of interval-list for each session.
Each interval-list is a list of tuples (start, end) for each interval. This
represents the intervals that can be sampled from each session.
r"""Returns a dictionary of sampling intervals for each session.
This represents the intervals that can be sampled from each session.

Note that these intervals will change depending on the split.
Note that these intervals will change depending on the split. If no split is
provided, the full domain of the data is used.
"""
interval_dict = {}
sampling_intervals_dict = {}
for session_id in self.session_dict.keys():
intervals = getattr(self._data_objects[session_id], f"{self.split}_domain")
sampling_domain = (
f"{self.split}_domain" if self.split is not None else "domain"
)
intervals = getattr(self._data_objects[session_id], sampling_domain)
sampling_intervals_modifier_code = self.session_dict[session_id][
"config"
].get("sampling_intervals_modifier", None)
if sampling_intervals_modifier_code is None:
interval_dict[session_id] = list(zip(intervals.start, intervals.end))
else:
if sampling_intervals_modifier_code is not None:
local_vars = {
"data": copy.deepcopy(self._data_objects[session_id]),
"sampling_intervals": intervals,
Expand All @@ -351,10 +352,8 @@ def get_sampling_intervals(self):
raise type(e)(error_message) from e

sampling_intervals = local_vars.get("sampling_intervals")
interval_dict[session_id] = list(
zip(sampling_intervals.start, sampling_intervals.end)
)
return interval_dict
sampling_intervals_dict[session_id] = sampling_intervals
return sampling_intervals_dict

def get_session_config_dict(self):
r"""Returns configs for each session in the dataset as a dictionary."""
Expand Down
12 changes: 7 additions & 5 deletions torch_brain/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import cached_property

import torch
from temporaldata import Interval


from torch_brain.data.dataset import DatasetIndex

Expand Down Expand Up @@ -33,7 +35,7 @@ class RandomFixedWindowSampler(torch.utils.data.Sampler):
def __init__(
self,
*,
interval_dict: Dict[str, List[Tuple[float, float]]],
interval_dict: Dict[str, Interval],
window_length: float,
generator: Optional[torch.Generator],
drop_short: bool = True,
Expand All @@ -49,7 +51,7 @@ def _estimated_len(self):
total_short_dropped = 0.0

for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -81,7 +83,7 @@ def __iter__(self):

indices = []
for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -177,7 +179,7 @@ def _indices(self) -> List[DatasetIndex]:
total_short_dropped = 0.0

for session_name, sampling_intervals in self.interval_dict.items():
for start, end in sampling_intervals:
for start, end in zip(sampling_intervals.start, sampling_intervals.end):
interval_length = end - start
if interval_length < self.window_length:
if self.drop_short:
Expand Down Expand Up @@ -252,7 +254,7 @@ def __iter__(self):
all_intervals = [
(session_id, start, end)
for session_id, intervals in self.interval_dict.items()
for start, end in intervals
for start, end in zip(intervals.start, intervals.end)
]

indices = [
Expand Down
Loading