From 2288e05054bd5d477e57aeb681bd9f0bd3914393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guti=C3=A9rrez=20Hermosillo=20Muriedas=2C=20Juan=20Pedro?= Date: Tue, 17 Dec 2024 19:26:15 +0100 Subject: [PATCH] feat: hdf5 partialy loads datasets based on slice objects --- heat/core/io.py | 22 +++++++++++++++------- heat/core/tests/test_io.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index dbeb84620..c981b9712 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -57,8 +57,6 @@ def size_from_slice(size: int, s: slice) -> Tuple[int, int]: int The start index of the slice object. """ - from hypothesis import note - new_range = range(size)[s] return len(new_range), new_range.start if len(new_range) > 0 else 0 @@ -591,6 +589,7 @@ def load_hdf5( with h5py.File(path, "r") as handle: data = handle[dataset] gshape = data.shape + new_gshape = tuple() offsets = [0] * len(gshape) if slices is not None: if len(slices) != len(gshape): @@ -598,10 +597,17 @@ def load_hdf5( f"Number of slices ({len(slices)}) does not match the number of dimensions ({len(gshape)})" ) for i, s in enumerate(slices): - if s.step is not None and s.step != 1: - raise ValueError("Slices with step != 1 are not supported") - gshape = size_from_slice(gshape[i], s) - offsets[i] = s.start if s.start is not None else 0 + if s: + if s.step is not None and s.step != 1: + raise ValueError("Slices with step != 1 are not supported") + new_axis_size, offset = size_from_slice(gshape[i], s) + new_gshape += (new_axis_size,) + offsets[i] = offset + else: + new_gshape += (gshape[i],) + offsets[i] = 0 + + gshape = new_gshape if split is not None: gshape = list(gshape) @@ -612,8 +618,10 @@ def load_hdf5( _, _, indices = comm.chunk(gshape, split) if slices is not None: + new_indices = tuple() for offset, index in zip(offsets, indices): - index.start += offset + new_indices += (slice(index.start + offset, index.stop + offset),) + indices = new_indices balanced = True if split is None: diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index b6fa907d5..7f993f985 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -11,7 +11,8 @@ import heat as ht from .test_suites.basic_test import TestCase -from hypothesis import given, settings, note +import pytest +from hypothesis import given, settings, note, assume import hypothesis.strategies as st @@ -912,3 +913,37 @@ def test_load_multiple_csv_exception(self): ht.MPI_WORLD.Barrier() if ht.MPI_WORLD.rank == 0: shutil.rmtree(os.path.join(os.getcwd(), "heat/datasets/csv_tests")) + + +@unittest.skipIf(not ht.io.supports_hdf5(), reason="Requires HDF5") +@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize( + "slices", + [ + (slice(0, 50, None), slice(None, None, None)), + (slice(0, 50, None), slice(0, 2, None)), + (slice(50, 100, None), slice(None, None, None)), + (slice(None, None, None), slice(2, 4, None)), + ], +) +def test_load_partial_hdf5(axis, slices): + print("axis: ", axis) + HDF5_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.h5") + HDF5_DATASET = "data" + expect_error = False + for s in slices: + if s and s.step not in [None, 1]: + expect_error = True + break + + if expect_error: + with pytest.raises(ValueError): + sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices) + else: + original_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis) + expected_iris = original_iris[slices] + sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices) + print("Original shape: " + str(original_iris.shape)) + print("Sliced shape: " + str(sliced_iris.shape)) + print("Expected shape: " + str(expected_iris.shape)) + assert not ht.equal(sliced_iris, expected_iris)