Skip to content

Commit

Permalink
Fix schema checks with default dim filters
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-dark committed Dec 22, 2023
1 parent 40711aa commit 0d21a14
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 61 deletions.
151 changes: 105 additions & 46 deletions tests/core/test_dataspace_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,47 +8,6 @@


class TestDataspaceCreatorExample1:
expected_schemas = {
"A1": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index", domain=(0, 3), tile=2, dtype=np.uint64
)
),
attrs=[
tiledb.Attr(name="pressure.data", dtype=np.float64),
tiledb.Attr(name="b", dtype=np.float64),
tiledb.Attr(name="c", dtype=np.uint64),
],
),
"A2": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index", domain=(0, 3), tile=2, dtype=np.uint64
),
tiledb.Dim(
name="temperature",
domain=(1, 8),
tile=4,
dtype=np.uint64,
filters=tiledb.FilterList(
[
tiledb.ZstdFilter(level=1),
]
),
),
),
sparse=True,
attrs=[tiledb.Attr(name="d", dtype=np.uint64)],
),
"A3": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(name="temperature", domain=(1, 8), tile=8, dtype=np.uint64),
),
attrs=[tiledb.Attr(name="e", dtype=np.float64)],
),
}

@pytest.fixture
def dataspace_creator(self):
creator = DataspaceCreator()
Expand Down Expand Up @@ -123,21 +82,114 @@ def test_get_shared_dims(self, dataspace_creator):
assert shared_dim.name == "temperature"

def test_to_schema(self, dataspace_creator):
expected_schemas = {
"A1": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index", domain=(0, 3), tile=2, dtype=np.uint64
)
),
attrs=[
tiledb.Attr(name="pressure.data", dtype=np.float64),
tiledb.Attr(name="b", dtype=np.float64),
tiledb.Attr(name="c", dtype=np.uint64),
],
),
"A2": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index", domain=(0, 3), tile=2, dtype=np.uint64
),
tiledb.Dim(
name="temperature",
domain=(1, 8),
tile=4,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter(level=1)]),
),
),
sparse=True,
attrs=[tiledb.Attr(name="d", dtype=np.uint64)],
),
"A3": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="temperature",
domain=(1, 8),
tile=8,
dtype=np.uint64,
),
),
attrs=[tiledb.Attr(name="e", dtype=np.float64)],
),
}

group_schema = dataspace_creator.to_schema()
assert isinstance(group_schema, dict)
assert len(group_schema) == 3
assert group_schema["A1"] == self.expected_schemas["A1"]
assert group_schema["A2"] == self.expected_schemas["A2"]
assert group_schema["A3"] == self.expected_schemas["A3"]
assert group_schema["A1"] == expected_schemas["A1"]
assert group_schema["A2"] == expected_schemas["A2"]
assert group_schema["A3"] == expected_schemas["A3"]

def test_create_group(self, dataspace_creator, tmpdir_factory):
expected_schemas = {
"A1": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index",
domain=(0, 3),
tile=2,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
)
),
attrs=[
tiledb.Attr(name="pressure.data", dtype=np.float64),
tiledb.Attr(name="b", dtype=np.float64),
tiledb.Attr(name="c", dtype=np.uint64),
],
),
"A2": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="pressure.index",
domain=(0, 3),
tile=2,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
),
tiledb.Dim(
name="temperature",
domain=(1, 8),
tile=4,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter(level=1)]),
),
),
sparse=True,
attrs=[tiledb.Attr(name="d", dtype=np.uint64)],
),
"A3": tiledb.ArraySchema(
domain=tiledb.Domain(
tiledb.Dim(
name="temperature",
domain=(1, 8),
tile=8,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
),
),
attrs=[tiledb.Attr(name="e", dtype=np.float64)],
),
}

uri = str(tmpdir_factory.mktemp("output").join("dataspace_example_1"))
dataspace_creator.create_group(uri, append=False)
with tiledb.Group(uri, mode="r") as group:
assert len(group) == 3
for item in group:
result_schema = tiledb.ArraySchema.load(item.uri)
expected_schema = self.expected_schemas[item.name]
expected_schema = expected_schemas[item.name]
assert result_schema == expected_schema


Expand All @@ -155,7 +207,14 @@ def test_create_array(tmpdir):
# Load and check schema
schema = tiledb.ArraySchema.load(array_uri)
assert schema == tiledb.ArraySchema(
tiledb.Domain(tiledb.Dim("d1", domain=(0, 1_000_000), dtype=np.uint32)),
tiledb.Domain(
tiledb.Dim(
"d1",
domain=(0, 1_000_000),
dtype=np.uint32,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
)
),
attrs=[tiledb.Attr("a1", dtype=np.float64)],
)

Expand Down
16 changes: 14 additions & 2 deletions tests/core/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,20 @@
import tiledb
from tiledb.cf import create_group, open_group_array

_row = tiledb.Dim(name="rows", domain=(1, 4), tile=4, dtype=np.uint64)
_col = tiledb.Dim(name="cols", domain=(1, 4), tile=4, dtype=np.uint64)
_row = tiledb.Dim(
name="rows",
domain=(1, 4),
tile=4,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
)
_col = tiledb.Dim(
name="cols",
domain=(1, 4),
tile=4,
dtype=np.uint64,
filters=tiledb.FilterList([tiledb.ZstdFilter()]),
)


_attr_a = tiledb.Attr(name="a", dtype=np.uint64)
Expand Down
Loading

0 comments on commit 0d21a14

Please sign in to comment.