Skip to content

Commit

Permalink
feat: hdf5 partialy loads datasets based on slice objects
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPedroGHM committed Dec 17, 2024
1 parent 3a397b8 commit 2288e05
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
22 changes: 15 additions & 7 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def size_from_slice(size: int, s: slice) -> Tuple[int, int]:
int
The start index of the slice object.
"""
from hypothesis import note

new_range = range(size)[s]
return len(new_range), new_range.start if len(new_range) > 0 else 0

Expand Down Expand Up @@ -591,17 +589,25 @@ def load_hdf5(
with h5py.File(path, "r") as handle:
data = handle[dataset]
gshape = data.shape
new_gshape = tuple()
offsets = [0] * len(gshape)
if slices is not None:
if len(slices) != len(gshape):
raise ValueError(
f"Number of slices ({len(slices)}) does not match the number of dimensions ({len(gshape)})"
)
for i, s in enumerate(slices):
if s.step is not None and s.step != 1:
raise ValueError("Slices with step != 1 are not supported")
gshape = size_from_slice(gshape[i], s)
offsets[i] = s.start if s.start is not None else 0
if s:
if s.step is not None and s.step != 1:
raise ValueError("Slices with step != 1 are not supported")
new_axis_size, offset = size_from_slice(gshape[i], s)
new_gshape += (new_axis_size,)
offsets[i] = offset
else:
new_gshape += (gshape[i],)
offsets[i] = 0

gshape = new_gshape

if split is not None:
gshape = list(gshape)
Expand All @@ -612,8 +618,10 @@ def load_hdf5(
_, _, indices = comm.chunk(gshape, split)

if slices is not None:
new_indices = tuple()
for offset, index in zip(offsets, indices):
index.start += offset
new_indices += (slice(index.start + offset, index.stop + offset),)
indices = new_indices

balanced = True
if split is None:
Expand Down
37 changes: 36 additions & 1 deletion heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import heat as ht
from .test_suites.basic_test import TestCase

from hypothesis import given, settings, note
import pytest
from hypothesis import given, settings, note, assume
import hypothesis.strategies as st


Expand Down Expand Up @@ -912,3 +913,37 @@ def test_load_multiple_csv_exception(self):
ht.MPI_WORLD.Barrier()
if ht.MPI_WORLD.rank == 0:
shutil.rmtree(os.path.join(os.getcwd(), "heat/datasets/csv_tests"))


@unittest.skipIf(not ht.io.supports_hdf5(), reason="Requires HDF5")
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize(
"slices",
[
(slice(0, 50, None), slice(None, None, None)),
(slice(0, 50, None), slice(0, 2, None)),
(slice(50, 100, None), slice(None, None, None)),
(slice(None, None, None), slice(2, 4, None)),
],
)
def test_load_partial_hdf5(axis, slices):
print("axis: ", axis)
HDF5_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.h5")
HDF5_DATASET = "data"
expect_error = False
for s in slices:
if s and s.step not in [None, 1]:
expect_error = True
break

if expect_error:
with pytest.raises(ValueError):
sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices)
else:
original_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis)
expected_iris = original_iris[slices]
sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices)
print("Original shape: " + str(original_iris.shape))
print("Sliced shape: " + str(sliced_iris.shape))
print("Expected shape: " + str(expected_iris.shape))
assert not ht.equal(sliced_iris, expected_iris)

0 comments on commit 2288e05

Please sign in to comment.