Skip to content

Commit

Permalink
ExpectHaveCommonSchema
Browse files Browse the repository at this point in the history
So far I used a one-line error message, aligned with other `Expect*` set of functions. If we will want to switch to multiline error messages, we should probably do it consistently for all the functions.

PiperOrigin-RevId: 713654374
Change-Id: I5bd9e9b8703a4de3bf9e208bc742fcab3e4e2e2a
  • Loading branch information
timofey-stepanov authored and copybara-github committed Jan 9, 2025
1 parent ca68448 commit 88d4540
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 40 deletions.
9 changes: 2 additions & 7 deletions koladata/operators/comparison.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,8 @@ absl::StatusOr<DataSlice> GreaterEqual(const DataSlice& x, const DataSlice& y) {
absl::StatusOr<DataSlice> Equal(const DataSlice& x, const DataSlice& y) {
// NOTE: Casting is handled internally by EqualOp. The schema compatibility is
// still verified to ensure that e.g. ITEMID and OBJECT are not compared.
RETURN_IF_ERROR(
schema::CommonSchema(x.GetSchemaImpl(), y.GetSchemaImpl()).status())
.With([&](const absl::Status& status) {
return AssembleErrorMessage(status,
{.db = DataBag::ImmutableEmptyWithFallbacks(
{x.GetBag(), y.GetBag()})});
});
RETURN_IF_ERROR(ExpectHaveCommonSchema({"x", "y"}, x, y))
.With(OpError("kd.comparison.equal"));
return DataSliceOp<internal::EqualOp>()(
x, y, internal::DataItem(schema::kMask), nullptr);
}
Expand Down
6 changes: 3 additions & 3 deletions koladata/operators/masking.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "koladata/internal/op_utils/presence_or.h"
#include "koladata/internal/op_utils/utils.h"
#include "koladata/operators/arolla_bridge.h"
#include "koladata/repr_utils.h"
#include "koladata/schema_utils.h"
#include "arolla/util/status_macros_backport.h"

Expand All @@ -52,9 +51,10 @@ inline absl::StatusOr<DataSlice> ApplyMask(const DataSlice& obj,
// kde.masking.coalesce.
inline absl::StatusOr<DataSlice> Coalesce(const DataSlice& x,
const DataSlice& y) {
RETURN_IF_ERROR(ExpectHaveCommonSchema({"x", "y"}, x, y))
.With(OpError("kd.masking.coalesce"));
auto res_db = DataBag::CommonDataBag({x.GetBag(), y.GetBag()});
ASSIGN_OR_RETURN(auto aligned_slices, AlignSchemas({x, y}),
AssembleErrorMessage(_, {.db = res_db}));
ASSIGN_OR_RETURN(auto aligned_slices, AlignSchemas({x, y}));
return DataSliceOp<internal::PresenceOrOp>()(
aligned_slices.slices[0], aligned_slices.slices[1],
aligned_slices.common_schema, std::move(res_db));
Expand Down
17 changes: 17 additions & 0 deletions koladata/schema_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,23 @@ absl::Status ExpectConsistentStringOrBytesImpl(

} // namespace schema_utils_internal

absl::Status ExpectHaveCommonSchema(
absl::Span<const absl::string_view> arg_names, const DataSlice& lhs,
const DataSlice& rhs) {
if (arg_names.size() != 2) {
return absl::InternalError("arg_names must have exactly 2 elements");
}
if (schema::CommonSchema(lhs.GetSchemaImpl(), rhs.GetSchemaImpl()).ok()) {
return absl::OkStatus();
}
return absl::InvalidArgumentError(
absl::StrFormat("arguments `%s` and `%s` must contain values castable to "
"a common type, got %s and %s",
arg_names[0], arg_names[1],
schema_utils_internal::DescribeSliceSchema(lhs),
schema_utils_internal::DescribeSliceSchema(rhs)));
}

absl::Status ExpectHaveCommonPrimitiveSchema(
absl::Span<const absl::string_view> arg_names, const DataSlice& lhs,
const DataSlice& rhs) {
Expand Down
5 changes: 5 additions & 0 deletions koladata/schema_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ inline absl::Status ExpectConsistentStringOrBytes(absl::string_view arg_name,
{&arg});
}

// Returns OK if the DataSlices contain values castable to a common type.
absl::Status ExpectHaveCommonSchema(
absl::Span<const absl::string_view> arg_names, const DataSlice& lhs,
const DataSlice& rhs);

// Returns OK if the DataSlices contain values castable to a common primitive
// type.
// NOTE: arg_names must have exactly 2 elements.
Expand Down
19 changes: 19 additions & 0 deletions koladata/schema_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,25 @@ TEST(SchemaUtilsTest, ExpectConsistentStringOrBytes) {
"slice of OBJECT with items of types BYTES, STRING"));
}

TEST(SchemaUtilsTest, ExpectHaveCommonSchema) {
auto empty_and_unknown = test::DataItem(std::nullopt, schema::kObject);
auto integer = test::DataSlice<int>({1, 2, std::nullopt});
auto floating = test::DataSlice<float>({1, 2, std::nullopt});
auto bytes = test::DataSlice<std::string>({"a", "b", std::nullopt});
auto bytes_any =
test::DataSlice<std::string>({"a", "b", std::nullopt}, schema::kAny);
auto schema = test::DataItem(std::nullopt, schema::kSchema);

EXPECT_THAT(ExpectHaveCommonSchema({"foo", "bar"}, bytes, empty_and_unknown),
IsOk());
EXPECT_THAT(ExpectHaveCommonSchema({"foo", "bar"}, bytes, bytes_any), IsOk());
EXPECT_THAT(ExpectHaveCommonSchema({"foo", "bar"}, integer, bytes), IsOk());
EXPECT_THAT(ExpectHaveCommonSchema({"foo", "bar"}, integer, schema),
StatusIs(absl::StatusCode::kInvalidArgument,
"arguments `foo` and `bar` must contain values castable "
"to a common type, got INT32 and SCHEMA"));
}

TEST(SchemaUtilsTest, ExpectHaveCommonPrimitiveSchema) {
auto empty_and_unknown = test::DataItem(std::nullopt, schema::kObject);
auto integer = test::DataSlice<int>({1, 2, std::nullopt});
Expand Down
29 changes: 18 additions & 11 deletions py/koladata/operators/tests/comparison_equal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for kde.comparison.equal."""
import re

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -115,32 +115,39 @@ def test_qtype_signatures(self):
def test_raises_on_incompatible_schemas(self):
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) INT32: INT32
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA() and INT32'
),
):
expr_eval.eval(kde.comparison.equal(ENTITY_1, ds(1)))

db = data_bag.DataBag.empty()
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) [0-9a-f]{32}:0: SCHEMA\(x=INT32\)
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and SCHEMA()'
),
):
expr_eval.eval(kde.comparison.equal(db.new(x=1), db.new()))

with self.assertRaisesRegex(
exceptions.KodaError,
'cannot find a common schema for provided schemas',
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and OBJECT with an'
' item of type ITEMID'
),
):
expr_eval.eval(kde.comparison.equal(db.new(x=1), db.obj()))

with self.assertRaisesRegex(
exceptions.KodaError,
'cannot find a common schema for provided schemas',
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and ITEMID'
),
):
expr_eval.eval(
kde.comparison.equal(
Expand Down
29 changes: 19 additions & 10 deletions py/koladata/operators/tests/comparison_full_equal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from absl.testing import absltest
from absl.testing import parameterized
from arolla import arolla
Expand Down Expand Up @@ -109,32 +111,39 @@ def test_qtype_signatures(self):
def test_raises_on_incompatible_schemas(self):
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) INT32: INT32
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA() and INT32'
),
):
expr_eval.eval(kde.comparison.full_equal(ENTITY_1, ds(1)))

db = data_bag.DataBag.empty()
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) [0-9a-f]{32}:0: SCHEMA\(x=INT32\)
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and SCHEMA()'
),
):
expr_eval.eval(kde.comparison.full_equal(db.new(x=1), db.new()))

with self.assertRaisesRegex(
exceptions.KodaError,
'cannot find a common schema for provided schemas',
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and OBJECT with an'
' item of type ITEMID'
),
):
expr_eval.eval(kde.comparison.full_equal(db.new(x=1), db.obj()))

with self.assertRaisesRegex(
exceptions.KodaError,
'cannot find a common schema for provided schemas',
re.escape(
'kd.comparison.equal: arguments `x` and `y` must contain values'
' castable to a common type, got SCHEMA(x=INT32) and ITEMID'
),
):
expr_eval.eval(
kde.comparison.full_equal(
Expand Down
6 changes: 2 additions & 4 deletions py/koladata/operators/tests/masking_coalesce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,8 @@ def test_incompatible_schema_error(self):
y = data_bag.DataBag.empty().new()
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) INT32: INT32
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
'kd.masking.coalesce: arguments `x` and `y` must contain values'
' castable to a common type, got INT32 and SCHEMA()',
):
expr_eval.eval(kde.masking.coalesce(x, y))

Expand Down
8 changes: 7 additions & 1 deletion py/koladata/operators/tests/masking_cond_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from absl.testing import absltest
from absl.testing import parameterized
from arolla import arolla
Expand Down Expand Up @@ -184,7 +186,11 @@ def test_incompatible_schema_error(self):
x = ds([1, None])
y = data_bag.DataBag.empty().new()
with self.assertRaisesRegex(
exceptions.KodaError, 'cannot find a common schema for provided schemas'
exceptions.KodaError,
re.escape(
'kd.masking.coalesce: arguments `x` and `y` must contain values'
' castable to a common type, got INT32 and SCHEMA()'
),
):
expr_eval.eval(kde.masking.cond(ds(arolla.present()), x, y))

Expand Down
8 changes: 4 additions & 4 deletions py/koladata/operators/tests/masking_disjoint_coalesce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ def test_incompatible_schema_error(self):
y = data_bag.DataBag.empty().new() & ds(arolla.missing())
with self.assertRaisesRegex(
exceptions.KodaError,
r"""cannot find a common schema for provided schemas
the common schema\(s\) INT32: INT32
the first conflicting schema [0-9a-f]{32}:0: SCHEMA\(\)""",
re.escape(
'kd.masking.coalesce: arguments `x` and `y` must contain values'
' castable to a common type, got INT32 and SCHEMA()'
),
):
expr_eval.eval(kde.masking.disjoint_coalesce(x, y))

Expand Down

0 comments on commit 88d4540

Please sign in to comment.