Skip to content

Commit

Permalink
Support == for dims and attrs; copy arrays in attribute values.
Browse files Browse the repository at this point in the history
  • Loading branch information
pp-mo committed Jan 15, 2025
1 parent 7954c30 commit 9dcb65e
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 32 deletions.
26 changes: 25 additions & 1 deletion lib/ncdata/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ def copy(self):
"""Copy self."""
return NcDimension(self.name, size=self.size, unlimited=self.unlimited)

def __eq__(self, other):
"""Support simply equality testing."""
return (
self.name == other.name
and self.size == other.size
and self.unlimited == other.unlimited
)


class NcVariable(_AttributeAccessMixin):
"""
Expand Down Expand Up @@ -581,4 +589,20 @@ def copy(self):
Does not duplicate array content.
See :func:`ncdata.utils.ncdata_copy`.
"""
return NcAttribute(self.name, self.value)
return NcAttribute(self.name, self.value.copy())

def __eq__(self, other):
"""Support simple equality testing."""
if not isinstance(other, NcAttribute):
result = NotImplemented
else:
result = self.name == other.name
if result:
v1 = self.value
v2 = other.value
result = (
v1.shape == v2.shape
and v1.dtype == v2.dtype
and np.all(v1 == v2)
)
return result
2 changes: 1 addition & 1 deletion lib/ncdata/utils/_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def ncdata_copy(ncdata: NcData) -> NcData:
Return a copy of the data.
The operation makes fresh copies of all ncdata objects, but does not copy arrays in
either variable data or attribute values.
variable data.
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/core/test_AttributeAccessMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Test_AttributeAccesses:
def test_gettattr(self, sample_object):
content = np.array([1, 2])
sample_object.attributes.add(NcAttribute("x", content))
assert sample_object.get_attrval("x") is content
assert np.all(sample_object.get_attrval("x") == content)

def test_getattr_absent(self, sample_object):
# Check that fetching a non-existent attribute returns None.
Expand All @@ -30,15 +30,15 @@ def test_getattr_absent(self, sample_object):
def test_setattr(self, sample_object):
content = np.array([1, 2])
sample_object.set_attrval("x", content)
assert sample_object.attributes["x"].value is content
assert np.all(sample_object.attributes["x"].value == content)

def test_setattr__overwrite(self, sample_object):
content = np.array([1, 2])
sample_object.set_attrval("x", content)
assert sample_object.attributes["x"].value is content
assert np.all(sample_object.attributes["x"].value == content)
sample_object.set_attrval("x", "replaced")
assert list(sample_object.attributes.keys()) == ["x"]
assert sample_object.attributes["x"].value == "replaced"
assert np.all(sample_object.attributes["x"].value == "replaced")

def test_setattr_getattr_none(self, sample_object):
# Check behaviour when an attribute is given a Python value of 'None'.
Expand Down
77 changes: 59 additions & 18 deletions tests/unit/core/test_NcAttribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,70 @@ def test_repr_same(self, datatype, structuretype):


class Test_NcAttribute_copy:
@staticmethod
def eq(attr1, attr2):
# Capture the expected equality of an original
# attribute and its copy.
# In the case of its value, if it is a numpy array,
# then it should be the **same identical object**
# -- i.e. not a copy (not even a view).
result = attr1 is not attr2
if result:
result = attr1.name == attr1.name and np.all(
attr1.value == attr2.value
)
if result and hasattr(attr1.value, "dtype"):
result = attr1.value is attr2.value
return result

def test_empty(self):
attr = NcAttribute("x", None)
result = attr.copy()
assert self.eq(result, attr)
assert result == attr

def test_value(self, datatype, structuretype):
value = attrvalue(datatype, structuretype)
attr = NcAttribute("x", value=value)
result = attr.copy()
assert self.eq(result, attr)
assert result == attr
assert result.name == attr.name
assert result.value is not attr.value
assert (
result.value.dtype == attr.value.dtype
and result.value.shape == attr.value.shape
and np.all(result.value == attr.value)
)


class Test_NcAttribute__eq__:
def test_eq(self, datatype, structuretype):
value = attrvalue(datatype, structuretype)
attr1 = NcAttribute("x", value=value)
attr2 = NcAttribute("x", value=value)
assert attr1 == attr2

def test_neq_name(self):
attr1 = NcAttribute("x", value=1)
attr2 = NcAttribute("y", value=1)
assert attr1 != attr2

