Skip to content

Commit

Permalink
fix tests, update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou committed Oct 28, 2024
1 parent 5851d0d commit 49da362
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
15 changes: 7 additions & 8 deletions tests/test_dataset_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def test_dataset_selection(dummy_data):
split=None,
config=temp_config_file.name,
)
assert len(ds.session_dict) == 2
assert ds.session_dict["allen_neuropixels_mock/20100102_1"]["filename"] == (
assert len(ds.recording_dict) == 2
assert ds.recording_dict["allen_neuropixels_mock/20100102_1"]["filename"] == (
dummy_data / "allen_neuropixels_mock" / "20100102_1.h5"
)

Expand All @@ -159,7 +159,7 @@ def test_dataset_selection(dummy_data):
split=None,
config=temp_config_file.name,
)
assert len(ds.session_dict) == 1
assert len(ds.recording_dict) == 1

with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_config_file:
yaml.dump(
Expand All @@ -171,18 +171,17 @@ def test_dataset_selection(dummy_data):
split=None,
config=temp_config_file.name,
)
assert len(ds.session_dict) == 1
assert len(ds.recording_dict) == 1


def test_get_session_data(dummy_data):
def test_get_recording_data(dummy_data):
ds = Dataset(
dummy_data,
split=None,
brainset="allen_neuropixels_mock",
session="20100102_1",
recording_id="allen_neuropixels_mock/20100102_1",
)

data = ds.get_session_data("allen_neuropixels_mock/20100102_1")
data = ds.get_recording_data("allen_neuropixels_mock/20100102_1")

assert len(data.spikes) == 1000
assert len(data.gabors) == 1000
6 changes: 3 additions & 3 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# helper
def compare_slice_indices(a, b):
return (
(a.session_id == b.session_id)
(a.recording_id == b.recording_id)
and np.isclose(a.start, b.start)
and np.isclose(a.end, b.end)
)
Expand All @@ -25,8 +25,8 @@ def compare_slice_indices(a, b):
# helper
def samples_in_interval_dict(samples, interval_dict):
for s in samples:
assert s.session_id in interval_dict
allowed_intervals = interval_dict[s.session_id]
assert s.recording_id in interval_dict
allowed_intervals = interval_dict[s.recording_id]
if not (
sum(
[
Expand Down
70 changes: 37 additions & 33 deletions torch_brain/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,43 @@

@dataclass
class DatasetIndex:
"""Accessing the dataset is done by specifying a recording id and a time interval."""
r"""The dataset can be indexed by specifying a recording id and a start and end time."""

recording_id: str
start: float
end: float


class Dataset(torch.utils.data.Dataset):
r"""This class abstracts a collection of lazily-loaded Data objects. Each of these
Data objects corresponds to a session and lives on the disk until it is requested.
To request a piece of a included session's data, you can use the `get` method,
or index the Dataset with a `DatasetIndex` object (see `__getitem__`).
This definition is a deviation from the standard PyTorch Dataset definition, which
generally presents the dataset directly as samples. In this case, the Dataset
by itself does not provide you with samples, but rather the means to flexibly work
and access complete sessions.
Within this framework, it is the job of the sampler to provide the
DatasetIndex indices to slice the dataset into samples (see `kirby.data.sampler`).
Files will be opened, and only closed when the Dataset object is deleted.
r"""This class abstracts a collection of lazily-loaded Data objects. Each data object
corresponds to a full recording. It is never fully loaded into memory, but rather
lazy-loaded on-the-fly from disk.
The dataset can be indexed by a recording id and a start and end time using the `get`
method, or by a DatasetIndex object. This definition is a deviation from the standard
PyTorch Dataset definition, which generally presents the dataset directly as samples.
In this case, the Dataset by itself does not provide you with samples, but rather the
means to flexibly work and access complete sessions.
Within this framework, it is the job of the sampler to provide a list of
DatasetIndex objects that are used to slice the dataset into samples (see
`torch_brain.data.sampler`).
The lazy loading is done both in:
- time: only the requested time interval is loaded, without having to load the entire
recording into memory, and
- attributes: attributes are not loaded until they are requested, this is useful when
only a small subset of the attributes are actually needed.
References to the underlying hdf5 files will be opened, and will only be closed when
the Dataset object is destroyed.
Args:
root: The root directory of the dataset.
config: The configuration file specifying the sessions to include.
brainset: The brainset to include. This is used to specify a single brainset, and can only be used if config is not provided.
session: The session to include. This is used to specify a single session, and can only be used if config is not provided.
brainset: The brainset to include. This is used to specify a single brainset,
and can only be used if config is not provided.
session: The session to include. This is used to specify a single session, and
can only be used if config is not provided.
split: The split of the dataset. This is used to determine the sampling intervals
for each session. The split is optional, and is used to load a subset of the data
in a session based on a predefined split.
Expand All @@ -59,8 +69,7 @@ def __init__(
root: str,
*,
config: str = None,
brainset: str = None,
session: str = None,
recording_id: str = None,
split: str = None,
transform=None,
):
Expand All @@ -72,30 +81,25 @@ def __init__(

if config is not None:
assert (
brainset is None and session is None
), "Cannot specify brainset or session when using config."
recording_id is None
), "Cannot specify recording_id when using config."

if Path(config).is_file():
config = omegaconf.OmegaConf.load(config)
else:
raise ValueError(f"Config source '{config}' not found.")
raise ValueError(f"Could not open configuration file: '{config}'")

self.recording_dict = self._look_for_files(config)

elif brainset is not None or session is not None:
assert (
brainset is not None and session is not None
), "Please specify both brainset and session."
elif recording_id is not None:
self.recording_dict = {
f"{brainset}/{session}": {
"filename": Path(self.root) / brainset / (session + ".h5"),
recording_id: {
"filename": Path(self.root) / (recording_id + ".h5"),
"config": {},
}
}
else:
raise ValueError(
"Please either specify a config file or a brainset and session."
)
raise ValueError("Please either specify a config file or a recording_id.")

self._open_files = {
recording_id: h5py.File(recording_info["filename"], "r")
Expand Down Expand Up @@ -246,7 +250,7 @@ def _look_for_files(self, config: omegaconf.DictConfig) -> Dict[str, Dict]:
for session_id in session_ids:
recording_id = subselection["brainset"] + "/" + session_id

if session_id in recording_dict:
if recording_id in recording_dict:
raise ValueError(
f"Recording {recording_id} is already included in the dataset."
"Please verify that it is only selected once."
Expand All @@ -260,8 +264,8 @@ def _look_for_files(self, config: omegaconf.DictConfig) -> Dict[str, Dict]:
return recording_dict

def get(self, recording_id: str, start: float, end: float):
r"""This is the main method to extract a slice from a session. It returns a
Data object that contains all data for session :obj:`recording_id` between
r"""This is the main method to extract a slice from a recording. It returns a
Data object that contains all data for recording :obj:`recording_id` between
times :obj:`start` and :obj:`end`.
Args:
Expand Down

0 comments on commit 49da362

Please sign in to comment.