diff --git a/docs/cheatsheet.md b/docs/cheatsheet.md index dda6beae..2102b172 100644 --- a/docs/cheatsheet.md +++ b/docs/cheatsheet.md @@ -143,7 +143,7 @@ ds.display() ### Entities -Entities can be thought as instances of protos or C++ structs. That is, they +Entities can be thought of as instances of protos or C++ structs. That is, they don't directly store their own schema. Instead, their schema is stored at DataSlice level and all entities in a DataSlice share the same schema. @@ -157,6 +157,8 @@ es = kd.new(x=kd.slice([1, 2, None]), e.get_schema() assert e.get_schema() == es.get_schema() +assert e.is_entity() + # Use an existing schema s = kd.named_schema('Point', x=kd.INT32, y=kd.INT32) e = kd.new(x=1, y=2, schema=s) @@ -222,82 +224,6 @@ nested = nested.updated(kd.attrs(nested.a.c, e=4),
-### Objects - -Objects can be thought as Python objects. They directly store their own schema -as **schema** attribute similar to how Python objects store **class** attribute. -This allows objects in a DataSlice to have different schemas. - -```py -o = kd.obj(x=1, y=2) -os = kd.obj(x=kd.slice([1, 2, None]), - y=kd.slice([4, None, 6])) - -os = kd.slice([kd.obj(x=1), - kd.obj(y=2.0), - kd.obj(x=1.0, y='a')]) - -os.get_schema() # kd.OBJECT -os.get_obj_schema() -# [IMPLICIT_SCHEMA(x=INT32), -# IMPLICIT_SCHEMA(y=FLOAT32), -# IMPLICIT_SCHEMA(x=INT32, y=STRING)] - -# Use provided itemids -itemid = kd.new_itemid() -o1 = kd.obj(x=1, y=2, itemid=itemid) -o2 = kd.obj(x=1, y=2, itemid=itemid) -assert o1.get_itemid() == o2.get_itemid() - -# Get available attributes -os1 = kd.slice([kd.obj(x=1), kd.obj(x=1.0, y='a')]) -# Attributes present in all objects -kd.dir(os1) # ['x'] -# Or -os1.get_attr_names(intersection=True) # ['x'] -os1.get_attr_names(intersection=False) # ['x', 'y'] - -# Access attribute -o.x # 1 -o.get_attr('y') # 2 -o.maybe('z') # None -o.get_attr('z', default=0) # 0 -os.get_attr('x', default=0) # [1, 0, 'a'] - -# Objects are immutable by default, modification is done -# by creating a new object with updated attributes -o = kd.obj(x=1, y=2) - -# Update a single attribute -o1 = o.with_attr('x', 3) -o1 = o.with_attr('z', 4) -# Also override schema -# no update_schema=True is needed -o1 = o.with_attr('y', 'a') - -# Update multiple attributes -o2 = o.with_attrs(z=4, x=3) -# Also override schema for 'y' -o2 = o.with_attrs(z=4, y='a') - -# Create an update and apply it separately -upd = kd.attrs(o, z=4, y=10) -o3 = o.updated(upd) - -# Allows mixing multiple updates -o4 = o.updated(kd.attrs(o, z=4), kd.attrs(o, y=10)) - -# Update nested attributes -nested = kd.obj(a=kd.obj(c=kd.obj(e=1), d=2), b=3) -nested = nested.updated(kd.attrs(nested.a.c, e=4), - kd.attrs(nested.a, d=5), - kd.attrs(nested, b=6)) -``` - -
- -
- ### Lists ```py @@ -408,6 +334,111 @@ d7 = d1.updated(d1.dict_update('c', 5),
+### Objects + +Objects can be thought of as Python objects. They directly store their own schema +as **schema** attribute similar to how Python objects store **class** attribute. +This allows objects in a DataSlice to have different schemas. Entities, Lists, +Dicts and primitives can be objects. Entities, Lists and Dicts store their own +schema as an internal `__schema__` attribute while primitives' schema is +determined by the type of their value. + +```py +# Entity objects +o = kd.obj(x=1, y=2) +os = kd.obj(x=kd.slice([1, 2, None]), + y=kd.slice([4, None, 6])) + +os = kd.slice([kd.obj(x=1), + kd.obj(y=2.0), + kd.obj(x=1.0, y='a')]) + +os.get_schema() # kd.OBJECT +os.get_obj_schema() +# [IMPLICIT_SCHEMA(x=INT32), +# IMPLICIT_SCHEMA(y=FLOAT32), +# IMPLICIT_SCHEMA(x=INT32, y=STRING)] + +# Use provided itemids +itemid = kd.new_itemid() +o1 = kd.obj(x=1, y=2, itemid=itemid) +o2 = kd.obj(x=1, y=2, itemid=itemid) +assert o1.get_itemid() == o2.get_itemid() + +# Get available attributes +os1 = kd.slice([kd.obj(x=1), kd.obj(x=1.0, y='a')]) +# Attributes present in all objects +kd.dir(os1) # ['x'] +# Or +os1.get_attr_names(intersection=True) # ['x'] +os1.get_attr_names(intersection=False) # ['x', 'y'] + +# Access attribute +o.x # 1 +o.get_attr('y') # 2 +o.maybe('z') # None +o.get_attr('z', default=0) # 0 +os.get_attr('x', default=0) # [1, 0, 'a'] + +# Objects are immutable by default, modification is done +# by creating a new object with updated attributes +o = kd.obj(x=1, y=2) + +# Update a single attribute +o1 = o.with_attr('x', 3) +o1 = o.with_attr('z', 4) +# Also override schema +# no update_schema=True is needed +o1 = o.with_attr('y', 'a') + +# Update multiple attributes +o2 = o.with_attrs(z=4, x=3) +# Also override schema for 'y' +o2 = o.with_attrs(z=4, y='a') + +# Create an update and apply it separately +upd = kd.attrs(o, z=4, y=10) +o3 = o.updated(upd) + +# Allows mixing multiple updates +o4 = o.updated(kd.attrs(o, z=4), kd.attrs(o, y=10)) + +# Update nested attributes +nested = kd.obj(a=kd.obj(c=kd.obj(e=1), d=2), b=3) +nested = nested.updated(kd.attrs(nested.a.c, e=4), + kd.attrs(nested.a, d=5), + kd.attrs(nested, b=6)) + +# List and dict can be objects too +# To convert a list/dict to an object, +# use kd.obj() +l = kd.list([1, 2, 3]) +l_obj = kd.obj(l) +l_obj[:] # [1, 2, 3] + +d = kd.dict({'a': 1, 'b': 2}) +d_obj = kd.obj(d) +d_obj.get_keys() # ['a', 'b'] +d_obj['a'] # 1 + +# Convert an entity to an object +e = kd.new(x=1, y=2) +e_obj = kd.obj(e) + +# Actually, we can pass primitive to kd.obj() +p_obj = kd.obj(1) +p_obj = kd.obj('a') + +# An OBJECT Dataslice with entity, list, +# dict and primitive items +kd.slice([kd.obj(a=1), 1, kd.obj(kd.list([1, 2])), + kd.obj(kd.dict({'a': 1}))]) +``` + +
+ +
+ ### Subslicing DataSlices Subslicing is an operation of getting part of the items in a DataSlice. @@ -821,6 +852,8 @@ kd.has_not(a) # [missing, present, missing] b = kd.slice([kd.obj(), kd.obj(kd.list()), kd.obj(kd.dict()), None, 1]) +kd.has_entity(b) +# -> [present, missing, missing, missing, missing] kd.has_list(b) # -> [missing, present, missing, missing, missing] kd.has_dict(b) @@ -1576,6 +1609,10 @@ Line = kd.named_schema('Line', start=Point, end=kd.ANY) # Get the attribute start's schema Line.start +# Check if it is an Entity schema +assert Point.is_entity_schema() +assert Line.is_entity_schema() + # List schema ls1 = kd.list_schema(kd.INT64) @@ -1611,6 +1648,9 @@ uus1 = kd.uu_schema(x=kd.INT32, y=kd.FLOAT64) uus2 = kd.uu_schema(x=kd.INT32, y=kd.FLOAT64) assert uus1 == uus2 +# It is also an Entity schema +assert uus1.is_entity_schema() + # In fact, named, list and dict schemas are also # UU schemas Point1 = kd.named_schema('Point', x=kd.INT32, y=kd.FLOAT64) diff --git a/koladata/operators/BUILD b/koladata/operators/BUILD index 56e92b17..93cbf867 100644 --- a/koladata/operators/BUILD +++ b/koladata/operators/BUILD @@ -161,6 +161,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_arolla//arolla/dense_array", + "@com_google_arolla//arolla/dense_array/ops", "@com_google_arolla//arolla/dense_array/qtype", "@com_google_arolla//arolla/expr", "@com_google_arolla//arolla/jagged_shape/dense_array/qtype", diff --git a/koladata/operators/operators.cc b/koladata/operators/operators.cc index 75a68193..91c9c75a 100644 --- a/koladata/operators/operators.cc +++ b/koladata/operators/operators.cc @@ -97,7 +97,9 @@ OPERATOR_FAMILY("kde.core.enriched", OPERATOR("kde.core.follow", Follow); OPERATOR("kde.core.freeze_bag", Freeze); OPERATOR("kde.core.get_bag", GetBag); +OPERATOR("kde.core.has_entity", HasEntity); OPERATOR("kde.core.has_primitive", HasPrimitive); +OPERATOR("kde.core.is_entity", IsEntity); OPERATOR("kde.core.is_primitive", IsPrimitive); OPERATOR("kde.core.no_bag", NoBag); OPERATOR("kde.core.nofollow", NoFollow); diff --git a/koladata/operators/predicates.cc b/koladata/operators/predicates.cc index 2c8c42f9..bff299a1 100644 --- a/koladata/operators/predicates.cc +++ b/koladata/operators/predicates.cc @@ -27,6 +27,8 @@ #include "koladata/operators/masking.h" #include "koladata/operators/utils.h" #include "arolla/dense_array/dense_array.h" +#include "arolla/dense_array/ops/dense_ops.h" +#include "arolla/memory/optional_value.h" #include "arolla/util/unit.h" #include "arolla/util/view_types.h" #include "arolla/util/status_macros_backport.h" @@ -60,6 +62,30 @@ absl::StatusOr HasPrimitiveImpl( return std::move(builder).Build(); } +absl::StatusOr HasEntityImpl( + const internal::DataItem& item) { + if (item.is_entity()) { + return internal::DataItem(arolla::Unit()); + } else { + return internal::DataItem(); + } +} + +absl::StatusOr HasEntityImpl( + const internal::DataSliceImpl& slice) { + auto result = arolla::CreateEmptyDenseArray(slice.size()); + slice.VisitValues([&](const arolla::DenseArray& values) { + if constexpr (std::is_same_v) { + result = arolla::CreateDenseOp( + [](arolla::view_type_t value) + -> arolla::OptionalValue { + return arolla::OptionalUnit(value.IsEntity()); + })(values); + } + }); + return internal::DataSliceImpl::Create(std::move(result)); +} + absl::StatusOr HasListImpl(const internal::DataItem& item) { if (item.is_list()) { return internal::DataItem(arolla::Unit()); @@ -70,19 +96,17 @@ absl::StatusOr HasListImpl(const internal::DataItem& item) { absl::StatusOr HasListImpl( const internal::DataSliceImpl& slice) { - internal::SliceBuilder builder(slice.size()); - auto typed_builder = builder.typed(); + auto result = arolla::CreateEmptyDenseArray(slice.size()); slice.VisitValues([&](const arolla::DenseArray& values) { if constexpr (std::is_same_v) { - values.ForEachPresent( - [&](int64_t id, arolla::view_type_t value) { - if (value.IsList()) { - typed_builder.InsertIfNotSet(id, arolla::Unit()); - } - }); + result = arolla::CreateDenseOp( + [](arolla::view_type_t value) + -> arolla::OptionalValue { + return arolla::OptionalUnit(value.IsList()); + })(values); } }); - return std::move(builder).Build(); + return internal::DataSliceImpl::Create(std::move(result)); } absl::StatusOr HasDictImpl(const internal::DataItem& item) { @@ -95,19 +119,17 @@ absl::StatusOr HasDictImpl(const internal::DataItem& item) { absl::StatusOr HasDictImpl( const internal::DataSliceImpl& slice) { - internal::SliceBuilder builder(slice.size()); - auto typed_builder = builder.typed(); + auto result = arolla::CreateEmptyDenseArray(slice.size()); slice.VisitValues([&](const arolla::DenseArray& values) { if constexpr (std::is_same_v) { - values.ForEachPresent( - [&](int64_t id, arolla::view_type_t value) { - if (value.IsDict()) { - typed_builder.InsertIfNotSet(id, arolla::Unit()); - } - }); + result = arolla::CreateDenseOp( + [](arolla::view_type_t value) + -> arolla::OptionalValue { + return arolla::OptionalUnit(value.IsDict()); + })(values); } }); - return std::move(builder).Build(); + return internal::DataSliceImpl::Create(std::move(result)); } } // namespace @@ -134,6 +156,25 @@ absl::StatusOr HasPrimitive(const DataSlice& x) { x.GetShape(), internal::DataItem(schema::kMask), nullptr); } +absl::StatusOr HasEntity(const DataSlice& x) { + auto schema = x.GetSchemaImpl(); + // Trust the schema if it is a Entity schema. + if (x.GetSchema().IsEntitySchema()) { + return Has(x); + } + // Derive from the data for OBJECT and ANY schemas. + if (schema.is_any_schema() || schema.is_object_schema()) { + return x.VisitImpl([&](const auto& impl) -> absl::StatusOr { + ASSIGN_OR_RETURN(auto res, HasEntityImpl(impl)); + return DataSlice::Create(std::move(res), x.GetShape(), + internal::DataItem(schema::kMask), nullptr); + }); + } + return DataSlice::Create( + internal::DataSliceImpl::CreateEmptyAndUnknownType(x.size()), + x.GetShape(), internal::DataItem(schema::kMask), nullptr); +} + absl::StatusOr HasList(const DataSlice& x) { auto schema = x.GetSchemaImpl(); // Trust the schema if it is a List schema. @@ -202,6 +243,10 @@ absl::StatusOr IsPrimitive(const DataSlice& x) { return AsMask(contains_only_primitives); } +absl::StatusOr IsEntity(const DataSlice& x) { + return AsMask(x.IsEntity()); +} + absl::StatusOr IsList(const DataSlice& x) { return AsMask(x.IsList()); } diff --git a/koladata/operators/predicates.h b/koladata/operators/predicates.h index 0c4baf87..0d7f4a66 100644 --- a/koladata/operators/predicates.h +++ b/koladata/operators/predicates.h @@ -27,18 +27,25 @@ absl::StatusOr IsPrimitive(const DataSlice& x); // Returns a MASK DataSlice with present for each item in `x` that is primitive. absl::StatusOr HasPrimitive(const DataSlice& x); +// Returns true if the DataSlice has an Entity schema or only contains entities +// if the schema is OBJECT or ANY. +absl::StatusOr IsEntity(const DataSlice& x); + +// Returns a MASK DataSlice with present for each item in `x` that is an Entity. +absl::StatusOr HasEntity(const DataSlice& x); + // Returns true if the DataSlice has a List schema or only contains lists if the // schema is OBJECT or ANY. absl::StatusOr IsList(const DataSlice& x); -// Returns a MASK DataSlice with present for each item in `x` that is List. +// Returns a MASK DataSlice with present for each item in `x` that is a List. absl::StatusOr HasList(const DataSlice& x); // Returns true if the DataSlice has a Dict schema or only contains dicts if the // schema is OBJECT or ANY. absl::StatusOr IsDict(const DataSlice& x); -// Returns a MASK DataSlice with present for each item in `x` that is Dict. +// Returns a MASK DataSlice with present for each item in `x` that is a Dict. absl::StatusOr HasDict(const DataSlice& x); } // namespace koladata::ops diff --git a/py/koladata/expr/view.py b/py/koladata/expr/view.py index 4682a878..24d4725f 100644 --- a/py/koladata/expr/view.py +++ b/py/koladata/expr/view.py @@ -406,6 +406,9 @@ def maybe(self, attr_name: Any) -> arolla.Expr: def is_empty(self) -> arolla.Expr: return arolla.abc.aux_bind_op('kde.is_empty', self) + def is_entity(self) -> arolla.Expr: + return arolla.abc.aux_bind_op('kde.is_entity', self) + def is_list(self) -> arolla.Expr: return arolla.abc.aux_bind_op('kde.is_list', self) diff --git a/py/koladata/expr/view_test.py b/py/koladata/expr/view_test.py index 5855ab2a..f4a6a76d 100644 --- a/py/koladata/expr/view_test.py +++ b/py/koladata/expr/view_test.py @@ -459,6 +459,15 @@ def test_updated(self): def test_get_present_count(self): testing.assert_equal(C.x.get_present_count(), kde.count(C.x)) + def test_is_entity(self): + testing.assert_equal(C.x.is_entity(), kde.is_entity(C.x)) + + def test_is_list(self): + testing.assert_equal(C.x.is_list(), kde.is_list(C.x)) + + def test_is_dict(self): + testing.assert_equal(C.x.is_dict(), kde.is_dict(C.x)) + def test_is_dict_schema(self): testing.assert_equal(C.x.is_dict_schema(), kde.schema.is_dict_schema(C.x)) diff --git a/py/koladata/operators/core.py b/py/koladata/operators/core.py index df03b0de..1a481b2d 100644 --- a/py/koladata/operators/core.py +++ b/py/koladata/operators/core.py @@ -226,6 +226,68 @@ def is_primitive(x): # pylint: disable=unused-argument raise NotImplementedError('implemented in the backend') +@optools.add_to_registry(aliases=['kde.has_entity']) +@optools.as_backend_operator( + 'kde.core.has_entity', + qtype_constraints=[ + qtype_utils.expect_data_slice(P.x), + ], +) +def has_entity(x): # pylint: disable=unused-argument + """Returns present for each item in `x` that is an Entity. + + Note that this is a pointwise operation. + + Also see `kd.is_entity` for checking if `x` is an Entity DataSlice. But + note that `kd.all(kd.has_entity(x))` is not always equivalent to + `kd.is_entity(x)`. For example, + + kd.is_entity(kd.item(None, kd.OBJECT)) -> kd.present + kd.all(kd.has_entity(kd.item(None, kd.OBJECT))) -> invalid for kd.all + kd.is_entity(kd.item([None], kd.OBJECT)) -> kd.present + kd.all(kd.has_entity(kd.item([None], kd.OBJECT))) -> kd.missing + + Args: + x: DataSlice to check. + + Returns: + A MASK DataSlice with the same shape as `x`. + """ + raise NotImplementedError('implemented in the backend') + + +@optools.add_to_registry(aliases=['kde.is_entity']) +@optools.as_backend_operator( + 'kde.core.is_entity', + qtype_constraints=[ + qtype_utils.expect_data_slice(P.x), + ], +) +def is_entity(x): # pylint: disable=unused-argument + """Returns whether x is an Entity DataSlice. + + `x` is an Entity DataSlice if it meets one of the following conditions: + 1) it has an Entity schema + 2) it has OBJECT/ANY schema and only has Entity items + + Also see `kd.has_entity` for a pointwise version. But note that + `kd.all(kd.has_entity(x))` is not always equivalent to + `kd.is_entity(x)`. For example, + + kd.is_entity(kd.item(None, kd.OBJECT)) -> kd.present + kd.all(kd.has_entity(kd.item(None, kd.OBJECT))) -> invalid for kd.all + kd.is_entity(kd.item([None], kd.OBJECT)) -> kd.present + kd.all(kd.has_entity(kd.item([None], kd.OBJECT))) -> kd.missing + + Args: + x: DataSlice to check. + + Returns: + A MASK DataItem. + """ + raise NotImplementedError('implemented in the backend') + + @optools.add_to_registry(aliases=['kde.stub']) @optools.as_backend_operator( 'kde.core.stub', diff --git a/py/koladata/operators/tests/BUILD b/py/koladata/operators/tests/BUILD index 31297bd2..4fbe66a7 100644 --- a/py/koladata/operators/tests/BUILD +++ b/py/koladata/operators/tests/BUILD @@ -457,6 +457,49 @@ py_test( ], ) +py_test( + name = "core_is_entity_test", + srcs = ["core_is_entity_test.py"], + deps = [ + "//py/koladata/expr:expr_eval", + "//py/koladata/expr:input_container", + "//py/koladata/expr:view", + "//py/koladata/operators:kde_operators", + "//py/koladata/operators:optools", + "//py/koladata/operators/tests/util:qtypes", + "//py/koladata/types:data_bag", + "//py/koladata/types:data_slice", + "//py/koladata/types:dict_item", + "//py/koladata/types:mask_constants", + "//py/koladata/types:qtypes", + "//py/koladata/types:schema_constants", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@com_google_arolla//py/arolla", + ], +) + +py_test( + name = "core_has_entity_test", + srcs = ["core_has_entity_test.py"], + deps = [ + "//py/koladata/expr:expr_eval", + "//py/koladata/expr:input_container", + "//py/koladata/expr:view", + "//py/koladata/operators:kde_operators", + "//py/koladata/operators:optools", + "//py/koladata/testing", + "//py/koladata/types:data_bag", + "//py/koladata/types:data_slice", + "//py/koladata/types:mask_constants", + "//py/koladata/types:qtypes", + "//py/koladata/types:schema_constants", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@com_google_arolla//py/arolla", + ], +) + py_test( name = "lists_is_list_test", srcs = ["lists_is_list_test.py"], diff --git a/py/koladata/operators/tests/core_has_entity_test.py b/py/koladata/operators/tests/core_has_entity_test.py new file mode 100644 index 00000000..40aca245 --- /dev/null +++ b/py/koladata/operators/tests/core_has_entity_test.py @@ -0,0 +1,86 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from arolla import arolla +from koladata.expr import expr_eval +from koladata.expr import input_container +from koladata.expr import view +from koladata.operators import kde_operators +from koladata.operators import optools +from koladata.testing import testing +from koladata.types import data_bag +from koladata.types import data_slice +from koladata.types import mask_constants +from koladata.types import qtypes +from koladata.types import schema_constants + +I = input_container.InputContainer('I') +M = arolla.M +bag = data_bag.DataBag.empty +ds = data_slice.DataSlice.from_vals +DATA_SLICE = qtypes.DATA_SLICE +kde = kde_operators.kde + +present = mask_constants.present +missing = mask_constants.missing + + +class KodaHasEntityTest(parameterized.TestCase): + + @parameterized.parameters( + # DataItem + (ds(None), missing), + (bag().new() & None, missing), + (bag().new(), present), + (bag().new(a=1), present), + (bag().obj(a=1), present), + (bag().new(a=1).as_any(), present), + (ds('hello'), missing), + (bag().dict(), missing), + (bag().dict().embed_schema(), missing), + (bag().list(), missing), + (bag().new_schema(), missing), + # DataSlice + ( + ds([ + bag().new(a=1, schema='test'), + None, + bag().new(a=2, schema='test'), + ]), + ds([present, missing, present]), + ), + (ds([None, None]), ds([missing, missing])), + (ds([None, None], schema_constants.INT32), ds([missing, missing])), + (ds([None, None], schema_constants.OBJECT), ds([missing, missing])), + (ds([None, None], schema_constants.ANY), ds([missing, missing])), + # Mixed types. + ( + ds([bag().obj(a=1), None, 'world', bag().dict().embed_schema()]), + ds([present, missing, missing, missing]), + ), + ) + def test_eval(self, x, expected): + testing.assert_equal(expr_eval.eval(kde.core.has_entity(x)), expected) + + def test_view(self): + self.assertTrue(view.has_koda_view(kde.core.has_entity(I.x))) + + def test_alias(self): + self.assertTrue(optools.equiv_to_op(kde.core.has_entity, kde.has_entity)) + + +if __name__ == '__main__': + absltest.main() diff --git a/py/koladata/operators/tests/core_is_entity_test.py b/py/koladata/operators/tests/core_is_entity_test.py new file mode 100644 index 00000000..3b7b44ab --- /dev/null +++ b/py/koladata/operators/tests/core_is_entity_test.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from arolla import arolla +from koladata.expr import expr_eval +from koladata.expr import input_container +from koladata.expr import view +from koladata.operators import kde_operators +from koladata.operators import optools +from koladata.operators.tests.util import qtypes as test_qtypes +from koladata.types import data_bag +from koladata.types import data_slice +from koladata.types import dict_item as _ # pylint: disable=unused-import +from koladata.types import mask_constants +from koladata.types import qtypes +from koladata.types import schema_constants + +I = input_container.InputContainer('I') +kde = kde_operators.kde +ds = data_slice.DataSlice.from_vals +bag = data_bag.DataBag.empty +DATA_SLICE = qtypes.DATA_SLICE + +present = mask_constants.present +missing = mask_constants.missing + + +QTYPES = frozenset([ + (DATA_SLICE, DATA_SLICE), +]) + + +class DictsIsEntityTest(parameterized.TestCase): + + @parameterized.parameters( + # Entity + (bag().new(),), + (bag().new(a=1),), + ( + ds([ + bag().new(a=1, schema='test'), + None, + bag().new(a=2, schema='test'), + ]), + ), + # OBJECT + (ds([bag().obj(a=1), None, bag().obj(a=2)]),), + # ANY + ( + ds([ + bag().new(a=1, schema='test'), + None, + bag().new(a=2, schema='test'), + ]).as_any(), + ), + # Missing + (bag().new() & None,), + (ds(None, schema_constants.OBJECT),), + (ds(None, schema_constants.ANY),), + (bag().obj(a=1) & None,), + ) + def test_is_entity(self, x): + self.assertTrue(expr_eval.eval(kde.core.is_entity(x))) + + @parameterized.parameters( + # Primitive + (ds(1),), + (ds([1, 2]),), + # List/Object/Dict + (bag().list([1, 2]).embed_schema(),), + (bag().list([1, 2]),), + (bag().dict({1: 2}),), + # ItemId + (bag().new().get_itemid(),), + # Mixed + (ds([bag().list([1, 2]).embed_schema(), None, 1]),), + # Missing + (ds(None),), + (ds(None, schema_constants.INT32),), + (ds([None, None]),), + (ds([None, None], schema_constants.INT32),), + (bag().dict({1: 2}) & None,), + (bag().list([1, 2]) & None,), + ) + def test_is_not_entity(self, x): + self.assertFalse(expr_eval.eval(kde.core.is_entity(x))) + + def test_qtype_signatures(self): + self.assertCountEqual( + arolla.testing.detect_qtype_signatures( + kde.core.is_entity, + possible_qtypes=test_qtypes.DETECT_SIGNATURES_QTYPES, + ), + QTYPES, + ) + + def test_view(self): + self.assertTrue(view.has_koda_view(kde.core.is_entity(I.x))) + + def test_alias(self): + self.assertTrue(optools.equiv_to_op(kde.core.is_entity, kde.is_entity)) + + +if __name__ == '__main__': + absltest.main() diff --git a/py/koladata/types/data_slice.cc b/py/koladata/types/data_slice.cc index b10c9257..3411ab41 100644 --- a/py/koladata/types/data_slice.cc +++ b/py/koladata/types/data_slice.cc @@ -808,6 +808,12 @@ absl::Nullable PyDataSlice_is_list(PyObject* self, PyObject*) { return WrapPyDataSlice(AsMask(ds.IsList())); } +absl::Nullable PyDataSlice_is_entity(PyObject* self, PyObject*) { + arolla::python::DCheckPyGIL(); + const auto& ds = UnsafeDataSliceRef(self); + return WrapPyDataSlice(AsMask(ds.IsEntity())); +} + absl::Nullable PyDataSlice_is_primitive_schema(PyObject* self, PyObject*) { arolla::python::DCheckPyGIL(); @@ -1085,11 +1091,18 @@ Note that the Entity schema includes Entity, List and Dict schemas. {"is_dict", PyDataSlice_is_dict, METH_NOARGS, "is_dict()\n" "--\n\n" - "Returns present iff this DataSlice contains only dicts."}, + "Returns present iff this DataSlice has Dict schema or contains only " + "dicts."}, {"is_list", PyDataSlice_is_list, METH_NOARGS, "is_list()\n" "--\n\n" - "Returns present iff this DataSlice contains only lists."}, + "Returns present iff this DataSlice has List schema or contains only " + "lists."}, + {"is_entity", PyDataSlice_is_entity, METH_NOARGS, + "is_entity()\n" + "--\n\n" + "Returns present iff this DataSlice has Entity schema or contains only " + "entities."}, {"is_dict_schema", PyDataSlice_is_dict_schema, METH_NOARGS, "is_dict_schema()\n" "--\n\n" diff --git a/py/koladata/types/data_slice_test.py b/py/koladata/types/data_slice_test.py index 9a89b8a2..b291008a 100644 --- a/py/koladata/types/data_slice_test.py +++ b/py/koladata/types/data_slice_test.py @@ -2546,6 +2546,26 @@ def test_is_dict(self): self.assertFalse(x.as_any().is_dict()) self.assertFalse(db.obj(x).is_dict()) + def test_is_entity(self): + db = bag() + x = db.new(a=ds([1, 2])) + self.assertTrue(x.is_entity()) + self.assertTrue(x.as_any().is_entity()) + self.assertTrue(db.obj(x).is_entity()) + self.assertFalse(ds([db.obj(a=1), db.obj(db.dict())]).is_entity()) + x = ds([db.dict({1: 2}), db.dict({3: 4})]) + self.assertFalse(x.is_entity()) + self.assertFalse(x.as_any().is_entity()) + self.assertFalse(db.obj(x).is_entity()) + x = ds([1.0, 2.0]) + self.assertFalse(x.is_entity()) + self.assertFalse(x.as_any().is_entity()) + self.assertFalse(db.obj(x).is_entity()) + x = ds([db.obj(a=1), 1.0]) + self.assertFalse(x.is_entity()) + self.assertFalse(x.as_any().is_entity()) + self.assertFalse(db.obj(x).is_entity()) + def test_empty_subscript_method_slice(self): db = bag() testing.assert_equal(