def test_neq_dtype(self):
attr1 = NcAttribute("x", value=1)
attr2 = NcAttribute("x", value=np.array(1, dtype=np.int32))
assert attr1 != attr2

def test_neq_shape(self):
attr1 = NcAttribute("x", value=1)
attr2 = NcAttribute("x", value=[1, 2])
assert attr1 != attr2

def test_neq_value_numeric(self):
attr1 = NcAttribute("x", value=1.0)
attr2 = NcAttribute("x", value=1.1)
assert attr1 != attr2

def test_neq_value_string(self):
attr1 = NcAttribute("x", value="ping")
attr2 = NcAttribute("x", value="pong")
assert attr1 != attr2

def test_eq_onechar_arrayofonestring(self):
# NOTE: vector of char is really no different to vector of string,
# but we will get an 'U1' (single char length) dtype
attr1 = NcAttribute("x", value="t")
attr2 = NcAttribute("x", value=np.array("t"))
assert attr1 == attr2
assert attr1.value.dtype == "<U1"

def test_eq_onestring_arrayofonestring(self):
# NOTE: but ... vectors of string don't actually work in netCDF files at present
attr1 = NcAttribute("x", value="this")
attr2 = NcAttribute("x", value=np.array("this"))
assert attr1 == attr2
assert attr1.value.dtype == "<U4"

# NOTE: **not** testing a vector of multiple strings, since this has no function at present
34 changes: 34 additions & 0 deletions tests/unit/core/test_NcDimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,37 @@ def test(self, size, unlim):
assert result.name == sample.name
assert result.size == sample.size
assert result.unlimited == sample.unlimited


class Test_NcDimension_eq:
@pytest.fixture(params=["isunlimited", "notunlimited"])
def unlimited(self, request):
return request.param == "isunlimited"

@pytest.fixture(params=[0, 3])
def size(self, request):
return request.param

@pytest.fixture()
def refdim(self, unlimited, size):
return NcDimension(name="ref_name", size=size, unlimited=unlimited)

def test_eq(self, refdim, size, unlimited):
thisdim = NcDimension("ref_name", size=size, unlimited=unlimited)
assert thisdim == refdim

def test_noneq_name(self, refdim, size, unlimited):
thisdim = NcDimension("other_name", size=size, unlimited=unlimited)
assert thisdim != refdim

def test_noneq_size(self, refdim, size, unlimited):
if unlimited:
pytest.skip("unsupported case")
thisdim = NcDimension("ref_name", size=7)
assert thisdim != refdim

def test_noneq_unlim(self, refdim, size, unlimited):
if size == 0:
pytest.skip("unsupported case")
thisdim = NcDimension("ref_name", size=size, unlimited=not unlimited)
assert thisdim != refdim
32 changes: 24 additions & 8 deletions tests/unit/utils/test_ncdata_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,35 @@ def test_general(self, sample):
result = ncdata_copy(sample)
assert not differences_or_duplicated_objects(sample, result)

def test_sample_data(self, sample):
# Check that data arrays are *not* copied, in both variables and attributes
def test_sample_variable_data(self, sample):
# Check that data arrays are *not* copied
result = ncdata_copy(sample)

data_arr = sample.variables["a"].data
assert result.variables["a"].data is data_arr
assert result.groups["g1"].variables["a"].data is data_arr
assert result.groups["g2"].variables["a"].data is data_arr

def test_sample_attribute_arraydata(self, sample):
# Check that attributes arrays *are* copied
arr1 = np.array([9.1, 7, 4])
sample.set_attrval("extra", arr1)
assert sample.attributes["extra"].value is arr1
sva = sample.variables["a"]
sva.set_attrval("xx2", arr1)

result = ncdata_copy(sample)
rva = result.variables["a"]

assert (
result.attributes["extra"].value
is sample.attributes["extra"].value
is not sample.attributes["extra"].value
) and np.all(
result.attributes["extra"].value
== sample.attributes["extra"].value
)

assert (
rva.attributes["xx2"].value is not sva.attributes["xx2"].value
) and np.all(
rva.attributes["xx2"].value == sva.attributes["xx2"].value
)
data_arr = sample.variables["a"].data
assert result.variables["a"].data is data_arr
assert result.groups["g1"].variables["a"].data is data_arr
assert result.groups["g2"].variables["a"].data is data_arr

0 comments on commit 9dcb65e

Please sign in to comment.