From 1f9dbaf0a682b35f3bbcc78bcfd423e8b0c851ef Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Fri, 27 Sep 2024 11:39:08 +0200 Subject: [PATCH 01/25] add round in polars --- src/pydiverse/transform/backend/polars.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index c2ed6e8a..af95acee 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -709,3 +709,10 @@ def _greatest(*x): @op.auto def _least(*x): return pl.min_horizontal(*x) + + +with PolarsImpl.op(ops.Round()) as op: + + @op.auto + def _round(x, digits=0): + return x.round(digits) From 4992c698cd6cc23c85a780f4084d60a62b6c88a7 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Fri, 27 Sep 2024 15:24:30 +0200 Subject: [PATCH 02/25] add some support for floating point log --- src/pydiverse/transform/backend/mssql.py | 18 ++++++ src/pydiverse/transform/backend/polars.py | 45 +++++++-------- src/pydiverse/transform/backend/sql.py | 39 +++++++------ src/pydiverse/transform/ops/numeric.py | 12 ++++ src/pydiverse/transform/tree/dtypes.py | 14 +++-- tests/test_backend_equivalence/conftest.py | 9 +++ .../test_ops/test_ops_numerical.py | 57 +++++++++++++++++++ tests/test_operator_registry.py | 2 +- tests/test_polars_table.py | 2 +- 9 files changed, 149 insertions(+), 49 deletions(-) create mode 100644 tests/test_backend_equivalence/test_ops/test_ops_numerical.py diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index e00a6309..1c8860fd 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -21,6 +21,11 @@ ) from pydiverse.transform.util.warnings import warn_non_standard +MSSQL_INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal( + "0.0", type_=sqa.Float() +) +MSSQL_NAN = MSSQL_INF + (-MSSQL_INF) + class MsSqlImpl(SqlImpl): dialect_name = "mssql" @@ -289,3 +294,16 @@ def _day_of_week(x): @op.auto def _mean(x): return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double()) + + +with MsSqlImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + # TODO: we still need to handle inf / -inf / nan + return sqa.case( + (x > 0, sqa.func.log(x)), + (x < 0, MSSQL_NAN), + (x.is_(sqa.null()), None), + else_=-MSSQL_INF, + ) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index af95acee..06b8e47b 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -1,7 +1,5 @@ from __future__ import annotations -import datetime -from types import NoneType from typing import Any from uuid import UUID @@ -340,7 +338,9 @@ def has_path_to_leaf_without_agg(expr: ColExpr): def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: if t.is_float(): - return dtypes.Float() + return dtypes.Float64() + elif t.is_decimal(): + return dtypes.Decimal() elif t.is_integer(): return dtypes.Int() elif isinstance(t, pl.Boolean): @@ -360,8 +360,10 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: - if isinstance(t, dtypes.Float): + if isinstance(t, dtypes.Float64): return pl.Float64() + elif isinstance(t, dtypes.Decimal): + return pl.Decimal() elif isinstance(t, dtypes.Int): return pl.Int64() elif isinstance(t, dtypes.Bool): @@ -380,27 +382,6 @@ def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: raise AssertionError -def python_type_to_polars(t: type) -> pl.DataType: - if t is int: - return pl.Int64() - elif t is float: - return pl.Float64() - elif t is bool: - return pl.Boolean() - elif t is str: - return pl.String() - elif t is datetime.datetime: - return pl.Datetime() - elif t is datetime.date: - return pl.Date() - elif t is datetime.timedelta: - return pl.Duration() - elif t is NoneType: - return pl.Null() - - raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform") - - with PolarsImpl.op(ops.Mean()) as op: @op.auto @@ -716,3 +697,17 @@ def _least(*x): @op.auto def _round(x, digits=0): return x.round(digits) + + +with PolarsImpl.op(ops.Exp()) as op: + + @op.auto + def _exp(x): + return x.exp() + + +with PolarsImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + return x.log() diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 05be684c..e85933b0 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -550,7 +550,7 @@ def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: if isinstance(t, sqa.Integer): return dtypes.Int() elif isinstance(t, sqa.Numeric): - return dtypes.Float() + return dtypes.Float64() elif isinstance(t, sqa.String): return dtypes.String() elif isinstance(t, sqa.Boolean): @@ -570,7 +570,7 @@ def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: if isinstance(t, dtypes.Int): return sqa.Integer() - elif isinstance(t, dtypes.Float): + elif isinstance(t, dtypes.Float64): return sqa.Numeric() elif isinstance(t, dtypes.String): return sqa.String() @@ -686,9 +686,6 @@ def _is_not_null(x): return x.is_not(sqa.null()) -#### String Functions #### - - with SqlImpl.op(ops.StrStrip()) as op: @op.auto @@ -753,9 +750,6 @@ def _str_slice(x, offset, length): return sqa.func.SUBSTR(x, offset + 1, length) -#### Datetime Functions #### - - with SqlImpl.op(ops.DtYear()) as op: @op.auto @@ -819,9 +813,6 @@ def _day_of_year(x): return sqa.extract("doy", x) -#### Generic Functions #### - - with SqlImpl.op(ops.Greatest()) as op: @op.auto @@ -838,9 +829,6 @@ def _least(*x): return sqa.func.LEAST(*x) -#### Summarising Functions #### - - with SqlImpl.op(ops.Mean()) as op: @op.auto @@ -919,9 +907,6 @@ def _count(x=None): return sqa.func.count(x) -#### Window Functions #### - - with SqlImpl.op(ops.Shift()) as op: @op.auto @@ -968,3 +953,23 @@ def _rank(): @op.auto def _dense_rank(): return sqa.func.dense_rank() + + +with SqlImpl.op(ops.Exp()) as op: + + @op.auto + def _exp(x): + return sqa.func.exp(x) + + +with SqlImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + # TODO: we still need to handle inf / -inf / nan + return sqa.case( + (x > 0, sqa.func.ln(x)), + (x < 0, sqa.literal("nan")), + (x.is_(sqa.null()), None), + else_=sqa.literal("-inf"), + ) diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 4c0c4350..1482f690 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -21,6 +21,8 @@ "Pos", "Abs", "Round", + "Exp", + "Log", ] @@ -144,3 +146,13 @@ class Round(ElementWise): "float -> float", "float, const int -> float", ] + + +class Log(ElementWise): + name = "log" + signatures = ["float -> float"] + + +class Exp(Log): + name = "exp" + signatures = ["float -> float"] diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index be06ec35..7ddb0306 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -76,7 +76,7 @@ def can_promote_to(self, other: Dtype) -> bool: return True # int can be promoted to float - if Float().same_kind(other): + if Float64().same_kind(other): if other.const and not self.const: return False @@ -85,8 +85,12 @@ def can_promote_to(self, other: Dtype) -> bool: return False -class Float(Dtype): - name = "float" +class Float64(Dtype): + name = "float64" + + +class Decimal(Dtype): + name = "decimal" class String(Dtype): @@ -144,7 +148,7 @@ def python_type_to_pdt(t: type) -> Dtype: if t is int: return Int() elif t is float: - return Float() + return Float64() elif t is bool: return Bool() elif t is str: @@ -193,7 +197,7 @@ def dtype_from_string(t: str) -> Dtype: if base_type == "int": return Int(const=is_const, vararg=is_vararg) if base_type == "float": - return Float(const=is_const, vararg=is_vararg) + return Float64(const=is_const, vararg=is_vararg) if base_type == "str": return String(const=is_const, vararg=is_vararg) if base_type == "bool": diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index da619439..50600824 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -119,6 +119,15 @@ # ], } ), + "df_num": pl.DataFrame( + { + "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, 2.3, -1000.0], + "b": [None, 2, 0, -11, 4, 19, -5190, 2000000], + "c": [0.0, None, None, 2.2, -0.0, 10.0, -10.0, 3.1415926535], + "d": [None, 2.71828, 0.577, 901234, -6.0, 4.0, None, -99.0], + "e": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 500.0], + } + ), } # compare one dataframe and one SQL backend to all others diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py new file mode 100644 index 00000000..155150ce --- /dev/null +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pydiverse.transform.pipe.verbs import mutate +from tests.util.assertion import assert_result_equal + + +def test_exp(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + exp_a=t.a.exp(), + exp_b=t.b.exp(), + exp_c=t.c.exp(), + exp_d=t.d.exp(), + ), + ) + + +def test_log(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + log_a=t.a.log(), + log_b=t.b.log(), + log_c=t.c.log(), + log_d=t.d.log(), + log_e=t.e.exp(), + ), + ) + + +def test_abs(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + abs_a=abs(t.a), + abs_b=abs(t.b), + abs_c=abs(t.c), + abs_d=abs(t.d), + ), + ) + + +def test_round(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + round_a=round(t.a), + round_b=round(t.b), + round_c=round(t.c), + round_d=round(t.d), + ), + ) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index 7014e7f7..0c05ad87 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -195,7 +195,7 @@ def test_template(self): ) assert isinstance( reg.get_impl("op3", parse_dtypes("str", "int", "float")).return_type, - dtypes.Float, + dtypes.Float64, ) def test_vararg(self): diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index d3b53cc7..9571ff50 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -121,7 +121,7 @@ def test_dtype(self, tbl1, tbl2): assert isinstance(tbl2.col1.dtype(), dtypes.Int) assert isinstance(tbl2.col2.dtype(), dtypes.Int) - assert isinstance(tbl2.col3.dtype(), dtypes.Float) + assert isinstance(tbl2.col3.dtype(), dtypes.Float64) # test that column expression type errors are checked immediately with pytest.raises(TypeError): From d3f63b7c735801dc30e318aaed50453d20966146 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Fri, 27 Sep 2024 16:07:39 +0200 Subject: [PATCH 03/25] add some tests for decimals --- src/pydiverse/transform/backend/sql.py | 8 +++- src/pydiverse/transform/ops/aggregate.py | 10 ++--- src/pydiverse/transform/ops/logical.py | 2 +- src/pydiverse/transform/ops/numeric.py | 38 +++++++++---------- src/pydiverse/transform/tree/dtypes.py | 6 ++- src/pydiverse/transform/tree/registry.py | 4 +- tests/test_backend_equivalence/conftest.py | 5 ++- .../test_ops/test_ops_numerical.py | 8 ++++ tests/test_operator_registry.py | 12 +++--- tests/util/backend.py | 9 ++--- 10 files changed, 57 insertions(+), 45 deletions(-) diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index e85933b0..15b1c245 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -549,8 +549,10 @@ def get_engine(nd: AstNode) -> sqa.Engine: def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: if isinstance(t, sqa.Integer): return dtypes.Int() - elif isinstance(t, sqa.Numeric): + elif isinstance(t, sqa.Float): return dtypes.Float64() + elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): + return dtypes.Decimal() elif isinstance(t, sqa.String): return dtypes.String() elif isinstance(t, sqa.Boolean): @@ -571,7 +573,9 @@ def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: if isinstance(t, dtypes.Int): return sqa.Integer() elif isinstance(t, dtypes.Float64): - return sqa.Numeric() + return sqa.Float() + elif isinstance(t, dtypes.Decimal): + return sqa.DECIMAL() elif isinstance(t, dtypes.String): return sqa.String() elif isinstance(t, dtypes.Bool): diff --git a/src/pydiverse/transform/ops/aggregate.py b/src/pydiverse/transform/ops/aggregate.py index 2261142e..9f536c94 100644 --- a/src/pydiverse/transform/ops/aggregate.py +++ b/src/pydiverse/transform/ops/aggregate.py @@ -17,7 +17,7 @@ class Min(Aggregate, Unary): name = "min" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", "str -> str", "datetime -> datetime", "date -> date", @@ -28,7 +28,7 @@ class Max(Aggregate, Unary): name = "max" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", "str -> str", "datetime -> datetime", "date -> date", @@ -38,8 +38,8 @@ class Max(Aggregate, Unary): class Mean(Aggregate, Unary): name = "mean" signatures = [ - "int -> float", - "float -> float", + "int -> float64", + "float64 -> float64", ] @@ -47,7 +47,7 @@ class Sum(Aggregate, Unary): name = "sum" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", ] diff --git a/src/pydiverse/transform/ops/logical.py b/src/pydiverse/transform/ops/logical.py index f4c1ef9c..1f36d0d2 100644 --- a/src/pydiverse/transform/ops/logical.py +++ b/src/pydiverse/transform/ops/logical.py @@ -36,7 +36,7 @@ class Logical(Operator): class Comparison(ElementWise, Binary, Logical): signatures = [ "int, int -> bool", - "float, float -> bool", + "float64, float64 -> bool", "str, str -> bool", "bool, bool -> bool", "datetime, datetime -> bool", diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 1482f690..6b6a29e3 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -30,9 +30,8 @@ class Add(ElementWise, Binary): name = "__add__" signatures = [ "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -44,9 +43,8 @@ class Sub(ElementWise, Binary): name = "__sub__" signatures = [ "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -58,9 +56,8 @@ class Mul(ElementWise, Binary): name = "__mul__" signatures = [ "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -71,8 +68,9 @@ class RMul(Mul): class TrueDiv(ElementWise, Binary): name = "__truediv__" signatures = [ - "int, int -> float", - "float, float -> float", + "int, int -> float64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -94,8 +92,8 @@ class RFloorDiv(FloorDiv): class Pow(ElementWise, Binary): name = "__pow__" signatures = [ - "int, int -> float", - "float, float -> float", + "int, int -> float64", + "float64, float64 -> float64", ] @@ -118,7 +116,7 @@ class Neg(ElementWise, Unary): name = "__neg__" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", ] @@ -126,7 +124,7 @@ class Pos(ElementWise, Unary): name = "__pos__" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", ] @@ -134,7 +132,7 @@ class Abs(ElementWise, Unary): name = "__abs__" signatures = [ "int -> int", - "float -> float", + "float64 -> float64", ] @@ -143,16 +141,16 @@ class Round(ElementWise): signatures = [ "int -> int", "int, const int -> int", - "float -> float", - "float, const int -> float", + "float64 -> float64", + "float64, const int -> float64", ] class Log(ElementWise): name = "log" - signatures = ["float -> float"] + signatures = ["float64 -> float64"] class Exp(Log): name = "exp" - signatures = ["float -> float"] + signatures = ["float64 -> float64"] diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index 7ddb0306..0b532d60 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -75,7 +75,7 @@ def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): return True - # int can be promoted to float + # int can be promoted to float64 if Float64().same_kind(other): if other.const and not self.const: return False @@ -196,8 +196,10 @@ def dtype_from_string(t: str) -> Dtype: if base_type == "int": return Int(const=is_const, vararg=is_vararg) - if base_type == "float": + if base_type == "float64": return Float64(const=is_const, vararg=is_vararg) + if base_type == "decimal": + return Decimal(const=is_const, vararg=is_vararg) if base_type == "str": return String(const=is_const, vararg=is_vararg) if base_type == "bool": diff --git a/src/pydiverse/transform/tree/registry.py b/src/pydiverse/transform/tree/registry.py index 023d9ed5..38a05e98 100644 --- a/src/pydiverse/transform/tree/registry.py +++ b/src/pydiverse/transform/tree/registry.py @@ -296,7 +296,7 @@ class OperatorSignature: terminal_arg ::= modifiers (dtype | vararg) vararg ::= dtype "..." rtype ::= dtype - dtype ::= template | "int" | "float" | "str" | "bool" | and others... + dtype ::= template | "int" | "float64" | "str" | "bool" | and others... modifiers ::= "const"? template ::= single uppercase character @@ -523,7 +523,7 @@ def does_match( elif not node.value.same_kind(dtype): # Needs type promotion # This only works when types can be promoted once - # -> (uint > int) wouldn't be preferred over (uint > int > float) + # -> (uint > int) wouldn't be preferred over (uint > int > float64) type_promotion_indices = (*type_promotion_indices, s_i) if s_i + 1 == len(signature): diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 50600824..b1265da0 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -126,7 +126,10 @@ "c": [0.0, None, None, 2.2, -0.0, 10.0, -10.0, 3.1415926535], "d": [None, 2.71828, 0.577, 901234, -6.0, 4.0, None, -99.0], "e": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 500.0], - } + "f": [3.0, None, 0.0, 4.3, 10.0, -1.2, -9999.1, -34.1], + "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.0], + }, + schema_overrides={"f": pl.Decimal(), "g": pl.Decimal()}, ), } diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py index 155150ce..379d7e19 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -55,3 +55,11 @@ def test_round(df_num): round_d=round(t.d), ), ) + + +def test_div(df_num): + assert_result_equal(df_num, lambda t: t >> mutate(u=t.a / 2, v=t.b / 3.1)) + + +def test_decimal(df_num): + assert_result_equal(df_num, lambda t: t >> mutate(u=t.f + t.g, z=t.f * t.g)) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index 0c05ad87..addae7a6 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -152,8 +152,8 @@ def test_template(self): assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 assert reg.get_impl("op1", parse_dtypes("int", "str"))() == 2 # int can be promoted to float; results in "float, float -> bool" signature - assert reg.get_impl("op1", parse_dtypes("int", "float"))() == 1 - assert reg.get_impl("op1", parse_dtypes("float", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int", "float64"))() == 1 + assert reg.get_impl("op1", parse_dtypes("float64", "int"))() == 1 # More template matching... Also check matching precedence reg.add_impl(op2, lambda: 1, "int, int, int -> int") @@ -168,10 +168,10 @@ def test_template(self): assert reg.get_impl("op2", parse_dtypes("int", "int", "str"))() == 3 assert reg.get_impl("op2", parse_dtypes("int", "bool", "bool"))() == 4 assert reg.get_impl("op2", parse_dtypes("str", "str", "str"))() == 5 - assert reg.get_impl("op2", parse_dtypes("float", "str", "str"))() == 6 + assert reg.get_impl("op2", parse_dtypes("float64", "str", "str"))() == 6 with pytest.raises(ValueError): - reg.get_impl("op2", parse_dtypes("int", "bool", "float")) + reg.get_impl("op2", parse_dtypes("int", "bool", "float64")) # Return type reg.add_impl(op3, lambda: 1, "T -> T") @@ -190,11 +190,11 @@ def test_template(self): dtypes.Int, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("int", "int", "float")).return_type, + reg.get_impl("op3", parse_dtypes("int", "int", "float64")).return_type, dtypes.Int, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("str", "int", "float")).return_type, + reg.get_impl("op3", parse_dtypes("str", "int", "float64")).return_type, dtypes.Float64, ) diff --git a/tests/util/backend.py b/tests/util/backend.py index 957cad28..9f23e448 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -38,6 +38,7 @@ def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): global _sql_engine_cache dtypes_map = dtypes_map or {} + dtypes_map[pl.Decimal()] = sqa.DECIMAL() if url in _sql_engine_cache: engine = _sql_engine_cache[url] @@ -69,6 +70,7 @@ def duckdb_table(df: pl.DataFrame, name: str): @_cached_table def postgres_table(df: pl.DataFrame, name: str): url = "postgresql://sa:Pydiverse23@127.0.0.1:6543" + return sql_table(df, name, url) @@ -80,12 +82,7 @@ def mssql_table(df: pl.DataFrame, name: str): "mssql+pyodbc://sa:PydiQuant27@127.0.0.1:1433" "/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" ) - return sql_table( - df, - name, - url, - dtypes_map={pl.Datetime(): DATETIME2()}, - ) + return sql_table(df, name, url, dtypes_map={pl.Datetime(): DATETIME2()}) BACKEND_TABLES = { From 70e94f88ad17d4943e6cb6ec7305e31ae9c32076 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Fri, 27 Sep 2024 16:43:49 +0200 Subject: [PATCH 04/25] add common numeric ops for decimal --- src/pydiverse/transform/ops/numeric.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 6b6a29e3..81efad4e 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -94,6 +94,7 @@ class Pow(ElementWise, Binary): signatures = [ "int, int -> float64", "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -117,6 +118,7 @@ class Neg(ElementWise, Unary): signatures = [ "int -> int", "float64 -> float64", + "decimal -> decimal", ] @@ -125,6 +127,7 @@ class Pos(ElementWise, Unary): signatures = [ "int -> int", "float64 -> float64", + "decimal -> decimal", ] @@ -133,6 +136,7 @@ class Abs(ElementWise, Unary): signatures = [ "int -> int", "float64 -> float64", + "decimal -> decimal", ] @@ -143,6 +147,8 @@ class Round(ElementWise): "int, const int -> int", "float64 -> float64", "float64, const int -> float64", + "decimal -> decimal", + "decimal, const int -> decimal", ] From 34e2ab79025711db43b4240af2939503c7b27aff Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 11:06:11 +0200 Subject: [PATCH 05/25] add cast --- src/pydiverse/transform/backend/polars.py | 14 +++--- src/pydiverse/transform/backend/sql.py | 6 +++ src/pydiverse/transform/tree/col_expr.py | 48 +++++++++++++++++++ tests/test_backend_equivalence/conftest.py | 26 ++++++++++ .../test_ops/test_cast.py | 9 ++++ 5 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 tests/test_backend_equivalence/test_ops/test_cast.py diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 06b8e47b..28da3d41 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -13,6 +13,7 @@ from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -157,7 +158,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: # the function was executed on the ordered arguments. here we # restore the original order of the table. - inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by( + inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64()).sort_by( by=order_by, descending=descending, nulls_last=nulls_last, @@ -184,6 +185,11 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: return pl.lit(expr.val) # polars interprets strings as column names return expr.val + elif isinstance(expr, Cast): + return compile_col_expr(expr.val, name_in_df).cast( + pdt_type_to_polars(expr.target_type) + ) + else: raise AssertionError @@ -339,8 +345,6 @@ def has_path_to_leaf_without_agg(expr: ColExpr): def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: if t.is_float(): return dtypes.Float64() - elif t.is_decimal(): - return dtypes.Decimal() elif t.is_integer(): return dtypes.Int() elif isinstance(t, pl.Boolean): @@ -360,10 +364,8 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: - if isinstance(t, dtypes.Float64): + if isinstance(t, (dtypes.Float64, dtypes.Decimal)): return pl.Float64() - elif isinstance(t, dtypes.Decimal): - return pl.Decimal() elif isinstance(t, dtypes.Int): return pl.Int64() elif isinstance(t, dtypes.Bool): diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 15b1c245..758013a5 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -21,6 +21,7 @@ from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -207,6 +208,11 @@ def compile_col_expr( elif isinstance(expr, LiteralCol): return expr.val + elif isinstance(expr, Cast): + return cls.compile_col_expr(expr.val, sqa_col).cast( + pdt_type_to_sqa(expr.target_type) + ) + raise AssertionError @classmethod diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 5ad49d05..09b1235d 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -296,6 +296,11 @@ def __getattr__(self, name) -> FnAttr: return FnAttr(f"{self.name}.{name}", self.arg) def __call__(self, *args, **kwargs) -> ColExpr: + if self.name == "cast": + if len(kwargs) > 0: + raise ValueError("`cast` does not take any keyword arguments") + return Cast(self.arg, *args) + return ColFn( self.name, wrap_literal(self.arg), @@ -438,6 +443,49 @@ def otherwise(self, value: ColExpr) -> CaseExpr: return CaseExpr(self.cases, wrap_literal(value)) +class Cast(ColExpr): + __slots__ = ["val", "target_type"] + + def __init__(self, val: ColExpr, target_type: Dtype): + self.val = val + self.target_type = target_type + super().__init__(target_type) + self.dtype() + + def dtype(self) -> Dtype: + # Since `ColExpr.dtype` is also responsible for type checking, we may not set + # `_dtype` until we are able to retrieve the type of `val`. + if self.val.dtype() is None: + return None + + if not self.val.dtype().can_promote_to(self.target_type): + valid_casts = { + (dtypes.String, dtypes.Int), + (dtypes.String, dtypes.Float64), + (dtypes.String, dtypes.Bool), + (dtypes.Float64, dtypes.Int), + } + + if ( + self.val.dtype().__class__, + self.target_type.__class__, + ) not in valid_casts: + raise TypeError( + f"cannot cast type `{self.val.dtype()}` to `{self.target_type}`" + ) + + return self._dtype + + def ftype(self, *, agg_is_window: bool) -> Ftype: + return self.val.ftype(agg_is_window=agg_is_window) + + def iter_children(self) -> Iterable[ColExpr]: + yield self.val + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: + return g(Cast(g(self.val), self.target_type)) + + @dataclasses.dataclass(slots=True) class Order: order_by: ColExpr diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index b1265da0..808d1ae2 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -69,6 +69,32 @@ None, "% _.AbAbAb", ], + "c": [ + "4352.0", + "-21", + "-421", + "3.313", + None, + "-inf", + "inf", + "nan", + "-0.000", + "-0.0", + "0.0", + ], + "d": [ + None, + "-123124", + "21241", + "010101", + "0", + "1", + "-12", + "42", + "197", + "1729", + "-100110", + ], } ), "df_datetime": pl.DataFrame( diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py new file mode 100644 index 00000000..4cfc8782 --- /dev/null +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import pydiverse.transform as pdt +from pydiverse.transform.pipe.verbs import mutate +from tests.util.assertion import assert_result_equal + + +def test_string_to_float(df_strings): + assert_result_equal(df_strings, lambda t: t >> mutate(u=t.c.cast(pdt.Float64()))) From 7ac1a25495e91059bfe604249dd479328435d396 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 11:40:26 +0200 Subject: [PATCH 06/25] make string to float casts work --- src/pydiverse/transform/backend/mssql.py | 27 ++++++++++++++++++++++- src/pydiverse/transform/backend/sql.py | 10 ++++++--- src/pydiverse/transform/backend/sqlite.py | 20 ++++++++++++++++- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 1c8860fd..65928f52 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -8,11 +8,12 @@ from pydiverse.transform import ops from pydiverse.transform.backend import sql -from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.backend.sql import SqlImpl, pdt_type_to_sqa from pydiverse.transform.tree import dtypes, verbs from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -30,6 +31,22 @@ class MsSqlImpl(SqlImpl): dialect_name = "mssql" + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + if isinstance(cast.val.dtype(), dtypes.String) and isinstance( + cast.target_type, dtypes.Float64 + ): + return sqa.case( + (compiled_val == "inf", MSSQL_INF), + (compiled_val == "-inf", -MSSQL_INF), + (compiled_val == "nan", MSSQL_NAN), + else_=sqa.cast( + compiled_val, + pdt_type_to_sqa(cast.target_type), + ), + ) + @classmethod def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: # boolean / bit conversion @@ -151,6 +168,14 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr elif isinstance(expr, LiteralCol): return expr + elif isinstance(expr, Cast): + # TODO: does this really work for casting onto / from booleans? we probably have + # to use wants_bool_as_bit in some way when casting to bool + return Cast( + convert_bool_bit(expr.val, wants_bool_as_bit=wants_bool_as_bit), + expr.target_type, + ) + raise AssertionError diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 758013a5..668596cc 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -141,6 +141,12 @@ def compile_order( ) return order_expr + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + return cls.compile_col_expr(cast.val, sqa_col).cast( + pdt_type_to_sqa(cast.target_type) + ) + @classmethod def compile_col_expr( cls, expr: ColExpr, sqa_col: dict[str, sqa.Label] @@ -209,9 +215,7 @@ def compile_col_expr( return expr.val elif isinstance(expr, Cast): - return cls.compile_col_expr(expr.val, sqa_col).cast( - pdt_type_to_sqa(expr.target_type) - ) + return cls.compile_cast(expr, sqa_col) raise AssertionError diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 1dc14f07..61ab4af8 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -3,13 +3,31 @@ import sqlalchemy as sqa from pydiverse.transform import ops -from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.backend.sql import SqlImpl, pdt_type_to_sqa +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.col_expr import Cast from pydiverse.transform.util.warnings import warn_non_standard class SqliteImpl(SqlImpl): dialect_name = "sqlite" + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + if isinstance(cast.val.dtype(), dtypes.String) and isinstance( + cast.target_type, dtypes.Float64 + ): + return sqa.case( + (compiled_val == "inf", sqa.literal("inf")), + (compiled_val == "-inf", sqa.literal("-inf")), + (compiled_val == "nan", sqa.literal("nan")), + else_=sqa.cast( + compiled_val, + pdt_type_to_sqa(cast.target_type), + ), + ) + with SqliteImpl.op(ops.Round()) as op: From fe0ee846433284639527616ec8ac585f3090a572 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 11:54:13 +0200 Subject: [PATCH 07/25] test string to int cast --- src/pydiverse/transform/backend/mssql.py | 2 ++ src/pydiverse/transform/backend/sqlite.py | 3 +++ tests/test_backend_equivalence/test_ops/test_cast.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 65928f52..e3066d51 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -47,6 +47,8 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) + return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) + @classmethod def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: # boolean / bit conversion diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 61ab4af8..bb18fd35 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -15,6 +15,7 @@ class SqliteImpl(SqlImpl): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) + if isinstance(cast.val.dtype(), dtypes.String) and isinstance( cast.target_type, dtypes.Float64 ): @@ -28,6 +29,8 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) + return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) + with SqliteImpl.op(ops.Round()) as op: diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 4cfc8782..00b5aa90 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -7,3 +7,7 @@ def test_string_to_float(df_strings): assert_result_equal(df_strings, lambda t: t >> mutate(u=t.c.cast(pdt.Float64()))) + + +def test_string_to_int(df_strings): + assert_result_equal(df_strings, lambda t: t >> mutate(u=t.d.cast(pdt.Int()))) From b54438fd3cd1e0511de01b4ac10dac4425fd0c42 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 13:58:09 +0200 Subject: [PATCH 08/25] add tests for float to in casts disallow string to bool casts --- src/pydiverse/transform/backend/postgres.py | 10 +++++++ src/pydiverse/transform/tree/col_expr.py | 1 - tests/test_backend_equivalence/conftest.py | 9 +++---- .../test_ops/test_cast.py | 26 +++++++++++++++++-- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index b83a32db..7c7d170a 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -4,11 +4,21 @@ from pydiverse.transform import ops from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.col_expr import Cast class PostgresImpl(SqlImpl): dialect_name = "postgresql" + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: + if isinstance(cast.dtype(), dtypes.Float64) and isinstance( + cast.target_type, dtypes.Int + ): + return ... + return super().compile_cast(cast, sqa_col) + with PostgresImpl.op(ops.Less()) as op: diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 09b1235d..2c7d6709 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -462,7 +462,6 @@ def dtype(self) -> Dtype: valid_casts = { (dtypes.String, dtypes.Int), (dtypes.String, dtypes.Float64), - (dtypes.String, dtypes.Bool), (dtypes.Float64, dtypes.Int), } diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 808d1ae2..1090b46e 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -149,13 +149,12 @@ { "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, 2.3, -1000.0], "b": [None, 2, 0, -11, 4, 19, -5190, 2000000], - "c": [0.0, None, None, 2.2, -0.0, 10.0, -10.0, 3.1415926535], + "c": [0.0, None, None, 2.9, -0.0, 10.0, -10.0, 3.1415926535], "d": [None, 2.71828, 0.577, 901234, -6.0, 4.0, None, -99.0], - "e": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 500.0], + "e": [1.0, 2.0, 3.0, 4.99, 5.0, 6.0, 7.0, 500.0], "f": [3.0, None, 0.0, 4.3, 10.0, -1.2, -9999.1, -34.1], - "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.0], - }, - schema_overrides={"f": pl.Decimal(), "g": pl.Decimal()}, + "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.6], + } ), } diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 00b5aa90..ca77e9d7 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -6,8 +6,30 @@ def test_string_to_float(df_strings): - assert_result_equal(df_strings, lambda t: t >> mutate(u=t.c.cast(pdt.Float64()))) + assert_result_equal( + df_strings, + lambda t: t >> mutate(u=t.c.cast(pdt.Float64())), + ) def test_string_to_int(df_strings): - assert_result_equal(df_strings, lambda t: t >> mutate(u=t.d.cast(pdt.Int()))) + assert_result_equal( + df_strings, + lambda t: t >> mutate(u=t.d.cast(pdt.Int())), + ) + + +def test_float_to_int(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + u=t.a.cast(pdt.Int()), + v=t.b.cast(pdt.Int()), + w=t.f.cast(pdt.Int()), + x=t.d.cast(pdt.Int()), + y=t.e.cast(pdt.Int()), + z=t.f.cast(pdt.Int()), + q=t.g.cast(pdt.Int()), + ), + ) From b4d1b440e454e3b003737beeb9013c2dd50a5f07 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 14:04:11 +0200 Subject: [PATCH 09/25] add floor / ceil --- src/pydiverse/transform/backend/mssql.py | 7 ++++ src/pydiverse/transform/backend/polars.py | 14 ++++++++ src/pydiverse/transform/backend/sql.py | 14 ++++++++ src/pydiverse/transform/ops/numeric.py | 14 ++++++++ .../test_ops/test_ops_numerical.py | 32 +++++++++++++++++++ 5 files changed, 81 insertions(+) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index e3066d51..a47da1db 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -334,3 +334,10 @@ def _log(x): (x.is_(sqa.null()), None), else_=-MSSQL_INF, ) + + +with MsSqlImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return sqa.func.ceiling(x) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 28da3d41..a9f11d5a 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -713,3 +713,17 @@ def _exp(x): @op.auto def _log(x): return x.log() + + +with PolarsImpl.op(ops.Floor()) as op: + + @op.auto + def _floor(x): + return x.floor() + + +with PolarsImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return x.ceil() diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 668596cc..d344c7e9 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -987,3 +987,17 @@ def _log(x): (x.is_(sqa.null()), None), else_=sqa.literal("-inf"), ) + + +with SqlImpl.op(ops.Floor()) as op: + + @op.auto + def _floor(x): + return sqa.func.floor(x) + + +with SqlImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return sqa.func.ceil(x) diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 81efad4e..4c6e77e7 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -21,6 +21,8 @@ "Pos", "Abs", "Round", + "Floor", + "Ceil", "Exp", "Log", ] @@ -152,6 +154,18 @@ class Round(ElementWise): ] +class Floor(ElementWise): + name = "floor" + signatures = [ + "float64 -> float64", + "decimal -> decimal", + ] + + +class Ceil(Floor): + name = "ceil" + + class Log(ElementWise): name = "log" signatures = ["float64 -> float64"] diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py index 379d7e19..36afde05 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -63,3 +63,35 @@ def test_div(df_num): def test_decimal(df_num): assert_result_equal(df_num, lambda t: t >> mutate(u=t.f + t.g, z=t.f * t.g)) + + +def test_floor(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + u=t.a.floor(), + v=t.b.floor(), + w=t.f.floor(), + x=t.d.floor(), + y=t.e.floor(), + z=t.f.floor(), + q=t.g.floor(), + ), + ) + + +def test_ceil(df_num): + assert_result_equal( + df_num, + lambda t: t + >> mutate( + u=t.a.ceil(), + v=t.b.ceil(), + w=t.f.ceil(), + x=t.d.ceil(), + y=t.e.ceil(), + z=t.f.ceil(), + q=t.g.ceil(), + ), + ) From 56028301228b885aaa0aba50b8bb87027502c875 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 28 Sep 2024 14:10:07 +0200 Subject: [PATCH 10/25] make float to int casts consistent we always truncate, like in C --- src/pydiverse/transform/backend/duckdb.py | 14 +++++++++++++- src/pydiverse/transform/backend/postgres.py | 6 ++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index 5937e3e5..b7e34057 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -1,12 +1,14 @@ from __future__ import annotations import polars as pl +import sqlalchemy as sqa from pydiverse.transform.backend import sql from pydiverse.transform.backend.sql import SqlImpl from pydiverse.transform.backend.targets import Polars, Target +from pydiverse.transform.tree import dtypes from pydiverse.transform.tree.ast import AstNode -from pydiverse.transform.tree.col_expr import Col +from pydiverse.transform.tree.col_expr import Cast, Col class DuckDbImpl(SqlImpl): @@ -21,3 +23,13 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]): DuckDbImpl.build_query(nd, final_select), connection=conn ) return SqlImpl.export(nd, target, final_select) + + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: + if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( + cast.target_type, dtypes.Int + ): + return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( + sqa.Integer() + ) + return super().compile_cast(cast, sqa_col) diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index 7c7d170a..8a408b38 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -13,10 +13,12 @@ class PostgresImpl(SqlImpl): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: - if isinstance(cast.dtype(), dtypes.Float64) and isinstance( + if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( cast.target_type, dtypes.Int ): - return ... + return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( + sqa.Integer() + ) return super().compile_cast(cast, sqa_col) From bb736720683479ff02bd62a6e06ac18beded5727 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sun, 29 Sep 2024 12:16:37 +0200 Subject: [PATCH 11/25] also catch -nan in sqlite / mssql str -> float --- src/pydiverse/transform/backend/mssql.py | 2 +- src/pydiverse/transform/backend/sqlite.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index a47da1db..673a7f02 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -40,7 +40,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: return sqa.case( (compiled_val == "inf", MSSQL_INF), (compiled_val == "-inf", -MSSQL_INF), - (compiled_val == "nan", MSSQL_NAN), + (compiled_val.in_(("nan", "-nan")), MSSQL_NAN), else_=sqa.cast( compiled_val, pdt_type_to_sqa(cast.target_type), diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index bb18fd35..649a760e 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -22,7 +22,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: return sqa.case( (compiled_val == "inf", sqa.literal("inf")), (compiled_val == "-inf", sqa.literal("-inf")), - (compiled_val == "nan", sqa.literal("nan")), + (compiled_val.in_(("nan", "-nan")), sqa.literal("nan")), else_=sqa.cast( compiled_val, pdt_type_to_sqa(cast.target_type), From 2c9a5891625c20882888107481399da4db239186 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sun, 29 Sep 2024 12:17:48 +0200 Subject: [PATCH 12/25] add stronger test cases --- src/pydiverse/transform/pipe/table.py | 15 ++-- tests/test_backend_equivalence/conftest.py | 33 ++++++- .../test_ops/test_cast.py | 11 +-- .../test_ops/test_ops_numerical.py | 89 ++++++++----------- 4 files changed, 75 insertions(+), 73 deletions(-) diff --git a/src/pydiverse/transform/pipe/table.py b/src/pydiverse/transform/pipe/table.py index 503d1dbb..e32c7e98 100644 --- a/src/pydiverse/transform/pipe/table.py +++ b/src/pydiverse/transform/pipe/table.py @@ -9,15 +9,9 @@ from pydiverse.transform.backend.table_impl import TableImpl from pydiverse.transform.tree.ast import AstNode -from pydiverse.transform.tree.col_expr import ( - Col, - ColExpr, -) +from pydiverse.transform.tree.col_expr import Col -# TODO: if we decide that select controls the C-space, the columns in _select will -# always be the same as those that we have to keep in _schema. However, we still need -# _select for the order. class Table: __slots__ = ["_ast", "_cache"] @@ -26,7 +20,8 @@ class Table: which is a reference to the underlying abstract syntax tree. """ - # TODO: define exactly what can be given for the two + # TODO: define exactly what can be given for the two and do type checks + # maybe call the second one execution_engine or similar? def __init__(self, resource, backend=None, *, name: str | None = None): import polars as pl @@ -40,6 +35,8 @@ def __init__(self, resource, backend=None, *, name: str | None = None): self._ast: AstNode = resource elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)): if name is None: + # TODO: we could look whether the df has a name attr set (which is the + # case if it was previously exported) name = "?" self._ast = PolarsImpl(name, resource) elif isinstance(resource, (str, sqa.Table)): @@ -77,7 +74,7 @@ def __setstate__(self, d): # to avoid very annoying AttributeErrors for slot, val in d[1].items(): setattr(self, slot, val) - def __iter__(self) -> Iterable[ColExpr]: + def __iter__(self) -> Iterable[Col]: cols = copy.copy(self._cache.select) yield from cols diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 1090b46e..86df0ea3 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -72,7 +72,7 @@ "c": [ "4352.0", "-21", - "-421", + "-nan", "3.313", None, "-inf", @@ -147,13 +147,38 @@ ), "df_num": pl.DataFrame( { - "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, 2.3, -1000.0], + "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, float("inf"), -1000.0], "b": [None, 2, 0, -11, 4, 19, -5190, 2000000], "c": [0.0, None, None, 2.9, -0.0, 10.0, -10.0, 3.1415926535], - "d": [None, 2.71828, 0.577, 901234, -6.0, 4.0, None, -99.0], - "e": [1.0, 2.0, 3.0, 4.99, 5.0, 6.0, 7.0, 500.0], + "d": [None, float("-nan"), 0.577, 901234, -6.0, 4.0, None, -99.0], + "e": [1.0, 2.0, 3.0, 4.99, float("nan"), 6.0, 7.0, 500.0], "f": [3.0, None, 0.0, 4.3, 10.0, -1.2, -9999.1, -34.1], "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.6], + "nan": [float("nan"), float("nan"), float("-nan"), float("-nan")] * 2, + "inf": [float("inf")] * 8, + "-inf": [float("-inf")] * 8, + "zero": [0.0, -0.0] * 4, + "pos": [ + 1.123, + 1297.324, + 9192.9793, + 7.5, + 912.097, + 32.9, + 2.712834, + 5002352.434, + ], + "neg": [ + -9623.1, + -0.1, + -1.0, + -923737552.5, + -5.5, + -0.12083, + -93.4, + -6699917733.1242, + ], + "null:": [0.0, None, None, None, None, None, None, None], } ), } diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index ca77e9d7..9dc06ed9 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -22,14 +22,5 @@ def test_string_to_int(df_strings): def test_float_to_int(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - u=t.a.cast(pdt.Int()), - v=t.b.cast(pdt.Int()), - w=t.f.cast(pdt.Int()), - x=t.d.cast(pdt.Int()), - y=t.e.cast(pdt.Int()), - z=t.f.cast(pdt.Int()), - q=t.g.cast(pdt.Int()), - ), + lambda t: t >> mutate(**{c.name: c.cast(pdt.Int()) for c in t}), ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py index 36afde05..df8e5cbd 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -7,91 +7,80 @@ def test_exp(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - exp_a=t.a.exp(), - exp_b=t.b.exp(), - exp_c=t.c.exp(), - exp_d=t.d.exp(), - ), + lambda t: t >> mutate(**{c.name: c.exp() for c in t}), ) def test_log(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - log_a=t.a.log(), - log_b=t.b.log(), - log_c=t.c.log(), - log_d=t.d.log(), - log_e=t.e.exp(), - ), + lambda t: t >> mutate(**{c.name: c.log() for c in t}), ) def test_abs(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - abs_a=abs(t.a), - abs_b=abs(t.b), - abs_c=abs(t.c), - abs_d=abs(t.d), - ), + lambda t: t >> mutate(**{c.name: abs(c) for c in t}), ) def test_round(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - round_a=round(t.a), - round_b=round(t.b), - round_c=round(t.c), - round_d=round(t.d), - ), + lambda t: t >> mutate(**{c.name: round(c) for c in t}), + ) + + +def test_add(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}+{d.name}": c + d for d in t for c in t}), + ) + + +def test_sub(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}-{d.name}": c - d for d in t for c in t}), + ) + + +def test_neg(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: -c for c in t}), + ) + + +def test_mul(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}*{d.name}": c * d for d in t for c in t}), ) def test_div(df_num): - assert_result_equal(df_num, lambda t: t >> mutate(u=t.a / 2, v=t.b / 3.1)) + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}/{d.name}": c / d for d in t for c in t}), + ) def test_decimal(df_num): + # TODO: test the decimal here assert_result_equal(df_num, lambda t: t >> mutate(u=t.f + t.g, z=t.f * t.g)) def test_floor(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - u=t.a.floor(), - v=t.b.floor(), - w=t.f.floor(), - x=t.d.floor(), - y=t.e.floor(), - z=t.f.floor(), - q=t.g.floor(), - ), + lambda t: t >> mutate(**{c.name: c.floor() for c in t}), ) def test_ceil(df_num): assert_result_equal( df_num, - lambda t: t - >> mutate( - u=t.a.ceil(), - v=t.b.ceil(), - w=t.f.ceil(), - x=t.d.ceil(), - y=t.e.ceil(), - z=t.f.ceil(), - q=t.g.ceil(), - ), + lambda t: t >> mutate(**{c.name: c.ceil() for c in t}), ) From 2d00ebef6058b83c95d0e90be10d2b6ac1d084c2 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 11:02:50 +0200 Subject: [PATCH 13/25] allow Callables in the pipe --- src/pydiverse/transform/pipe/table.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/pydiverse/transform/pipe/table.py b/src/pydiverse/transform/pipe/table.py index e32c7e98..1b58c7c3 100644 --- a/src/pydiverse/transform/pipe/table.py +++ b/src/pydiverse/transform/pipe/table.py @@ -2,12 +2,14 @@ import copy import dataclasses -from collections.abc import Iterable +import inspect +from collections.abc import Callable, Iterable from html import escape import sqlalchemy as sqa from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.pipe.pipeable import Pipeable from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import Col @@ -81,6 +83,25 @@ def __iter__(self) -> Iterable[Col]: def __len__(self) -> int: return len(self._cache.select) + def __rshift__(self, rhs): + if isinstance(rhs, Pipeable): + return rhs.__rrshift__(self) + if isinstance(rhs, Callable): + num_params = len(inspect.signature(rhs).parameters) + if num_params != 1: + raise TypeError( + "only functions with one parameter can be used in a pipe, got " + f"function with {num_params} parameters." + ) + return rhs(self) + + raise TypeError( + f"found instance of invalid type `{type(rhs)}` in the pipe. \n" + "hint: You can use a `Table` or a Callable taking a single argument in a " + "pipe. If you use a Callable, it will receive the current table as an " + "and must return a `Table`." + ) + def __str__(self): try: from pydiverse.transform.backend.targets import Polars From c35a98080e1b400e4cb28fda575b18a93da5f39c Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 11:11:28 +0200 Subject: [PATCH 14/25] add nan / inf cols manually polars does not write them properly to the database --- tests/test_backend_equivalence/conftest.py | 11 +++--- .../test_ops/test_cast.py | 17 +++++++++- .../test_ops/test_ops_numerical.py | 34 ++++++++++++++++--- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index 86df0ea3..a8e9c24f 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -147,16 +147,13 @@ ), "df_num": pl.DataFrame( { - "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, float("inf"), -1000.0], + "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, -344.0053, -1000.0], "b": [None, 2, 0, -11, 4, 19, -5190, 2000000], "c": [0.0, None, None, 2.9, -0.0, 10.0, -10.0, 3.1415926535], - "d": [None, float("-nan"), 0.577, 901234, -6.0, 4.0, None, -99.0], - "e": [1.0, 2.0, 3.0, 4.99, float("nan"), 6.0, 7.0, 500.0], + "d": [None, 2352.0230, 0.577, 901234, -6.0, 4.0, None, -99.0], + "e": [1.0, 2.0, 3.0, 4.99, -442.0, 6.0, 7.0, 500.0], "f": [3.0, None, 0.0, 4.3, 10.0, -1.2, -9999.1, -34.1], "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.6], - "nan": [float("nan"), float("nan"), float("-nan"), float("-nan")] * 2, - "inf": [float("inf")] * 8, - "-inf": [float("-inf")] * 8, "zero": [0.0, -0.0] * 4, "pos": [ 1.123, @@ -178,7 +175,7 @@ -93.4, -6699917733.1242, ], - "null:": [0.0, None, None, None, None, None, None, None], + "null_s": [0.0, None, None, None, None, None, None, None], } ), } diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 9dc06ed9..713764da 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -1,7 +1,9 @@ from __future__ import annotations import pydiverse.transform as pdt +from pydiverse.transform.pipe.c import C from pydiverse.transform.pipe.verbs import mutate +from tests.test_backend_equivalence.test_ops.test_ops_numerical import add_nan_inf_cols from tests.util.assertion import assert_result_equal @@ -22,5 +24,18 @@ def test_string_to_int(df_strings): def test_float_to_int(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{c.name: c.cast(pdt.Int()) for c in t}), + lambda t: t >> mutate(**{col.name: col.cast(pdt.Int()) for col in t}), + ) + + assert_result_equal( + df_num, + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int())), + exception=Exception, + may_throw=True, + ) + assert_result_equal( + df_num, + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int())), + exception=Exception, + may_throw=True, ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py index df8e5cbd..f2a57400 100644 --- a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -1,13 +1,27 @@ from __future__ import annotations +import pydiverse.transform as pdt +from pydiverse.transform.pipe.pipeable import verb from pydiverse.transform.pipe.verbs import mutate from tests.util.assertion import assert_result_equal +@verb +def add_nan_inf_cols(table: pdt.Table) -> pdt.Table: + return table >> mutate( + **{ + "nan": float("nan"), + "negnan": float("-nan"), + "inf": float("inf"), + "neginf": float("-inf"), + } + ) + + def test_exp(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{c.name: c.exp() for c in t}), + lambda t: t >> add_nan_inf_cols() >> mutate(**{c.name: c.exp() for c in t}), ) @@ -35,21 +49,33 @@ def test_round(df_num): def test_add(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{f"{c.name}+{d.name}": c + d for d in t for c in t}), + lambda t: t + >> add_nan_inf_cols() + >> ( + lambda s: s + >> mutate(**{f"add_{c.name}_{d.name}": c + d for d in s for c in s}) + ), ) def test_sub(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{f"{c.name}-{d.name}": c - d for d in t for c in t}), + lambda t: t + >> add_nan_inf_cols() + >> ( + lambda s: s + >> mutate(**{f"sub_{c.name}_{d.name}": c - d for d in s for c in s}) + ), ) def test_neg(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{c.name: -c for c in t}), + lambda t: t + >> add_nan_inf_cols() + >> (lambda s: s >> mutate(**{f"neg_{c.name}": -c for c in s})), ) From e7383d49653811943401e36cd2d244ef476012be Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 11:12:31 +0200 Subject: [PATCH 15/25] sqlite: fix inf, don't distinguish nan and null --- src/pydiverse/transform/backend/mssql.py | 21 +++++++++++---------- src/pydiverse/transform/backend/sql.py | 13 ++++++++++++- src/pydiverse/transform/backend/sqlite.py | 10 +++++++--- tests/util/assertion.py | 7 ++++++- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 673a7f02..3cbd97ba 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -22,15 +22,16 @@ ) from pydiverse.transform.util.warnings import warn_non_standard -MSSQL_INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal( - "0.0", type_=sqa.Float() -) -MSSQL_NAN = MSSQL_INF + (-MSSQL_INF) - class MsSqlImpl(SqlImpl): dialect_name = "mssql" + INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal( + "0.0", type_=sqa.Float() + ) + NEG_INF = -INF + NAN = INF + NEG_INF + @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) @@ -38,9 +39,9 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: cast.target_type, dtypes.Float64 ): return sqa.case( - (compiled_val == "inf", MSSQL_INF), - (compiled_val == "-inf", -MSSQL_INF), - (compiled_val.in_(("nan", "-nan")), MSSQL_NAN), + (compiled_val == "inf", cls.INF), + (compiled_val == "-inf", -cls.INF), + (compiled_val.in_(("nan", "-nan")), cls.NAN), else_=sqa.cast( compiled_val, pdt_type_to_sqa(cast.target_type), @@ -330,9 +331,9 @@ def _log(x): # TODO: we still need to handle inf / -inf / nan return sqa.case( (x > 0, sqa.func.log(x)), - (x < 0, MSSQL_NAN), + (x < 0, MsSqlImpl.NAN), (x.is_(sqa.null()), None), - else_=-MSSQL_INF, + else_=-MsSqlImpl.INF, ) diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index d344c7e9..9432a6e4 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -4,6 +4,7 @@ import functools import inspect import itertools +import math import operator from collections.abc import Iterable from typing import Any @@ -34,6 +35,10 @@ class SqlImpl(TableImpl): Dialects: dict[str, type[TableImpl]] = {} + INF = sqa.cast(sqa.literal("inf"), sqa.Float()) + NEG_INF = sqa.cast(sqa.literal("-inf"), sqa.Float()) + NAN = sqa.cast(sqa.literal("nan"), sqa.Float()) + def __new__(cls, *args, **kwargs) -> SqlImpl: engine: str | sqa.Engine = ( inspect.signature(cls.__init__) @@ -113,6 +118,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: sql_col.name: pdt_type_to_polars(col.dtype()) for sql_col, col in zip(sel.columns.values(), final_select) }, + infer_schema_length=0, ) df.name = nd.name return df @@ -212,6 +218,11 @@ def compile_col_expr( ) elif isinstance(expr, LiteralCol): + if isinstance(expr.val, float): + if math.isnan(expr.val): + return cls.NAN + elif math.isinf(expr.val): + return cls.INF if expr.val > 0 else cls.NEG_INF return expr.val elif isinstance(expr, Cast): @@ -581,7 +592,7 @@ def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: if isinstance(t, dtypes.Int): - return sqa.Integer() + return sqa.BigInteger() elif isinstance(t, dtypes.Float64): return sqa.Float() elif isinstance(t, dtypes.Decimal): diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 649a760e..094113e2 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -12,6 +12,10 @@ class SqliteImpl(SqlImpl): dialect_name = "sqlite" + INF = sqa.cast(sqa.literal("1e314"), sqa.Float) + NEG_INF = -INF + NAN = sqa.null() + @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) @@ -20,9 +24,9 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: cast.target_type, dtypes.Float64 ): return sqa.case( - (compiled_val == "inf", sqa.literal("inf")), - (compiled_val == "-inf", sqa.literal("-inf")), - (compiled_val.in_(("nan", "-nan")), sqa.literal("nan")), + (compiled_val == "inf", cls.INF), + (compiled_val == "-inf", cls.NEG_INF), + (compiled_val.in_(("nan", "-nan")), cls.NAN), else_=sqa.cast( compiled_val, pdt_type_to_sqa(cast.target_type), diff --git a/tests/util/assertion.py b/tests/util/assertion.py index 0fec9339..479d957f 100644 --- a/tests/util/assertion.py +++ b/tests/util/assertion.py @@ -8,9 +8,10 @@ from polars.testing import assert_frame_equal from pydiverse.transform import Table +from pydiverse.transform.backend.sqlite import SqliteImpl from pydiverse.transform.backend.targets import Polars from pydiverse.transform.errors import NonStandardBehaviourWarning -from pydiverse.transform.pipe.verbs import export, show_query +from pydiverse.transform.pipe.verbs import export, get_backend, show_query def assert_equal(left, right, check_dtypes=False, check_row_order=True): @@ -85,6 +86,10 @@ def assert_result_equal( pl.col(pl.Decimal(scale=10)).cast(pl.Float64) ) + # sqlite does not know NaN + if get_backend(query_y._ast) is SqliteImpl: + dfx = dfx.fill_nan(None) + # after a join, cols containing only null values get type Null on SQLite and # Postgres. maybe we can fix this but for now we just ignore such cols assert dfx.columns == dfy.columns From ec306c9949a7cc3f7586c2e7c6866fb5ff5b7855 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 11:20:18 +0200 Subject: [PATCH 16/25] fix duckdb / postgres cast float -> int --- src/pydiverse/transform/backend/duckdb.py | 2 +- src/pydiverse/transform/backend/postgres.py | 2 +- tests/util/backend.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index b7e34057..2b4d68f7 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -30,6 +30,6 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: cast.target_type, dtypes.Int ): return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( - sqa.Integer() + sqa.BigInteger() ) return super().compile_cast(cast, sqa_col) diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index 8a408b38..b21f82ac 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -17,7 +17,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: cast.target_type, dtypes.Int ): return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( - sqa.Integer() + sqa.BigInteger() ) return super().compile_cast(cast, sqa_col) diff --git a/tests/util/backend.py b/tests/util/backend.py index 9f23e448..844b3cb8 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -38,7 +38,6 @@ def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): global _sql_engine_cache dtypes_map = dtypes_map or {} - dtypes_map[pl.Decimal()] = sqa.DECIMAL() if url in _sql_engine_cache: engine = _sql_engine_cache[url] From 611c3b0e93e7b00a0934ec8daae6ecc835a1f138 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 12:36:08 +0200 Subject: [PATCH 17/25] add datetime to date cast --- src/pydiverse/transform/backend/sql.py | 2 +- src/pydiverse/transform/backend/sqlite.py | 5 +++++ src/pydiverse/transform/tree/col_expr.py | 1 + tests/test_backend_equivalence/test_ops/test_cast.py | 7 +++++++ tests/util/backend.py | 2 +- 5 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 9432a6e4..2dc35385 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -594,7 +594,7 @@ def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: if isinstance(t, dtypes.Int): return sqa.BigInteger() elif isinstance(t, dtypes.Float64): - return sqa.Float() + return sqa.Double() elif isinstance(t, dtypes.Decimal): return sqa.DECIMAL() elif isinstance(t, dtypes.String): diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 094113e2..9361ddc8 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -33,6 +33,11 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) + elif isinstance(cast.val.dtype(), dtypes.DateTime) and isinstance( + cast.target_type, dtypes.Date + ): + return sqa.type_coerce(sqa.func.date(compiled_val), sqa.DATE()) + return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 2c7d6709..af94910f 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -463,6 +463,7 @@ def dtype(self) -> Dtype: (dtypes.String, dtypes.Int), (dtypes.String, dtypes.Float64), (dtypes.Float64, dtypes.Int), + (dtypes.DateTime, dtypes.Date), } if ( diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 713764da..aa6ac8fc 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -39,3 +39,10 @@ def test_float_to_int(df_num): exception=Exception, may_throw=True, ) + + +def test_datetime_to_date(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t >> mutate(u=t.col1.cast(pdt.Date()), v=t.col2.cast(pdt.Date())), + ) diff --git a/tests/util/backend.py b/tests/util/backend.py index 844b3cb8..006181e7 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -32,7 +32,7 @@ def polars_table(df: pl.DataFrame, name: str): _sql_engine_cache = {} -def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): +def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict | None = None): import sqlalchemy as sqa global _sql_engine_cache From 5aa690b6ad3cd1759ffb2b4d64a923e8fcc32270 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 13:29:22 +0200 Subject: [PATCH 18/25] implement int to string cast --- src/pydiverse/transform/tree/col_expr.py | 2 ++ tests/test_backend_equivalence/conftest.py | 7 +++++++ .../test_ops/test_cast.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index af94910f..53f519c7 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -464,6 +464,8 @@ def dtype(self) -> Dtype: (dtypes.String, dtypes.Float64), (dtypes.Float64, dtypes.Int), (dtypes.DateTime, dtypes.Date), + (dtypes.Int, dtypes.String), + (dtypes.Float64, dtypes.String), } if ( diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index a8e9c24f..59673f4b 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -178,6 +178,13 @@ "null_s": [0.0, None, None, None, None, None, None, None], } ), + "df_int": pl.DataFrame( + { + "a": [3, 1, 0, -12, 4, 5, 1 << 20, 5], + "b": [-23, 18282, -42, 1729, None, -2323, 11, 1], + "null_s": [0] + [None] * 7, + } + ), } # compare one dataframe and one SQL backend to all others diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index aa6ac8fc..80349bef 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -46,3 +46,18 @@ def test_datetime_to_date(df_datetime): df_datetime, lambda t: t >> mutate(u=t.col1.cast(pdt.Date()), v=t.col2.cast(pdt.Date())), ) + + +def test_int_to_string(df_int): + assert_result_equal( + df_int, lambda t: t >> mutate(**{c.name: c.cast(pdt.String()) for c in t}) + ) + + +def test_float_to_string(df_num): + assert_result_equal( + df_num, + lambda t: t + >> add_nan_inf_cols() + >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.String()) for c in s})), + ) From 9c79efb3fae16658e32c8576ba80f86031402b45 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 13:58:49 +0200 Subject: [PATCH 19/25] partially implement cast of float to string we do not guarantte equivalence here (e.g. 0.0 vs 0), but we guarantee that a sequence of casts Float64 -> String -> Float64 is the identity (up to some epsilon). --- src/pydiverse/transform/backend/mssql.py | 11 ++++++++++ src/pydiverse/transform/backend/polars.py | 8 ++++++- src/pydiverse/transform/backend/postgres.py | 21 +++++++++++++------ src/pydiverse/transform/backend/sqlite.py | 9 ++++++++ .../test_ops/test_cast.py | 3 ++- 5 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 3cbd97ba..c6c7da89 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -48,6 +48,17 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) + if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( + cast.target_type, dtypes.String + ): + compiled = super().compile_cast(cast, sqa_col) + return sqa.case( + (compiled == "1.#QNAN", "nan"), + (compiled == "1.#INF", "inf"), + (compiled == "-1.#INF", "-inf"), + else_=compiled, + ) + return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) @classmethod diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index a9f11d5a..1fee7844 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -186,9 +186,15 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: return expr.val elif isinstance(expr, Cast): - return compile_col_expr(expr.val, name_in_df).cast( + compiled = compile_col_expr(expr.val, name_in_df).cast( pdt_type_to_polars(expr.target_type) ) + if isinstance(expr.val.dtype(), dtypes.Float64) and isinstance( + expr.target_type, dtypes.String + ): + compiled = compiled.replace("NaN", "nan") + + return compiled else: raise AssertionError diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index b21f82ac..3b6c2b94 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -13,12 +13,21 @@ class PostgresImpl(SqlImpl): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: - if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( - cast.target_type, dtypes.Int - ): - return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( - sqa.BigInteger() - ) + if isinstance(cast.val.dtype(), dtypes.Float64): + if isinstance(cast.target_type, dtypes.Int): + return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( + sqa.BigInteger() + ) + + if isinstance(cast.target_type, dtypes.String): + compiled = super().compile_cast(cast, sqa_col) + return sqa.case( + (compiled == "NaN", "nan"), + (compiled == "Infinity", "inf"), + (compiled == "-Infinity", "-inf"), + else_=compiled, + ) + return super().compile_cast(cast, sqa_col) diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 9361ddc8..88d21e4b 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -38,6 +38,15 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ): return sqa.type_coerce(sqa.func.date(compiled_val), sqa.DATE()) + elif isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( + cast.target_type, dtypes.String + ): + return sqa.case( + (compiled_val == cls.INF, "inf"), + (compiled_val == cls.NEG_INF, "-inf"), + else_=sqa.cast(compiled_val, sqa.String), + ) + return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 80349bef..d72ace8b 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -59,5 +59,6 @@ def test_float_to_string(df_num): df_num, lambda t: t >> add_nan_inf_cols() - >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.String()) for c in s})), + >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.String()) for c in s})) + >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.Float64()) for c in s})), ) From 193905a3a6f546f69ddebe505acfb93053b1c937 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 14:28:43 +0200 Subject: [PATCH 20/25] allow Date to String and Datetime to String cast --- src/pydiverse/transform/tree/col_expr.py | 4 +++- .../test_ops/test_cast.py | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 53f519c7..2a98f966 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -466,6 +466,8 @@ def dtype(self) -> Dtype: (dtypes.DateTime, dtypes.Date), (dtypes.Int, dtypes.String), (dtypes.Float64, dtypes.String), + (dtypes.DateTime, dtypes.String), + (dtypes.Date, dtypes.String), } if ( @@ -485,7 +487,7 @@ def iter_children(self) -> Iterable[ColExpr]: yield self.val def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: - return g(Cast(g(self.val), self.target_type)) + return g(Cast(self.val.map_subtree(g), self.target_type)) @dataclasses.dataclass(slots=True) diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index d72ace8b..a9b23bb2 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -62,3 +62,26 @@ def test_float_to_string(df_num): >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.String()) for c in s})) >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.Float64()) for c in s})), ) + + +def test_datetime_to_string(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t + >> mutate( + x=t.col1.cast(pdt.String()), + y=t.col2.cast(pdt.String()), + ), + ) + + +def test_date_to_string(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t + >> mutate( + x=t.col1.cast(pdt.Date()).cast(pdt.String()), + y=t.col2.cast(pdt.Date()).cast(pdt.String()), + z=t.cdate.cast(pdt.String()), + ), + ) From b5fafe9010826558d47620aed9f97ce29e082ce6 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 14:59:16 +0200 Subject: [PATCH 21/25] make sqa_type and pdt_type in sql classmethods so that e.g. mssql can return DATETIME2 instead of DATETIME --- src/pydiverse/transform/backend/mssql.py | 17 +++- src/pydiverse/transform/backend/polars.py | 1 + src/pydiverse/transform/backend/postgres.py | 10 +-- src/pydiverse/transform/backend/sql.py | 98 ++++++++++----------- src/pydiverse/transform/backend/sqlite.py | 8 +- 5 files changed, 72 insertions(+), 62 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index c6c7da89..705ff4e7 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -8,7 +8,7 @@ from pydiverse.transform import ops from pydiverse.transform.backend import sql -from pydiverse.transform.backend.sql import SqlImpl, pdt_type_to_sqa +from pydiverse.transform.backend.sql import SqlImpl from pydiverse.transform.tree import dtypes, verbs from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( @@ -44,14 +44,14 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: (compiled_val.in_(("nan", "-nan")), cls.NAN), else_=sqa.cast( compiled_val, - pdt_type_to_sqa(cast.target_type), + cls.sqa_type(cast.target_type), ), ) if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( cast.target_type, dtypes.String ): - compiled = super().compile_cast(cast, sqa_col) + compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String) return sqa.case( (compiled == "1.#QNAN", "nan"), (compiled == "1.#INF", "inf"), @@ -59,7 +59,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: else_=compiled, ) - return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) @classmethod def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: @@ -90,6 +90,15 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) return cls.compile_query(table, query) + @classmethod + def sqa_type(cls, t: dtypes.Dtype): + if isinstance(t, dtypes.DateTime): + from sqlalchemy.dialects.mssql import DATETIME2 + + return DATETIME2() + + return super().sqa_type(t) + def convert_order_list(order_list: list[Order]) -> list[Order]: new_list: list[Order] = [] diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 1fee7844..9b032e95 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -189,6 +189,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: compiled = compile_col_expr(expr.val, name_in_df).cast( pdt_type_to_polars(expr.target_type) ) + if isinstance(expr.val.dtype(), dtypes.Float64) and isinstance( expr.target_type, dtypes.String ): diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index 3b6c2b94..28b84df1 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -13,14 +13,14 @@ class PostgresImpl(SqlImpl): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + if isinstance(cast.val.dtype(), dtypes.Float64): if isinstance(cast.target_type, dtypes.Int): - return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( - sqa.BigInteger() - ) + return sqa.func.trunc(compiled_val).cast(sqa.BigInteger()) if isinstance(cast.target_type, dtypes.String): - compiled = super().compile_cast(cast, sqa_col) + compiled = sqa.cast(compiled_val, sqa.String) return sqa.case( (compiled == "NaN", "nan"), (compiled == "Infinity", "inf"), @@ -28,7 +28,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: else_=compiled, ) - return super().compile_cast(cast, sqa_col) + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) with PostgresImpl.op(ops.Less()) as op: diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 2dc35385..42deaa4f 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -75,7 +75,7 @@ def __init__(self, table: str | sqa.Table, conf: SqlAlchemy, name: str | None): super().__init__( name, - {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns}, + {col.name: self.pdt_type(col.type) for col in self.table.columns}, ) def __init_subclass__(cls, **kwargs): @@ -86,7 +86,7 @@ def col_names(self) -> list[str]: return [col.name for col in self.table.columns] def schema(self) -> dict[str, Dtype]: - return {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns} + return {col.name: self.pdt_type(col.type) for col in self.table.columns} def _clone(self) -> tuple[SqlImpl, dict[AstNode, AstNode], dict[UUID, UUID]]: cloned = self.__class__(self.table, SqlAlchemy(self.engine), self.name) @@ -150,7 +150,7 @@ def compile_order( @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: return cls.compile_col_expr(cast.val, sqa_col).cast( - pdt_type_to_sqa(cast.target_type) + cls.sqa_type(cast.target_type) ) @classmethod @@ -488,6 +488,52 @@ def compile_ast( return table, query, sqa_col + @classmethod + def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: + if isinstance(t, dtypes.Int): + return sqa.BigInteger() + elif isinstance(t, dtypes.Float64): + return sqa.Double() + elif isinstance(t, dtypes.Decimal): + return sqa.DECIMAL() + elif isinstance(t, dtypes.String): + return sqa.String() + elif isinstance(t, dtypes.Bool): + return sqa.Boolean() + elif isinstance(t, dtypes.DateTime): + return sqa.DateTime() + elif isinstance(t, dtypes.Date): + return sqa.Date() + elif isinstance(t, dtypes.Duration): + return sqa.Interval() + elif isinstance(t, dtypes.NoneDtype): + return sqa.types.NullType() + + raise AssertionError + + @classmethod + def pdt_type(cls, t: sqa.types.TypeEngine) -> Dtype: + if isinstance(t, sqa.Integer): + return dtypes.Int() + elif isinstance(t, sqa.Float): + return dtypes.Float64() + elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): + return dtypes.Decimal() + elif isinstance(t, sqa.String): + return dtypes.String() + elif isinstance(t, sqa.Boolean): + return dtypes.Bool() + elif isinstance(t, sqa.DateTime): + return dtypes.DateTime() + elif isinstance(t, sqa.Date): + return dtypes.Date() + elif isinstance(t, sqa.Interval): + return dtypes.Duration() + elif isinstance(t, sqa.Null): + return dtypes.NoneDtype() + + raise TypeError(f"SQLAlchemy type {t} not supported by pydiverse.transform") + @dataclasses.dataclass(slots=True) class Query: @@ -567,52 +613,6 @@ def get_engine(nd: AstNode) -> sqa.Engine: return engine -def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: - if isinstance(t, sqa.Integer): - return dtypes.Int() - elif isinstance(t, sqa.Float): - return dtypes.Float64() - elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): - return dtypes.Decimal() - elif isinstance(t, sqa.String): - return dtypes.String() - elif isinstance(t, sqa.Boolean): - return dtypes.Bool() - elif isinstance(t, sqa.DateTime): - return dtypes.DateTime() - elif isinstance(t, sqa.Date): - return dtypes.Date() - elif isinstance(t, sqa.Interval): - return dtypes.Duration() - elif isinstance(t, sqa.Null): - return dtypes.NoneDtype() - - raise TypeError(f"SQLAlchemy type {t} not supported by pydiverse.transform") - - -def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: - if isinstance(t, dtypes.Int): - return sqa.BigInteger() - elif isinstance(t, dtypes.Float64): - return sqa.Double() - elif isinstance(t, dtypes.Decimal): - return sqa.DECIMAL() - elif isinstance(t, dtypes.String): - return sqa.String() - elif isinstance(t, dtypes.Bool): - return sqa.Boolean() - elif isinstance(t, dtypes.DateTime): - return sqa.DateTime() - elif isinstance(t, dtypes.Date): - return sqa.Date() - elif isinstance(t, dtypes.Duration): - return sqa.Interval() - elif isinstance(t, dtypes.NoneDtype): - return sqa.types.NullType() - - raise AssertionError - - with SqlImpl.op(ops.FloorDiv(), check_super=False) as op: if sqa.__version__ < "2": diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 88d21e4b..c611a9ba 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -3,7 +3,7 @@ import sqlalchemy as sqa from pydiverse.transform import ops -from pydiverse.transform.backend.sql import SqlImpl, pdt_type_to_sqa +from pydiverse.transform.backend.sql import SqlImpl from pydiverse.transform.tree import dtypes from pydiverse.transform.tree.col_expr import Cast from pydiverse.transform.util.warnings import warn_non_standard @@ -29,14 +29,14 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: (compiled_val.in_(("nan", "-nan")), cls.NAN), else_=sqa.cast( compiled_val, - pdt_type_to_sqa(cast.target_type), + cls.sqa_type(cast.target_type), ), ) elif isinstance(cast.val.dtype(), dtypes.DateTime) and isinstance( cast.target_type, dtypes.Date ): - return sqa.type_coerce(sqa.func.date(compiled_val), sqa.DATE()) + return sqa.type_coerce(sqa.func.date(compiled_val), sqa.Date()) elif isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( cast.target_type, dtypes.String @@ -47,7 +47,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: else_=sqa.cast(compiled_val, sqa.String), ) - return sqa.cast(compiled_val, pdt_type_to_sqa(cast.target_type)) + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) with SqliteImpl.op(ops.Round()) as op: From ed5fc1b21efe29c20eb0f2270e42cb218bf53d9d Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 15:20:27 +0200 Subject: [PATCH 22/25] implement str.to_datetime we currently only guarantee this to work for datetimes of the form YYYY-MM-DD HH:MM:SS.MILLIS (yes, exactly six digits after the comma, which makes sense since we guarantee all operations on datetimes on millisecond precision) --- src/pydiverse/transform/backend/mssql.py | 10 ++++++++-- src/pydiverse/transform/backend/polars.py | 14 ++++++++++++++ src/pydiverse/transform/backend/sql.py | 14 ++++++++++++++ src/pydiverse/transform/backend/sqlite.py | 15 +++++++++++++++ src/pydiverse/transform/ops/string.py | 12 ++++++++++++ src/pydiverse/transform/tree/col_expr.py | 13 ++++++++++++- src/pydiverse/transform/tree/dtypes.py | 2 ++ .../test_ops/test_cast.py | 4 ++++ 8 files changed, 81 insertions(+), 3 deletions(-) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 705ff4e7..8606c4d5 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -5,6 +5,7 @@ from typing import Any import sqlalchemy as sqa +from sqlalchemy.dialects.mssql import DATETIME2 from pydiverse.transform import ops from pydiverse.transform.backend import sql @@ -93,8 +94,6 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: @classmethod def sqa_type(cls, t: dtypes.Dtype): if isinstance(t, dtypes.DateTime): - from sqlalchemy.dialects.mssql import DATETIME2 - return DATETIME2() return super().sqa_type(t) @@ -362,3 +361,10 @@ def _log(x): @op.auto def _ceil(x): return sqa.func.ceiling(x) + + +with MsSqlImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, DATETIME2) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 9b032e95..142dfcea 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -734,3 +734,17 @@ def _floor(x): @op.auto def _ceil(x): return x.ceil() + + +with PolarsImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return x.str.to_datetime() + + +with PolarsImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_date(x): + return x.str.to_date() diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 42deaa4f..1e2d1d57 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -1012,3 +1012,17 @@ def _floor(x): @op.auto def _ceil(x): return sqa.func.ceil(x) + + +with SqlImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, sqa.DateTime) + + +with SqlImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, sqa.Date) diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index c611a9ba..79dd369e 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -140,3 +140,18 @@ def _least(*x): # TODO: Determine return type return sqa.func.coalesce(sqa.func.MIN(left, right), left, right) + + +# TODO: we need to get the string in the right format here (so sqlite can work with it) +with SqliteImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.type_coerce(x, sqa.DateTime) + + +with SqliteImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.type_coerce(x, sqa.Date) diff --git a/src/pydiverse/transform/ops/string.py b/src/pydiverse/transform/ops/string.py index fdd4ef5b..696bebdc 100644 --- a/src/pydiverse/transform/ops/string.py +++ b/src/pydiverse/transform/ops/string.py @@ -16,6 +16,8 @@ "StrEndsWith", "StrContains", "StrSlice", + "StrToDateTime", + "StrToDate", ] @@ -92,3 +94,13 @@ class StrContains(ElementWise, Logical): class StrSlice(ElementWise): name = "str.slice" signatures = ["str, int, int -> str"] + + +class StrToDateTime(ElementWise): + name = "str.to_datetime" + signatures = ["str -> datetime"] + + +class StrToDate(ElementWise): + name = "str.to_date" + signatures = ["str -> date"] diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 2a98f966..db1e369a 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -474,8 +474,19 @@ def dtype(self) -> Dtype: self.val.dtype().__class__, self.target_type.__class__, ) not in valid_casts: + hint = "" + if self.val.dtype() == dtypes.String and ( + (self.target_type == dtypes.DateTime) + or (self.target_type == dtypes.Date) + ): + hint = ( + "\nhint: to convert a str to datetime, call " + f"`.str.to_{self.target_type.name}()` on the expression." + ) + raise TypeError( - f"cannot cast type `{self.val.dtype()}` to `{self.target_type}`" + f"cannot cast type {self.val.dtype()} to {self.target_type}." + f"{hint}" ) return self._dtype diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index 0b532d60..5da753c2 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -14,6 +14,8 @@ def __init__(self, *, const: bool = False, vararg: bool = False): self.vararg = vararg def __eq__(self, other): + if type(self) is other: + return True if type(self) is not type(other): return False if self.const != other.const: diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index a9b23bb2..77a6f23d 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -71,6 +71,10 @@ def test_datetime_to_string(df_datetime): >> mutate( x=t.col1.cast(pdt.String()), y=t.col2.cast(pdt.String()), + ) + >> mutate( + x=C.x.str.to_datetime(), + y=C.y.str.to_datetime(), ), ) From 289d57fb03dd685e8a35bbfe3c35789ca7be9993 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 15:50:01 +0200 Subject: [PATCH 23/25] make dtype comparisons more readable we allow __eq__ between a dtype instance and a dtype class. When comparing with a class, const and vararg are ignored. --- src/pydiverse/transform/backend/duckdb.py | 4 +--- src/pydiverse/transform/backend/mssql.py | 10 +++------- src/pydiverse/transform/backend/polars.py | 6 ++---- src/pydiverse/transform/backend/sql.py | 1 - src/pydiverse/transform/backend/sqlite.py | 12 +++--------- src/pydiverse/transform/pipe/functions.py | 2 +- src/pydiverse/transform/pipe/verbs.py | 2 +- src/pydiverse/transform/tree/col_expr.py | 10 +++++----- src/pydiverse/transform/tree/dtypes.py | 15 +++++++++------ 9 files changed, 25 insertions(+), 37 deletions(-) diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index 2b4d68f7..d14b0363 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -26,9 +26,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: - if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( - cast.target_type, dtypes.Int - ): + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int: return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( sqa.BigInteger() ) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index 8606c4d5..b098bf5f 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -36,9 +36,7 @@ class MsSqlImpl(SqlImpl): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) - if isinstance(cast.val.dtype(), dtypes.String) and isinstance( - cast.target_type, dtypes.Float64 - ): + if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64: return sqa.case( (compiled_val == "inf", cls.INF), (compiled_val == "-inf", -cls.INF), @@ -49,9 +47,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) - if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( - cast.target_type, dtypes.String - ): + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String: compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String) return sqa.case( (compiled == "1.#QNAN", "nan"), @@ -137,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr ) elif isinstance(expr, Col): - if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool): + if not wants_bool_as_bit and expr.dtype() == dtypes.Bool: return ColFn("__eq__", expr, LiteralCol(True)) return expr diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 142dfcea..2c7494ef 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -181,7 +181,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: return compiled elif isinstance(expr, LiteralCol): - if isinstance(expr.dtype(), dtypes.String): + if expr.dtype() == dtypes.String: return pl.lit(expr.val) # polars interprets strings as column names return expr.val @@ -190,9 +190,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: pdt_type_to_polars(expr.target_type) ) - if isinstance(expr.val.dtype(), dtypes.Float64) and isinstance( - expr.target_type, dtypes.String - ): + if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String: compiled = compiled.replace("NaN", "nan") return compiled diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 1e2d1d57..b04aba0b 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -118,7 +118,6 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: sql_col.name: pdt_type_to_polars(col.dtype()) for sql_col, col in zip(sel.columns.values(), final_select) }, - infer_schema_length=0, ) df.name = nd.name return df diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 79dd369e..a39297ed 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -20,9 +20,7 @@ class SqliteImpl(SqlImpl): def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) - if isinstance(cast.val.dtype(), dtypes.String) and isinstance( - cast.target_type, dtypes.Float64 - ): + if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64: return sqa.case( (compiled_val == "inf", cls.INF), (compiled_val == "-inf", cls.NEG_INF), @@ -33,14 +31,10 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: ), ) - elif isinstance(cast.val.dtype(), dtypes.DateTime) and isinstance( - cast.target_type, dtypes.Date - ): + elif cast.val.dtype() == dtypes.DateTime and cast.target_type == dtypes.Date: return sqa.type_coerce(sqa.func.date(compiled_val), sqa.Date()) - elif isinstance(cast.val.dtype(), dtypes.Float64) and isinstance( - cast.target_type, dtypes.String - ): + elif cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String: return sqa.case( (compiled_val == cls.INF, "inf"), (compiled_val == cls.NEG_INF, "-inf"), diff --git a/src/pydiverse/transform/pipe/functions.py b/src/pydiverse/transform/pipe/functions.py index 6837ae0d..2fa28b8c 100644 --- a/src/pydiverse/transform/pipe/functions.py +++ b/src/pydiverse/transform/pipe/functions.py @@ -18,7 +18,7 @@ def clean_kwargs(**kwargs) -> dict[str, list[ColExpr]]: def when(condition: ColExpr) -> WhenClause: - if condition.dtype() is not None and not isinstance(condition.dtype(), dtypes.Bool): + if condition.dtype() is not None and condition.dtype() != dtypes.Bool: raise TypeError( "argument for `when` must be of boolean type, but has type " f"`{condition.dtype()}`" diff --git a/src/pydiverse/transform/pipe/verbs.py b/src/pydiverse/transform/pipe/verbs.py index 8348645e..9500a6e1 100644 --- a/src/pydiverse/transform/pipe/verbs.py +++ b/src/pydiverse/transform/pipe/verbs.py @@ -198,7 +198,7 @@ def filter(table: Table, *predicates: ColExpr): new._ast = Filter(table._ast, preprocess_arg(predicates, table)) for cond in new._ast.filters: - if not isinstance(cond.dtype(), dtypes.Bool): + if cond.dtype() != dtypes.Bool: raise TypeError( "predicates given to `filter` must be of boolean type.\n" f"hint: {cond} is of type {cond.dtype()} instead." diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index db1e369a..6b098e59 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -115,7 +115,7 @@ def __str__(self) -> str: ) def __hash__(self) -> int: - return hash(self.uuid) + return hash(self._uuid) class ColName(ColExpr): @@ -384,7 +384,7 @@ def dtype(self): raise TypeError(f"invalid case expression: {e}") from e for cond, _ in self.cases: - if cond.dtype() is not None and not isinstance(cond.dtype(), dtypes.Bool): + if cond.dtype() is not None and cond.dtype() != dtypes.Bool: raise TypeError( f"argument `{cond}` for `when` must be of boolean type, but has " f"type `{cond.dtype()}`" @@ -475,9 +475,9 @@ def dtype(self) -> Dtype: self.target_type.__class__, ) not in valid_casts: hint = "" - if self.val.dtype() == dtypes.String and ( - (self.target_type == dtypes.DateTime) - or (self.target_type == dtypes.Date) + if self.val.dtype() == dtypes.String and self.target_type in ( + dtypes.DateTime, + dtypes.Date, ): hint = ( "\nhint: to convert a str to datetime, call " diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index 5da753c2..28c71716 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -13,19 +13,22 @@ def __init__(self, *, const: bool = False, vararg: bool = False): self.const = const self.vararg = vararg - def __eq__(self, other): - if type(self) is other: + def __eq__(self, rhs): + if type(self) is rhs: return True - if type(self) is not type(other): + if type(self) is not type(rhs): return False - if self.const != other.const: + if self.const != rhs.const: return False - if self.vararg != other.vararg: + if self.vararg != rhs.vararg: return False - if self.name != other.name: + if self.name != rhs.name: return False return True + def __ne__(self, rhs: object) -> bool: + return not self.__eq__(rhs) + def __hash__(self): return hash((self.name, self.const, self.vararg, type(self).__qualname__)) From 081712cb524b26d46a78c2ec3edabd7da2b0c1fe Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 16:02:23 +0200 Subject: [PATCH 24/25] rename Int to Int64 for consistency with Float64 --- src/pydiverse/transform/backend/duckdb.py | 2 +- src/pydiverse/transform/backend/polars.py | 4 ++-- src/pydiverse/transform/backend/postgres.py | 2 +- src/pydiverse/transform/backend/sql.py | 4 ++-- src/pydiverse/transform/tree/col_expr.py | 6 +++--- src/pydiverse/transform/tree/dtypes.py | 8 ++++---- tests/test_backend_equivalence/test_ops/test_cast.py | 8 ++++---- tests/test_operator_registry.py | 10 +++++----- tests/test_polars_table.py | 6 +++--- 9 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index d14b0363..33f1fa78 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -26,7 +26,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]): @classmethod def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: - if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int: + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64: return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( sqa.BigInteger() ) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index 2c7494ef..783e62f0 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -351,7 +351,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: if t.is_float(): return dtypes.Float64() elif t.is_integer(): - return dtypes.Int() + return dtypes.Int64() elif isinstance(t, pl.Boolean): return dtypes.Bool() elif isinstance(t, pl.String): @@ -371,7 +371,7 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: if isinstance(t, (dtypes.Float64, dtypes.Decimal)): return pl.Float64() - elif isinstance(t, dtypes.Int): + elif isinstance(t, dtypes.Int64): return pl.Int64() elif isinstance(t, dtypes.Bool): return pl.Boolean() diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index 28b84df1..183ad3e0 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -16,7 +16,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: compiled_val = cls.compile_col_expr(cast.val, sqa_col) if isinstance(cast.val.dtype(), dtypes.Float64): - if isinstance(cast.target_type, dtypes.Int): + if isinstance(cast.target_type, dtypes.Int64): return sqa.func.trunc(compiled_val).cast(sqa.BigInteger()) if isinstance(cast.target_type, dtypes.String): diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index b04aba0b..8bae93c0 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -489,7 +489,7 @@ def compile_ast( @classmethod def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: - if isinstance(t, dtypes.Int): + if isinstance(t, dtypes.Int64): return sqa.BigInteger() elif isinstance(t, dtypes.Float64): return sqa.Double() @@ -513,7 +513,7 @@ def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: @classmethod def pdt_type(cls, t: sqa.types.TypeEngine) -> Dtype: if isinstance(t, sqa.Integer): - return dtypes.Int() + return dtypes.Int64() elif isinstance(t, sqa.Float): return dtypes.Float64() elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 6b098e59..c56a0f40 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -460,11 +460,11 @@ def dtype(self) -> Dtype: if not self.val.dtype().can_promote_to(self.target_type): valid_casts = { - (dtypes.String, dtypes.Int), + (dtypes.String, dtypes.Int64), (dtypes.String, dtypes.Float64), - (dtypes.Float64, dtypes.Int), + (dtypes.Float64, dtypes.Int64), (dtypes.DateTime, dtypes.Date), - (dtypes.Int, dtypes.String), + (dtypes.Int64, dtypes.String), (dtypes.Float64, dtypes.String), (dtypes.DateTime, dtypes.String), (dtypes.Date, dtypes.String), diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index 28c71716..b81fc8ba 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -73,8 +73,8 @@ def can_promote_to(self, other: Dtype) -> bool: return other.same_kind(self) -class Int(Dtype): - name = "int" +class Int64(Dtype): + name = "int64" def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): @@ -151,7 +151,7 @@ class NoneDtype(Dtype): def python_type_to_pdt(t: type) -> Dtype: if t is int: - return Int() + return Int64() elif t is float: return Float64() elif t is bool: @@ -200,7 +200,7 @@ def dtype_from_string(t: str) -> Dtype: return Template(base_type, const=is_const, vararg=is_vararg) if base_type == "int": - return Int(const=is_const, vararg=is_vararg) + return Int64(const=is_const, vararg=is_vararg) if base_type == "float64": return Float64(const=is_const, vararg=is_vararg) if base_type == "decimal": diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py index 77a6f23d..019edc86 100644 --- a/tests/test_backend_equivalence/test_ops/test_cast.py +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -17,25 +17,25 @@ def test_string_to_float(df_strings): def test_string_to_int(df_strings): assert_result_equal( df_strings, - lambda t: t >> mutate(u=t.d.cast(pdt.Int())), + lambda t: t >> mutate(u=t.d.cast(pdt.Int64())), ) def test_float_to_int(df_num): assert_result_equal( df_num, - lambda t: t >> mutate(**{col.name: col.cast(pdt.Int()) for col in t}), + lambda t: t >> mutate(**{col.name: col.cast(pdt.Int64()) for col in t}), ) assert_result_equal( df_num, - lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int())), + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int64())), exception=Exception, may_throw=True, ) assert_result_equal( df_num, - lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int())), + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int64())), exception=Exception, may_throw=True, ) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index addae7a6..558e2659 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -24,13 +24,13 @@ def assert_signature( class TestOperatorSignature: def test_parse_simple(self): s = OperatorSignature.parse("int, int -> int") - assert_signature(s, [dtypes.Int(), dtypes.Int()], dtypes.Int()) + assert_signature(s, [dtypes.Int64(), dtypes.Int64()], dtypes.Int64()) s = OperatorSignature.parse("bool->bool ") assert_signature(s, [dtypes.Bool()], dtypes.Bool()) s = OperatorSignature.parse("-> int") - assert_signature(s, [], dtypes.Int()) + assert_signature(s, [], dtypes.Int64()) with pytest.raises(ValueError): OperatorSignature.parse("int, int -> ") @@ -112,7 +112,7 @@ def test_simple(self): assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 assert isinstance( reg.get_impl("op1", parse_dtypes("int", "int")).return_type, - dtypes.Int, + dtypes.Int64, ) assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10 @@ -187,11 +187,11 @@ def test_template(self): ) assert isinstance( reg.get_impl("op3", parse_dtypes("int")).return_type, - dtypes.Int, + dtypes.Int64, ) assert isinstance( reg.get_impl("op3", parse_dtypes("int", "int", "float64")).return_type, - dtypes.Int, + dtypes.Int64, ) assert isinstance( reg.get_impl("op3", parse_dtypes("str", "int", "float64")).return_type, diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index 9571ff50..af2a48e5 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -116,11 +116,11 @@ def tbl_dt(): class TestPolarsLazyImpl: def test_dtype(self, tbl1, tbl2): - assert isinstance(tbl1.col1.dtype(), dtypes.Int) + assert isinstance(tbl1.col1.dtype(), dtypes.Int64) assert isinstance(tbl1.col2.dtype(), dtypes.String) - assert isinstance(tbl2.col1.dtype(), dtypes.Int) - assert isinstance(tbl2.col2.dtype(), dtypes.Int) + assert isinstance(tbl2.col1.dtype(), dtypes.Int64) + assert isinstance(tbl2.col2.dtype(), dtypes.Int64) assert isinstance(tbl2.col3.dtype(), dtypes.Float64) # test that column expression type errors are checked immediately From a3933e51ee4ed3baf00a7bb00dde36dd6d129695 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Mon, 30 Sep 2024 16:14:11 +0200 Subject: [PATCH 25/25] fix operator registry tests --- src/pydiverse/transform/ops/aggregate.py | 12 +-- src/pydiverse/transform/ops/datetime.py | 6 +- src/pydiverse/transform/ops/logical.py | 2 +- src/pydiverse/transform/ops/numeric.py | 28 +++--- src/pydiverse/transform/ops/string.py | 4 +- src/pydiverse/transform/ops/window.py | 10 +- src/pydiverse/transform/tree/dtypes.py | 4 +- src/pydiverse/transform/tree/registry.py | 8 +- tests/test_operator_registry.py | 118 +++++++++++------------ 9 files changed, 95 insertions(+), 97 deletions(-) diff --git a/src/pydiverse/transform/ops/aggregate.py b/src/pydiverse/transform/ops/aggregate.py index 9f536c94..7da557f6 100644 --- a/src/pydiverse/transform/ops/aggregate.py +++ b/src/pydiverse/transform/ops/aggregate.py @@ -16,7 +16,7 @@ class Min(Aggregate, Unary): name = "min" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", "str -> str", "datetime -> datetime", @@ -27,7 +27,7 @@ class Min(Aggregate, Unary): class Max(Aggregate, Unary): name = "max" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", "str -> str", "datetime -> datetime", @@ -38,7 +38,7 @@ class Max(Aggregate, Unary): class Mean(Aggregate, Unary): name = "mean" signatures = [ - "int -> float64", + "int64 -> float64", "float64 -> float64", ] @@ -46,7 +46,7 @@ class Mean(Aggregate, Unary): class Sum(Aggregate, Unary): name = "sum" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", ] @@ -68,6 +68,6 @@ class All(Aggregate, Unary): class Count(Aggregate): name = "count" signatures = [ - "-> int", - "T -> int", + "-> int64", + "T -> int64", ] diff --git a/src/pydiverse/transform/ops/datetime.py b/src/pydiverse/transform/ops/datetime.py index e7788791..4ee9c294 100644 --- a/src/pydiverse/transform/ops/datetime.py +++ b/src/pydiverse/transform/ops/datetime.py @@ -26,11 +26,11 @@ class DtExtract(ElementWise, Unary): - signatures = ["datetime -> int"] + signatures = ["datetime -> int64"] class DateExtract(ElementWise, Unary): - signatures = ["datetime -> int", "date -> int"] + signatures = ["datetime -> int64", "date -> int64"] class DtYear(DateExtract): @@ -70,7 +70,7 @@ class DtDayOfYear(DateExtract): class DurationToUnit(ElementWise, Unary): - signatures = ["duration -> int"] + signatures = ["duration -> int64"] class DtDays(DurationToUnit): diff --git a/src/pydiverse/transform/ops/logical.py b/src/pydiverse/transform/ops/logical.py index 1f36d0d2..11a5c814 100644 --- a/src/pydiverse/transform/ops/logical.py +++ b/src/pydiverse/transform/ops/logical.py @@ -35,7 +35,7 @@ class Logical(Operator): class Comparison(ElementWise, Binary, Logical): signatures = [ - "int, int -> bool", + "int64, int64 -> bool", "float64, float64 -> bool", "str, str -> bool", "bool, bool -> bool", diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 4c6e77e7..fc12a529 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -31,7 +31,7 @@ class Add(ElementWise, Binary): name = "__add__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", "float64, float64 -> float64", "decimal, decimal -> decimal", ] @@ -44,7 +44,7 @@ class RAdd(Add): class Sub(ElementWise, Binary): name = "__sub__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", "float64, float64 -> float64", "decimal, decimal -> decimal", ] @@ -57,7 +57,7 @@ class RSub(Sub): class Mul(ElementWise, Binary): name = "__mul__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", "float64, float64 -> float64", "decimal, decimal -> decimal", ] @@ -70,7 +70,7 @@ class RMul(Mul): class TrueDiv(ElementWise, Binary): name = "__truediv__" signatures = [ - "int, int -> float64", + "int64, int64 -> float64", "float64, float64 -> float64", "decimal, decimal -> decimal", ] @@ -83,7 +83,7 @@ class RTrueDiv(TrueDiv): class FloorDiv(ElementWise, Binary): name = "__floordiv__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", ] @@ -94,7 +94,7 @@ class RFloorDiv(FloorDiv): class Pow(ElementWise, Binary): name = "__pow__" signatures = [ - "int, int -> float64", + "int64, int64 -> float64", "float64, float64 -> float64", "decimal, decimal -> decimal", ] @@ -107,7 +107,7 @@ class RPow(Pow): class Mod(ElementWise, Binary): name = "__mod__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", ] @@ -118,7 +118,7 @@ class RMod(Mod): class Neg(ElementWise, Unary): name = "__neg__" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", "decimal -> decimal", ] @@ -127,7 +127,7 @@ class Neg(ElementWise, Unary): class Pos(ElementWise, Unary): name = "__pos__" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", "decimal -> decimal", ] @@ -136,7 +136,7 @@ class Pos(ElementWise, Unary): class Abs(ElementWise, Unary): name = "__abs__" signatures = [ - "int -> int", + "int64 -> int64", "float64 -> float64", "decimal -> decimal", ] @@ -145,12 +145,12 @@ class Abs(ElementWise, Unary): class Round(ElementWise): name = "__round__" signatures = [ - "int -> int", - "int, const int -> int", + "int64 -> int64", + "int64, const int64 -> int64", "float64 -> float64", - "float64, const int -> float64", + "float64, const int64 -> float64", "decimal -> decimal", - "decimal, const int -> decimal", + "decimal, const int64 -> decimal", ] diff --git a/src/pydiverse/transform/ops/string.py b/src/pydiverse/transform/ops/string.py index 696bebdc..7fecf1bb 100644 --- a/src/pydiverse/transform/ops/string.py +++ b/src/pydiverse/transform/ops/string.py @@ -51,7 +51,7 @@ class StrStrip(StrUnary): class StrLen(StrUnary): name = "str.len" signatures = [ - "str -> int", + "str -> int64", ] @@ -93,7 +93,7 @@ class StrContains(ElementWise, Logical): class StrSlice(ElementWise): name = "str.slice" - signatures = ["str, int, int -> str"] + signatures = ["str, int64, int64 -> str"] class StrToDateTime(ElementWise): diff --git a/src/pydiverse/transform/ops/window.py b/src/pydiverse/transform/ops/window.py index 65ccf052..cf3a1af2 100644 --- a/src/pydiverse/transform/ops/window.py +++ b/src/pydiverse/transform/ops/window.py @@ -13,27 +13,27 @@ class Shift(Window): name = "shift" signatures = [ - "T, const int -> T", - "T, const int, const T -> T", + "T, const int64 -> T", + "T, const int64, const T -> T", ] class RowNumber(Window, Nullary): name = "row_number" signatures = [ - "-> int", + "-> int64", ] class Rank(Window, Nullary): name = "rank" signatures = [ - "-> int", + "-> int64", ] class DenseRank(Window, Nullary): name = "dense_rank" signatures = [ - "-> int", + "-> int64", ] diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index b81fc8ba..9f5152a3 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -80,7 +80,7 @@ def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): return True - # int can be promoted to float64 + # int64 can be promoted to float64 if Float64().same_kind(other): if other.const and not self.const: return False @@ -199,7 +199,7 @@ def dtype_from_string(t: str) -> Dtype: if is_template: return Template(base_type, const=is_const, vararg=is_vararg) - if base_type == "int": + if base_type == "int64": return Int64(const=is_const, vararg=is_vararg) if base_type == "float64": return Float64(const=is_const, vararg=is_vararg) diff --git a/src/pydiverse/transform/tree/registry.py b/src/pydiverse/transform/tree/registry.py index 38a05e98..70637f6e 100644 --- a/src/pydiverse/transform/tree/registry.py +++ b/src/pydiverse/transform/tree/registry.py @@ -296,21 +296,21 @@ class OperatorSignature: terminal_arg ::= modifiers (dtype | vararg) vararg ::= dtype "..." rtype ::= dtype - dtype ::= template | "int" | "float64" | "str" | "bool" | and others... + dtype ::= template | "int64" | "float64" | "str" | "bool" | and others... modifiers ::= "const"? template ::= single uppercase character Examples: Function that takes two integers and returns an integer: - int, int -> int + int64, int64 -> int64 Templated argument (templates consist of single uppercase characters): T, T -> T T, U -> bool Variable number of arguments: - int... -> int + int64... -> int64 """ @@ -523,7 +523,7 @@ def does_match( elif not node.value.same_kind(dtype): # Needs type promotion # This only works when types can be promoted once - # -> (uint > int) wouldn't be preferred over (uint > int > float64) + # -> (uint > int64) wouldn't be preferred over (uint > int64 > float64) type_promotion_indices = (*type_promotion_indices, s_i) if s_i + 1 == len(signature): diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index 558e2659..edcde881 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -23,27 +23,27 @@ def assert_signature( class TestOperatorSignature: def test_parse_simple(self): - s = OperatorSignature.parse("int, int -> int") + s = OperatorSignature.parse("int64, int64 -> int64") assert_signature(s, [dtypes.Int64(), dtypes.Int64()], dtypes.Int64()) s = OperatorSignature.parse("bool->bool ") assert_signature(s, [dtypes.Bool()], dtypes.Bool()) - s = OperatorSignature.parse("-> int") + s = OperatorSignature.parse("-> int64") assert_signature(s, [], dtypes.Int64()) with pytest.raises(ValueError): - OperatorSignature.parse("int, int -> ") + OperatorSignature.parse("int64, int64 -> ") with pytest.raises(ValueError): - OperatorSignature.parse("int, int -> int, int") + OperatorSignature.parse("int64, int64 -> int64, int64") with pytest.raises(ValueError): - OperatorSignature.parse("o#r -> int") + OperatorSignature.parse("o#r -> int64") with pytest.raises(ValueError): - OperatorSignature.parse("int -> a#") + OperatorSignature.parse("int64 -> a#") def test_parse_template(self): - s = OperatorSignature.parse("T, int -> int") + s = OperatorSignature.parse("T, int64 -> int64") assert isinstance(s.args[0], dtypes.Template) s = OperatorSignature.parse("T -> T") @@ -54,26 +54,26 @@ def test_parse_template(self): OperatorSignature.parse("T, T -> U") def test_parse_varargs(self): - s = OperatorSignature.parse("int, str... -> int") + s = OperatorSignature.parse("int64, str... -> int64") assert not s.args[0].vararg assert s.args[1].vararg - s = OperatorSignature.parse("int... -> bool") + s = OperatorSignature.parse("int64... -> bool") assert s.args[0].vararg with pytest.raises(ValueError): - OperatorSignature.parse("int..., str -> int") + OperatorSignature.parse("int64..., str -> int64") with pytest.raises(ValueError): - OperatorSignature.parse("int, str -> int...") + OperatorSignature.parse("int64, str -> int64...") - s0 = OperatorSignature.parse("int, str -> int") - s1 = OperatorSignature.parse("int, str... -> int") + s0 = OperatorSignature.parse("int64, str -> int64") + s1 = OperatorSignature.parse("int64, str... -> int64") assert not s0.is_vararg assert s1.is_vararg def test_parse_const(self): - s = OperatorSignature.parse("const int -> int") + s = OperatorSignature.parse("const int64 -> int64") assert s.args[0].const assert not s.rtype.const @@ -103,33 +103,31 @@ def test_simple(self): reg.register_op(op1) reg.register_op(op2) - reg.add_impl(op1, lambda: 1, "int, int -> int") + reg.add_impl(op1, lambda: 1, "int64, int64 -> int64") reg.add_impl(op1, lambda: 2, "str, str -> str") - reg.add_impl(op2, lambda: 10, "int, int -> int") + reg.add_impl(op2, lambda: 10, "int64, int64 -> int64") reg.add_impl(op2, lambda: 20, "str, str -> str") - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 1 assert isinstance( - reg.get_impl("op1", parse_dtypes("int", "int")).return_type, + reg.get_impl("op1", parse_dtypes("int64", "int64")).return_type, dtypes.Int64, ) - assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10 + assert reg.get_impl("op2", parse_dtypes("int64", "int64"))() == 10 assert reg.get_impl("op1", parse_dtypes("str", "str"))() == 2 assert reg.get_impl("op2", parse_dtypes("str", "str"))() == 20 - with pytest.raises(ValueError): - reg.get_impl("op1", parse_dtypes("int", "str")) + with pytest.raises(TypeError): + reg.get_impl("op1", parse_dtypes("int64", "str")) with pytest.raises(ValueError): reg.get_impl( "not_implemented", - parse_dtypes( - "int", - ), + parse_dtypes("int64"), ) - reg.add_impl(op1, lambda: 100, "-> int") + reg.add_impl(op1, lambda: 100, "-> int64") assert reg.get_impl("op1", tuple())() == 100 def test_template(self): @@ -149,52 +147,52 @@ def test_template(self): with pytest.raises(ValueError, match="already defined"): reg.add_impl(op1, lambda: 3, "T, U -> U") - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 - assert reg.get_impl("op1", parse_dtypes("int", "str"))() == 2 - # int can be promoted to float; results in "float, float -> bool" signature - assert reg.get_impl("op1", parse_dtypes("int", "float64"))() == 1 - assert reg.get_impl("op1", parse_dtypes("float64", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "str"))() == 2 + # int64 can be promoted to float; results in "float, float -> bool" signature + assert reg.get_impl("op1", parse_dtypes("int64", "float64"))() == 1 + assert reg.get_impl("op1", parse_dtypes("float64", "int64"))() == 1 # More template matching... Also check matching precedence - reg.add_impl(op2, lambda: 1, "int, int, int -> int") - reg.add_impl(op2, lambda: 2, "int, str, T -> int") - reg.add_impl(op2, lambda: 3, "int, T, str -> int") - reg.add_impl(op2, lambda: 4, "int, T, T -> int") - reg.add_impl(op2, lambda: 5, "T, T, T -> int") - reg.add_impl(op2, lambda: 6, "A, T, T -> int") - - assert reg.get_impl("op2", parse_dtypes("int", "int", "int"))() == 1 - assert reg.get_impl("op2", parse_dtypes("int", "str", "str"))() == 2 - assert reg.get_impl("op2", parse_dtypes("int", "int", "str"))() == 3 - assert reg.get_impl("op2", parse_dtypes("int", "bool", "bool"))() == 4 + reg.add_impl(op2, lambda: 1, "int64, int64, int64 -> int64") + reg.add_impl(op2, lambda: 2, "int64, str, T -> int64") + reg.add_impl(op2, lambda: 3, "int64, T, str -> int64") + reg.add_impl(op2, lambda: 4, "int64, T, T -> int64") + reg.add_impl(op2, lambda: 5, "T, T, T -> int64") + reg.add_impl(op2, lambda: 6, "A, T, T -> int64") + + assert reg.get_impl("op2", parse_dtypes("int64", "int64", "int64"))() == 1 + assert reg.get_impl("op2", parse_dtypes("int64", "str", "str"))() == 2 + assert reg.get_impl("op2", parse_dtypes("int64", "int64", "str"))() == 3 + assert reg.get_impl("op2", parse_dtypes("int64", "bool", "bool"))() == 4 assert reg.get_impl("op2", parse_dtypes("str", "str", "str"))() == 5 assert reg.get_impl("op2", parse_dtypes("float64", "str", "str"))() == 6 - with pytest.raises(ValueError): - reg.get_impl("op2", parse_dtypes("int", "bool", "float64")) + with pytest.raises(TypeError): + reg.get_impl("op2", parse_dtypes("int64", "bool", "float64")) # Return type reg.add_impl(op3, lambda: 1, "T -> T") - reg.add_impl(op3, lambda: 2, "int, T, U -> T") + reg.add_impl(op3, lambda: 2, "int64, T, U -> T") reg.add_impl(op3, lambda: 3, "str, T, U -> U") with pytest.raises(ValueError, match="already defined."): - reg.add_impl(op3, lambda: 4, "int, T, U -> U") + reg.add_impl(op3, lambda: 4, "int64, T, U -> U") assert isinstance( reg.get_impl("op3", parse_dtypes("str")).return_type, dtypes.String, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("int")).return_type, + reg.get_impl("op3", parse_dtypes("int64")).return_type, dtypes.Int64, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("int", "int", "float64")).return_type, + reg.get_impl("op3", parse_dtypes("int64", "int64", "float64")).return_type, dtypes.Int64, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("str", "int", "float64")).return_type, + reg.get_impl("op3", parse_dtypes("str", "int64", "float64")).return_type, dtypes.Float64, ) @@ -204,25 +202,25 @@ def test_vararg(self): op1 = self.Op1() reg.register_op(op1) - reg.add_impl(op1, lambda: 1, "int... -> int") - reg.add_impl(op1, lambda: 2, "int, int... -> int") - reg.add_impl(op1, lambda: 3, "int, T... -> T") + reg.add_impl(op1, lambda: 1, "int64... -> int64") + reg.add_impl(op1, lambda: 2, "int64, int64... -> int64") + reg.add_impl(op1, lambda: 3, "int64, T... -> T") assert ( reg.get_impl( "op1", parse_dtypes( - "int", + "int64", ), )() == 1 ) - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 2 - assert reg.get_impl("op1", parse_dtypes("int", "int", "int"))() == 2 - assert reg.get_impl("op1", parse_dtypes("int", "str", "str"))() == 3 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int64", "int64", "int64"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int64", "str", "str"))() == 3 assert isinstance( - reg.get_impl("op1", parse_dtypes("int", "str", "str")).return_type, + reg.get_impl("op1", parse_dtypes("int64", "str", "str")).return_type, dtypes.String, ) @@ -233,13 +231,13 @@ def test_variant(self): reg.register_op(op1) with pytest.raises(ValueError): - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR") - reg.add_impl(op1, lambda: 1, "-> int") - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 1, "-> int64") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR") assert reg.get_impl("op1", tuple())() == 1 assert reg.get_impl("op1", tuple()).get_variant("VAR")() == 2 with pytest.raises(ValueError): - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR")