diff --git a/src/pydiverse/transform/backend/duckdb.py b/src/pydiverse/transform/backend/duckdb.py index 5937e3e5..33f1fa78 100644 --- a/src/pydiverse/transform/backend/duckdb.py +++ b/src/pydiverse/transform/backend/duckdb.py @@ -1,12 +1,14 @@ from __future__ import annotations import polars as pl +import sqlalchemy as sqa from pydiverse.transform.backend import sql from pydiverse.transform.backend.sql import SqlImpl from pydiverse.transform.backend.targets import Polars, Target +from pydiverse.transform.tree import dtypes from pydiverse.transform.tree.ast import AstNode -from pydiverse.transform.tree.col_expr import Col +from pydiverse.transform.tree.col_expr import Cast, Col class DuckDbImpl(SqlImpl): @@ -21,3 +23,11 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]): DuckDbImpl.build_query(nd, final_select), connection=conn ) return SqlImpl.export(nd, target, final_select) + + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64: + return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast( + sqa.BigInteger() + ) + return super().compile_cast(cast, sqa_col) diff --git a/src/pydiverse/transform/backend/mssql.py b/src/pydiverse/transform/backend/mssql.py index e00a6309..b098bf5f 100644 --- a/src/pydiverse/transform/backend/mssql.py +++ b/src/pydiverse/transform/backend/mssql.py @@ -5,6 +5,7 @@ from typing import Any import sqlalchemy as sqa +from sqlalchemy.dialects.mssql import DATETIME2 from pydiverse.transform import ops from pydiverse.transform.backend import sql @@ -13,6 +14,7 @@ from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -25,6 +27,37 @@ class MsSqlImpl(SqlImpl): dialect_name = "mssql" + INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal( + "0.0", type_=sqa.Float() + ) + NEG_INF = -INF + NAN = INF + NEG_INF + + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64: + return sqa.case( + (compiled_val == "inf", cls.INF), + (compiled_val == "-inf", -cls.INF), + (compiled_val.in_(("nan", "-nan")), cls.NAN), + else_=sqa.cast( + compiled_val, + cls.sqa_type(cast.target_type), + ), + ) + + if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String: + compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String) + return sqa.case( + (compiled == "1.#QNAN", "nan"), + (compiled == "1.#INF", "inf"), + (compiled == "-1.#INF", "-inf"), + else_=compiled, + ) + + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) + @classmethod def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: # boolean / bit conversion @@ -54,6 +87,13 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any: table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select}) return cls.compile_query(table, query) + @classmethod + def sqa_type(cls, t: dtypes.Dtype): + if isinstance(t, dtypes.DateTime): + return DATETIME2() + + return super().sqa_type(t) + def convert_order_list(order_list: list[Order]) -> list[Order]: new_list: list[Order] = [] @@ -93,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr ) elif isinstance(expr, Col): - if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool): + if not wants_bool_as_bit and expr.dtype() == dtypes.Bool: return ColFn("__eq__", expr, LiteralCol(True)) return expr @@ -146,6 +186,14 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr elif isinstance(expr, LiteralCol): return expr + elif isinstance(expr, Cast): + # TODO: does this really work for casting onto / from booleans? we probably have + # to use wants_bool_as_bit in some way when casting to bool + return Cast( + convert_bool_bit(expr.val, wants_bool_as_bit=wants_bool_as_bit), + expr.target_type, + ) + raise AssertionError @@ -289,3 +337,30 @@ def _day_of_week(x): @op.auto def _mean(x): return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double()) + + +with MsSqlImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + # TODO: we still need to handle inf / -inf / nan + return sqa.case( + (x > 0, sqa.func.log(x)), + (x < 0, MsSqlImpl.NAN), + (x.is_(sqa.null()), None), + else_=-MsSqlImpl.INF, + ) + + +with MsSqlImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return sqa.func.ceiling(x) + + +with MsSqlImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, DATETIME2) diff --git a/src/pydiverse/transform/backend/polars.py b/src/pydiverse/transform/backend/polars.py index c2ed6e8a..783e62f0 100644 --- a/src/pydiverse/transform/backend/polars.py +++ b/src/pydiverse/transform/backend/polars.py @@ -1,7 +1,5 @@ from __future__ import annotations -import datetime -from types import NoneType from typing import Any from uuid import UUID @@ -15,6 +13,7 @@ from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -159,7 +158,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: # the function was executed on the ordered arguments. here we # restore the original order of the table. - inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by( + inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64()).sort_by( by=order_by, descending=descending, nulls_last=nulls_last, @@ -182,10 +181,20 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr: return compiled elif isinstance(expr, LiteralCol): - if isinstance(expr.dtype(), dtypes.String): + if expr.dtype() == dtypes.String: return pl.lit(expr.val) # polars interprets strings as column names return expr.val + elif isinstance(expr, Cast): + compiled = compile_col_expr(expr.val, name_in_df).cast( + pdt_type_to_polars(expr.target_type) + ) + + if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String: + compiled = compiled.replace("NaN", "nan") + + return compiled + else: raise AssertionError @@ -340,9 +349,9 @@ def has_path_to_leaf_without_agg(expr: ColExpr): def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: if t.is_float(): - return dtypes.Float() + return dtypes.Float64() elif t.is_integer(): - return dtypes.Int() + return dtypes.Int64() elif isinstance(t, pl.Boolean): return dtypes.Bool() elif isinstance(t, pl.String): @@ -360,9 +369,9 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype: def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: - if isinstance(t, dtypes.Float): + if isinstance(t, (dtypes.Float64, dtypes.Decimal)): return pl.Float64() - elif isinstance(t, dtypes.Int): + elif isinstance(t, dtypes.Int64): return pl.Int64() elif isinstance(t, dtypes.Bool): return pl.Boolean() @@ -380,27 +389,6 @@ def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType: raise AssertionError -def python_type_to_polars(t: type) -> pl.DataType: - if t is int: - return pl.Int64() - elif t is float: - return pl.Float64() - elif t is bool: - return pl.Boolean() - elif t is str: - return pl.String() - elif t is datetime.datetime: - return pl.Datetime() - elif t is datetime.date: - return pl.Date() - elif t is datetime.timedelta: - return pl.Duration() - elif t is NoneType: - return pl.Null() - - raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform") - - with PolarsImpl.op(ops.Mean()) as op: @op.auto @@ -709,3 +697,52 @@ def _greatest(*x): @op.auto def _least(*x): return pl.min_horizontal(*x) + + +with PolarsImpl.op(ops.Round()) as op: + + @op.auto + def _round(x, digits=0): + return x.round(digits) + + +with PolarsImpl.op(ops.Exp()) as op: + + @op.auto + def _exp(x): + return x.exp() + + +with PolarsImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + return x.log() + + +with PolarsImpl.op(ops.Floor()) as op: + + @op.auto + def _floor(x): + return x.floor() + + +with PolarsImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return x.ceil() + + +with PolarsImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return x.str.to_datetime() + + +with PolarsImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_date(x): + return x.str.to_date() diff --git a/src/pydiverse/transform/backend/postgres.py b/src/pydiverse/transform/backend/postgres.py index b83a32db..183ad3e0 100644 --- a/src/pydiverse/transform/backend/postgres.py +++ b/src/pydiverse/transform/backend/postgres.py @@ -4,11 +4,32 @@ from pydiverse.transform import ops from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.col_expr import Cast class PostgresImpl(SqlImpl): dialect_name = "postgresql" + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + + if isinstance(cast.val.dtype(), dtypes.Float64): + if isinstance(cast.target_type, dtypes.Int64): + return sqa.func.trunc(compiled_val).cast(sqa.BigInteger()) + + if isinstance(cast.target_type, dtypes.String): + compiled = sqa.cast(compiled_val, sqa.String) + return sqa.case( + (compiled == "NaN", "nan"), + (compiled == "Infinity", "inf"), + (compiled == "-Infinity", "-inf"), + else_=compiled, + ) + + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) + with PostgresImpl.op(ops.Less()) as op: diff --git a/src/pydiverse/transform/backend/sql.py b/src/pydiverse/transform/backend/sql.py index 05be684c..8bae93c0 100644 --- a/src/pydiverse/transform/backend/sql.py +++ b/src/pydiverse/transform/backend/sql.py @@ -4,6 +4,7 @@ import functools import inspect import itertools +import math import operator from collections.abc import Iterable from typing import Any @@ -21,6 +22,7 @@ from pydiverse.transform.tree.ast import AstNode from pydiverse.transform.tree.col_expr import ( CaseExpr, + Cast, Col, ColExpr, ColFn, @@ -33,6 +35,10 @@ class SqlImpl(TableImpl): Dialects: dict[str, type[TableImpl]] = {} + INF = sqa.cast(sqa.literal("inf"), sqa.Float()) + NEG_INF = sqa.cast(sqa.literal("-inf"), sqa.Float()) + NAN = sqa.cast(sqa.literal("nan"), sqa.Float()) + def __new__(cls, *args, **kwargs) -> SqlImpl: engine: str | sqa.Engine = ( inspect.signature(cls.__init__) @@ -69,7 +75,7 @@ def __init__(self, table: str | sqa.Table, conf: SqlAlchemy, name: str | None): super().__init__( name, - {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns}, + {col.name: self.pdt_type(col.type) for col in self.table.columns}, ) def __init_subclass__(cls, **kwargs): @@ -80,7 +86,7 @@ def col_names(self) -> list[str]: return [col.name for col in self.table.columns] def schema(self) -> dict[str, Dtype]: - return {col.name: sqa_type_to_pdt(col.type) for col in self.table.columns} + return {col.name: self.pdt_type(col.type) for col in self.table.columns} def _clone(self) -> tuple[SqlImpl, dict[AstNode, AstNode], dict[UUID, UUID]]: cloned = self.__class__(self.table, SqlAlchemy(self.engine), self.name) @@ -140,6 +146,12 @@ def compile_order( ) return order_expr + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + return cls.compile_col_expr(cast.val, sqa_col).cast( + cls.sqa_type(cast.target_type) + ) + @classmethod def compile_col_expr( cls, expr: ColExpr, sqa_col: dict[str, sqa.Label] @@ -205,8 +217,16 @@ def compile_col_expr( ) elif isinstance(expr, LiteralCol): + if isinstance(expr.val, float): + if math.isnan(expr.val): + return cls.NAN + elif math.isinf(expr.val): + return cls.INF if expr.val > 0 else cls.NEG_INF return expr.val + elif isinstance(expr, Cast): + return cls.compile_cast(expr, sqa_col) + raise AssertionError @classmethod @@ -467,6 +487,52 @@ def compile_ast( return table, query, sqa_col + @classmethod + def sqa_type(cls, t: Dtype) -> sqa.types.TypeEngine: + if isinstance(t, dtypes.Int64): + return sqa.BigInteger() + elif isinstance(t, dtypes.Float64): + return sqa.Double() + elif isinstance(t, dtypes.Decimal): + return sqa.DECIMAL() + elif isinstance(t, dtypes.String): + return sqa.String() + elif isinstance(t, dtypes.Bool): + return sqa.Boolean() + elif isinstance(t, dtypes.DateTime): + return sqa.DateTime() + elif isinstance(t, dtypes.Date): + return sqa.Date() + elif isinstance(t, dtypes.Duration): + return sqa.Interval() + elif isinstance(t, dtypes.NoneDtype): + return sqa.types.NullType() + + raise AssertionError + + @classmethod + def pdt_type(cls, t: sqa.types.TypeEngine) -> Dtype: + if isinstance(t, sqa.Integer): + return dtypes.Int64() + elif isinstance(t, sqa.Float): + return dtypes.Float64() + elif isinstance(t, (sqa.DECIMAL, sqa.NUMERIC)): + return dtypes.Decimal() + elif isinstance(t, sqa.String): + return dtypes.String() + elif isinstance(t, sqa.Boolean): + return dtypes.Bool() + elif isinstance(t, sqa.DateTime): + return dtypes.DateTime() + elif isinstance(t, sqa.Date): + return dtypes.Date() + elif isinstance(t, sqa.Interval): + return dtypes.Duration() + elif isinstance(t, sqa.Null): + return dtypes.NoneDtype() + + raise TypeError(f"SQLAlchemy type {t} not supported by pydiverse.transform") + @dataclasses.dataclass(slots=True) class Query: @@ -546,48 +612,6 @@ def get_engine(nd: AstNode) -> sqa.Engine: return engine -def sqa_type_to_pdt(t: sqa.types.TypeEngine) -> Dtype: - if isinstance(t, sqa.Integer): - return dtypes.Int() - elif isinstance(t, sqa.Numeric): - return dtypes.Float() - elif isinstance(t, sqa.String): - return dtypes.String() - elif isinstance(t, sqa.Boolean): - return dtypes.Bool() - elif isinstance(t, sqa.DateTime): - return dtypes.DateTime() - elif isinstance(t, sqa.Date): - return dtypes.Date() - elif isinstance(t, sqa.Interval): - return dtypes.Duration() - elif isinstance(t, sqa.Null): - return dtypes.NoneDtype() - - raise TypeError(f"SQLAlchemy type {t} not supported by pydiverse.transform") - - -def pdt_type_to_sqa(t: Dtype) -> sqa.types.TypeEngine: - if isinstance(t, dtypes.Int): - return sqa.Integer() - elif isinstance(t, dtypes.Float): - return sqa.Numeric() - elif isinstance(t, dtypes.String): - return sqa.String() - elif isinstance(t, dtypes.Bool): - return sqa.Boolean() - elif isinstance(t, dtypes.DateTime): - return sqa.DateTime() - elif isinstance(t, dtypes.Date): - return sqa.Date() - elif isinstance(t, dtypes.Duration): - return sqa.Interval() - elif isinstance(t, dtypes.NoneDtype): - return sqa.types.NullType() - - raise AssertionError - - with SqlImpl.op(ops.FloorDiv(), check_super=False) as op: if sqa.__version__ < "2": @@ -686,9 +710,6 @@ def _is_not_null(x): return x.is_not(sqa.null()) -#### String Functions #### - - with SqlImpl.op(ops.StrStrip()) as op: @op.auto @@ -753,9 +774,6 @@ def _str_slice(x, offset, length): return sqa.func.SUBSTR(x, offset + 1, length) -#### Datetime Functions #### - - with SqlImpl.op(ops.DtYear()) as op: @op.auto @@ -819,9 +837,6 @@ def _day_of_year(x): return sqa.extract("doy", x) -#### Generic Functions #### - - with SqlImpl.op(ops.Greatest()) as op: @op.auto @@ -838,9 +853,6 @@ def _least(*x): return sqa.func.LEAST(*x) -#### Summarising Functions #### - - with SqlImpl.op(ops.Mean()) as op: @op.auto @@ -919,9 +931,6 @@ def _count(x=None): return sqa.func.count(x) -#### Window Functions #### - - with SqlImpl.op(ops.Shift()) as op: @op.auto @@ -968,3 +977,51 @@ def _rank(): @op.auto def _dense_rank(): return sqa.func.dense_rank() + + +with SqlImpl.op(ops.Exp()) as op: + + @op.auto + def _exp(x): + return sqa.func.exp(x) + + +with SqlImpl.op(ops.Log()) as op: + + @op.auto + def _log(x): + # TODO: we still need to handle inf / -inf / nan + return sqa.case( + (x > 0, sqa.func.ln(x)), + (x < 0, sqa.literal("nan")), + (x.is_(sqa.null()), None), + else_=sqa.literal("-inf"), + ) + + +with SqlImpl.op(ops.Floor()) as op: + + @op.auto + def _floor(x): + return sqa.func.floor(x) + + +with SqlImpl.op(ops.Ceil()) as op: + + @op.auto + def _ceil(x): + return sqa.func.ceil(x) + + +with SqlImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, sqa.DateTime) + + +with SqlImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.cast(x, sqa.Date) diff --git a/src/pydiverse/transform/backend/sqlite.py b/src/pydiverse/transform/backend/sqlite.py index 1dc14f07..a39297ed 100644 --- a/src/pydiverse/transform/backend/sqlite.py +++ b/src/pydiverse/transform/backend/sqlite.py @@ -4,12 +4,45 @@ from pydiverse.transform import ops from pydiverse.transform.backend.sql import SqlImpl +from pydiverse.transform.tree import dtypes +from pydiverse.transform.tree.col_expr import Cast from pydiverse.transform.util.warnings import warn_non_standard class SqliteImpl(SqlImpl): dialect_name = "sqlite" + INF = sqa.cast(sqa.literal("1e314"), sqa.Float) + NEG_INF = -INF + NAN = sqa.null() + + @classmethod + def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast: + compiled_val = cls.compile_col_expr(cast.val, sqa_col) + + if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64: + return sqa.case( + (compiled_val == "inf", cls.INF), + (compiled_val == "-inf", cls.NEG_INF), + (compiled_val.in_(("nan", "-nan")), cls.NAN), + else_=sqa.cast( + compiled_val, + cls.sqa_type(cast.target_type), + ), + ) + + elif cast.val.dtype() == dtypes.DateTime and cast.target_type == dtypes.Date: + return sqa.type_coerce(sqa.func.date(compiled_val), sqa.Date()) + + elif cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String: + return sqa.case( + (compiled_val == cls.INF, "inf"), + (compiled_val == cls.NEG_INF, "-inf"), + else_=sqa.cast(compiled_val, sqa.String), + ) + + return sqa.cast(compiled_val, cls.sqa_type(cast.target_type)) + with SqliteImpl.op(ops.Round()) as op: @@ -101,3 +134,18 @@ def _least(*x): # TODO: Determine return type return sqa.func.coalesce(sqa.func.MIN(left, right), left, right) + + +# TODO: we need to get the string in the right format here (so sqlite can work with it) +with SqliteImpl.op(ops.StrToDateTime()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.type_coerce(x, sqa.DateTime) + + +with SqliteImpl.op(ops.StrToDate()) as op: + + @op.auto + def _str_to_datetime(x): + return sqa.type_coerce(x, sqa.Date) diff --git a/src/pydiverse/transform/ops/aggregate.py b/src/pydiverse/transform/ops/aggregate.py index 2261142e..7da557f6 100644 --- a/src/pydiverse/transform/ops/aggregate.py +++ b/src/pydiverse/transform/ops/aggregate.py @@ -16,8 +16,8 @@ class Min(Aggregate, Unary): name = "min" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", "str -> str", "datetime -> datetime", "date -> date", @@ -27,8 +27,8 @@ class Min(Aggregate, Unary): class Max(Aggregate, Unary): name = "max" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", "str -> str", "datetime -> datetime", "date -> date", @@ -38,16 +38,16 @@ class Max(Aggregate, Unary): class Mean(Aggregate, Unary): name = "mean" signatures = [ - "int -> float", - "float -> float", + "int64 -> float64", + "float64 -> float64", ] class Sum(Aggregate, Unary): name = "sum" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", ] @@ -68,6 +68,6 @@ class All(Aggregate, Unary): class Count(Aggregate): name = "count" signatures = [ - "-> int", - "T -> int", + "-> int64", + "T -> int64", ] diff --git a/src/pydiverse/transform/ops/datetime.py b/src/pydiverse/transform/ops/datetime.py index e7788791..4ee9c294 100644 --- a/src/pydiverse/transform/ops/datetime.py +++ b/src/pydiverse/transform/ops/datetime.py @@ -26,11 +26,11 @@ class DtExtract(ElementWise, Unary): - signatures = ["datetime -> int"] + signatures = ["datetime -> int64"] class DateExtract(ElementWise, Unary): - signatures = ["datetime -> int", "date -> int"] + signatures = ["datetime -> int64", "date -> int64"] class DtYear(DateExtract): @@ -70,7 +70,7 @@ class DtDayOfYear(DateExtract): class DurationToUnit(ElementWise, Unary): - signatures = ["duration -> int"] + signatures = ["duration -> int64"] class DtDays(DurationToUnit): diff --git a/src/pydiverse/transform/ops/logical.py b/src/pydiverse/transform/ops/logical.py index f4c1ef9c..11a5c814 100644 --- a/src/pydiverse/transform/ops/logical.py +++ b/src/pydiverse/transform/ops/logical.py @@ -35,8 +35,8 @@ class Logical(Operator): class Comparison(ElementWise, Binary, Logical): signatures = [ - "int, int -> bool", - "float, float -> bool", + "int64, int64 -> bool", + "float64, float64 -> bool", "str, str -> bool", "bool, bool -> bool", "datetime, datetime -> bool", diff --git a/src/pydiverse/transform/ops/numeric.py b/src/pydiverse/transform/ops/numeric.py index 4c0c4350..fc12a529 100644 --- a/src/pydiverse/transform/ops/numeric.py +++ b/src/pydiverse/transform/ops/numeric.py @@ -21,16 +21,19 @@ "Pos", "Abs", "Round", + "Floor", + "Ceil", + "Exp", + "Log", ] class Add(ElementWise, Binary): name = "__add__" signatures = [ - "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "int64, int64 -> int64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -41,10 +44,9 @@ class RAdd(Add): class Sub(ElementWise, Binary): name = "__sub__" signatures = [ - "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "int64, int64 -> int64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -55,10 +57,9 @@ class RSub(Sub): class Mul(ElementWise, Binary): name = "__mul__" signatures = [ - "int, int -> int", - "int, float -> float", - "float, int -> float", - "float, float -> float", + "int64, int64 -> int64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -69,8 +70,9 @@ class RMul(Mul): class TrueDiv(ElementWise, Binary): name = "__truediv__" signatures = [ - "int, int -> float", - "float, float -> float", + "int64, int64 -> float64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -81,7 +83,7 @@ class RTrueDiv(TrueDiv): class FloorDiv(ElementWise, Binary): name = "__floordiv__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", ] @@ -92,8 +94,9 @@ class RFloorDiv(FloorDiv): class Pow(ElementWise, Binary): name = "__pow__" signatures = [ - "int, int -> float", - "float, float -> float", + "int64, int64 -> float64", + "float64, float64 -> float64", + "decimal, decimal -> decimal", ] @@ -104,7 +107,7 @@ class RPow(Pow): class Mod(ElementWise, Binary): name = "__mod__" signatures = [ - "int, int -> int", + "int64, int64 -> int64", ] @@ -115,32 +118,59 @@ class RMod(Mod): class Neg(ElementWise, Unary): name = "__neg__" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", + "decimal -> decimal", ] class Pos(ElementWise, Unary): name = "__pos__" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", + "decimal -> decimal", ] class Abs(ElementWise, Unary): name = "__abs__" signatures = [ - "int -> int", - "float -> float", + "int64 -> int64", + "float64 -> float64", + "decimal -> decimal", ] class Round(ElementWise): name = "__round__" signatures = [ - "int -> int", - "int, const int -> int", - "float -> float", - "float, const int -> float", + "int64 -> int64", + "int64, const int64 -> int64", + "float64 -> float64", + "float64, const int64 -> float64", + "decimal -> decimal", + "decimal, const int64 -> decimal", ] + + +class Floor(ElementWise): + name = "floor" + signatures = [ + "float64 -> float64", + "decimal -> decimal", + ] + + +class Ceil(Floor): + name = "ceil" + + +class Log(ElementWise): + name = "log" + signatures = ["float64 -> float64"] + + +class Exp(Log): + name = "exp" + signatures = ["float64 -> float64"] diff --git a/src/pydiverse/transform/ops/string.py b/src/pydiverse/transform/ops/string.py index fdd4ef5b..7fecf1bb 100644 --- a/src/pydiverse/transform/ops/string.py +++ b/src/pydiverse/transform/ops/string.py @@ -16,6 +16,8 @@ "StrEndsWith", "StrContains", "StrSlice", + "StrToDateTime", + "StrToDate", ] @@ -49,7 +51,7 @@ class StrStrip(StrUnary): class StrLen(StrUnary): name = "str.len" signatures = [ - "str -> int", + "str -> int64", ] @@ -91,4 +93,14 @@ class StrContains(ElementWise, Logical): class StrSlice(ElementWise): name = "str.slice" - signatures = ["str, int, int -> str"] + signatures = ["str, int64, int64 -> str"] + + +class StrToDateTime(ElementWise): + name = "str.to_datetime" + signatures = ["str -> datetime"] + + +class StrToDate(ElementWise): + name = "str.to_date" + signatures = ["str -> date"] diff --git a/src/pydiverse/transform/ops/window.py b/src/pydiverse/transform/ops/window.py index 65ccf052..cf3a1af2 100644 --- a/src/pydiverse/transform/ops/window.py +++ b/src/pydiverse/transform/ops/window.py @@ -13,27 +13,27 @@ class Shift(Window): name = "shift" signatures = [ - "T, const int -> T", - "T, const int, const T -> T", + "T, const int64 -> T", + "T, const int64, const T -> T", ] class RowNumber(Window, Nullary): name = "row_number" signatures = [ - "-> int", + "-> int64", ] class Rank(Window, Nullary): name = "rank" signatures = [ - "-> int", + "-> int64", ] class DenseRank(Window, Nullary): name = "dense_rank" signatures = [ - "-> int", + "-> int64", ] diff --git a/src/pydiverse/transform/pipe/functions.py b/src/pydiverse/transform/pipe/functions.py index 6837ae0d..2fa28b8c 100644 --- a/src/pydiverse/transform/pipe/functions.py +++ b/src/pydiverse/transform/pipe/functions.py @@ -18,7 +18,7 @@ def clean_kwargs(**kwargs) -> dict[str, list[ColExpr]]: def when(condition: ColExpr) -> WhenClause: - if condition.dtype() is not None and not isinstance(condition.dtype(), dtypes.Bool): + if condition.dtype() is not None and condition.dtype() != dtypes.Bool: raise TypeError( "argument for `when` must be of boolean type, but has type " f"`{condition.dtype()}`" diff --git a/src/pydiverse/transform/pipe/table.py b/src/pydiverse/transform/pipe/table.py index 503d1dbb..1b58c7c3 100644 --- a/src/pydiverse/transform/pipe/table.py +++ b/src/pydiverse/transform/pipe/table.py @@ -2,22 +2,18 @@ import copy import dataclasses -from collections.abc import Iterable +import inspect +from collections.abc import Callable, Iterable from html import escape import sqlalchemy as sqa from pydiverse.transform.backend.table_impl import TableImpl +from pydiverse.transform.pipe.pipeable import Pipeable from pydiverse.transform.tree.ast import AstNode -from pydiverse.transform.tree.col_expr import ( - Col, - ColExpr, -) +from pydiverse.transform.tree.col_expr import Col -# TODO: if we decide that select controls the C-space, the columns in _select will -# always be the same as those that we have to keep in _schema. However, we still need -# _select for the order. class Table: __slots__ = ["_ast", "_cache"] @@ -26,7 +22,8 @@ class Table: which is a reference to the underlying abstract syntax tree. """ - # TODO: define exactly what can be given for the two + # TODO: define exactly what can be given for the two and do type checks + # maybe call the second one execution_engine or similar? def __init__(self, resource, backend=None, *, name: str | None = None): import polars as pl @@ -40,6 +37,8 @@ def __init__(self, resource, backend=None, *, name: str | None = None): self._ast: AstNode = resource elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)): if name is None: + # TODO: we could look whether the df has a name attr set (which is the + # case if it was previously exported) name = "?" self._ast = PolarsImpl(name, resource) elif isinstance(resource, (str, sqa.Table)): @@ -77,13 +76,32 @@ def __setstate__(self, d): # to avoid very annoying AttributeErrors for slot, val in d[1].items(): setattr(self, slot, val) - def __iter__(self) -> Iterable[ColExpr]: + def __iter__(self) -> Iterable[Col]: cols = copy.copy(self._cache.select) yield from cols def __len__(self) -> int: return len(self._cache.select) + def __rshift__(self, rhs): + if isinstance(rhs, Pipeable): + return rhs.__rrshift__(self) + if isinstance(rhs, Callable): + num_params = len(inspect.signature(rhs).parameters) + if num_params != 1: + raise TypeError( + "only functions with one parameter can be used in a pipe, got " + f"function with {num_params} parameters." + ) + return rhs(self) + + raise TypeError( + f"found instance of invalid type `{type(rhs)}` in the pipe. \n" + "hint: You can use a `Table` or a Callable taking a single argument in a " + "pipe. If you use a Callable, it will receive the current table as an " + "and must return a `Table`." + ) + def __str__(self): try: from pydiverse.transform.backend.targets import Polars diff --git a/src/pydiverse/transform/pipe/verbs.py b/src/pydiverse/transform/pipe/verbs.py index 8348645e..9500a6e1 100644 --- a/src/pydiverse/transform/pipe/verbs.py +++ b/src/pydiverse/transform/pipe/verbs.py @@ -198,7 +198,7 @@ def filter(table: Table, *predicates: ColExpr): new._ast = Filter(table._ast, preprocess_arg(predicates, table)) for cond in new._ast.filters: - if not isinstance(cond.dtype(), dtypes.Bool): + if cond.dtype() != dtypes.Bool: raise TypeError( "predicates given to `filter` must be of boolean type.\n" f"hint: {cond} is of type {cond.dtype()} instead." diff --git a/src/pydiverse/transform/tree/col_expr.py b/src/pydiverse/transform/tree/col_expr.py index 5ad49d05..c56a0f40 100644 --- a/src/pydiverse/transform/tree/col_expr.py +++ b/src/pydiverse/transform/tree/col_expr.py @@ -115,7 +115,7 @@ def __str__(self) -> str: ) def __hash__(self) -> int: - return hash(self.uuid) + return hash(self._uuid) class ColName(ColExpr): @@ -296,6 +296,11 @@ def __getattr__(self, name) -> FnAttr: return FnAttr(f"{self.name}.{name}", self.arg) def __call__(self, *args, **kwargs) -> ColExpr: + if self.name == "cast": + if len(kwargs) > 0: + raise ValueError("`cast` does not take any keyword arguments") + return Cast(self.arg, *args) + return ColFn( self.name, wrap_literal(self.arg), @@ -379,7 +384,7 @@ def dtype(self): raise TypeError(f"invalid case expression: {e}") from e for cond, _ in self.cases: - if cond.dtype() is not None and not isinstance(cond.dtype(), dtypes.Bool): + if cond.dtype() is not None and cond.dtype() != dtypes.Bool: raise TypeError( f"argument `{cond}` for `when` must be of boolean type, but has " f"type `{cond.dtype()}`" @@ -438,6 +443,64 @@ def otherwise(self, value: ColExpr) -> CaseExpr: return CaseExpr(self.cases, wrap_literal(value)) +class Cast(ColExpr): + __slots__ = ["val", "target_type"] + + def __init__(self, val: ColExpr, target_type: Dtype): + self.val = val + self.target_type = target_type + super().__init__(target_type) + self.dtype() + + def dtype(self) -> Dtype: + # Since `ColExpr.dtype` is also responsible for type checking, we may not set + # `_dtype` until we are able to retrieve the type of `val`. + if self.val.dtype() is None: + return None + + if not self.val.dtype().can_promote_to(self.target_type): + valid_casts = { + (dtypes.String, dtypes.Int64), + (dtypes.String, dtypes.Float64), + (dtypes.Float64, dtypes.Int64), + (dtypes.DateTime, dtypes.Date), + (dtypes.Int64, dtypes.String), + (dtypes.Float64, dtypes.String), + (dtypes.DateTime, dtypes.String), + (dtypes.Date, dtypes.String), + } + + if ( + self.val.dtype().__class__, + self.target_type.__class__, + ) not in valid_casts: + hint = "" + if self.val.dtype() == dtypes.String and self.target_type in ( + dtypes.DateTime, + dtypes.Date, + ): + hint = ( + "\nhint: to convert a str to datetime, call " + f"`.str.to_{self.target_type.name}()` on the expression." + ) + + raise TypeError( + f"cannot cast type {self.val.dtype()} to {self.target_type}." + f"{hint}" + ) + + return self._dtype + + def ftype(self, *, agg_is_window: bool) -> Ftype: + return self.val.ftype(agg_is_window=agg_is_window) + + def iter_children(self) -> Iterable[ColExpr]: + yield self.val + + def map_subtree(self, g: Callable[[ColExpr], ColExpr]) -> ColExpr: + return g(Cast(self.val.map_subtree(g), self.target_type)) + + @dataclasses.dataclass(slots=True) class Order: order_by: ColExpr diff --git a/src/pydiverse/transform/tree/dtypes.py b/src/pydiverse/transform/tree/dtypes.py index be06ec35..9f5152a3 100644 --- a/src/pydiverse/transform/tree/dtypes.py +++ b/src/pydiverse/transform/tree/dtypes.py @@ -13,17 +13,22 @@ def __init__(self, *, const: bool = False, vararg: bool = False): self.const = const self.vararg = vararg - def __eq__(self, other): - if type(self) is not type(other): + def __eq__(self, rhs): + if type(self) is rhs: + return True + if type(self) is not type(rhs): return False - if self.const != other.const: + if self.const != rhs.const: return False - if self.vararg != other.vararg: + if self.vararg != rhs.vararg: return False - if self.name != other.name: + if self.name != rhs.name: return False return True + def __ne__(self, rhs: object) -> bool: + return not self.__eq__(rhs) + def __hash__(self): return hash((self.name, self.const, self.vararg, type(self).__qualname__)) @@ -68,15 +73,15 @@ def can_promote_to(self, other: Dtype) -> bool: return other.same_kind(self) -class Int(Dtype): - name = "int" +class Int64(Dtype): + name = "int64" def can_promote_to(self, other: Dtype) -> bool: if super().can_promote_to(other): return True - # int can be promoted to float - if Float().same_kind(other): + # int64 can be promoted to float64 + if Float64().same_kind(other): if other.const and not self.const: return False @@ -85,8 +90,12 @@ def can_promote_to(self, other: Dtype) -> bool: return False -class Float(Dtype): - name = "float" +class Float64(Dtype): + name = "float64" + + +class Decimal(Dtype): + name = "decimal" class String(Dtype): @@ -142,9 +151,9 @@ class NoneDtype(Dtype): def python_type_to_pdt(t: type) -> Dtype: if t is int: - return Int() + return Int64() elif t is float: - return Float() + return Float64() elif t is bool: return Bool() elif t is str: @@ -190,10 +199,12 @@ def dtype_from_string(t: str) -> Dtype: if is_template: return Template(base_type, const=is_const, vararg=is_vararg) - if base_type == "int": - return Int(const=is_const, vararg=is_vararg) - if base_type == "float": - return Float(const=is_const, vararg=is_vararg) + if base_type == "int64": + return Int64(const=is_const, vararg=is_vararg) + if base_type == "float64": + return Float64(const=is_const, vararg=is_vararg) + if base_type == "decimal": + return Decimal(const=is_const, vararg=is_vararg) if base_type == "str": return String(const=is_const, vararg=is_vararg) if base_type == "bool": diff --git a/src/pydiverse/transform/tree/registry.py b/src/pydiverse/transform/tree/registry.py index 023d9ed5..70637f6e 100644 --- a/src/pydiverse/transform/tree/registry.py +++ b/src/pydiverse/transform/tree/registry.py @@ -296,21 +296,21 @@ class OperatorSignature: terminal_arg ::= modifiers (dtype | vararg) vararg ::= dtype "..." rtype ::= dtype - dtype ::= template | "int" | "float" | "str" | "bool" | and others... + dtype ::= template | "int64" | "float64" | "str" | "bool" | and others... modifiers ::= "const"? template ::= single uppercase character Examples: Function that takes two integers and returns an integer: - int, int -> int + int64, int64 -> int64 Templated argument (templates consist of single uppercase characters): T, T -> T T, U -> bool Variable number of arguments: - int... -> int + int64... -> int64 """ @@ -523,7 +523,7 @@ def does_match( elif not node.value.same_kind(dtype): # Needs type promotion # This only works when types can be promoted once - # -> (uint > int) wouldn't be preferred over (uint > int > float) + # -> (uint > int64) wouldn't be preferred over (uint > int64 > float64) type_promotion_indices = (*type_promotion_indices, s_i) if s_i + 1 == len(signature): diff --git a/tests/test_backend_equivalence/conftest.py b/tests/test_backend_equivalence/conftest.py index da619439..59673f4b 100644 --- a/tests/test_backend_equivalence/conftest.py +++ b/tests/test_backend_equivalence/conftest.py @@ -69,6 +69,32 @@ None, "% _.AbAbAb", ], + "c": [ + "4352.0", + "-21", + "-nan", + "3.313", + None, + "-inf", + "inf", + "nan", + "-0.000", + "-0.0", + "0.0", + ], + "d": [ + None, + "-123124", + "21241", + "010101", + "0", + "1", + "-12", + "42", + "197", + "1729", + "-100110", + ], } ), "df_datetime": pl.DataFrame( @@ -119,6 +145,46 @@ # ], } ), + "df_num": pl.DataFrame( + { + "a": [0.4, -1.1, -0.0, 0.0, 9.0, 2.0, -344.0053, -1000.0], + "b": [None, 2, 0, -11, 4, 19, -5190, 2000000], + "c": [0.0, None, None, 2.9, -0.0, 10.0, -10.0, 3.1415926535], + "d": [None, 2352.0230, 0.577, 901234, -6.0, 4.0, None, -99.0], + "e": [1.0, 2.0, 3.0, 4.99, -442.0, 6.0, 7.0, 500.0], + "f": [3.0, None, 0.0, 4.3, 10.0, -1.2, -9999.1, -34.1], + "g": [-5.5, None, None, 1.100212, -3.412351, 1000.4252, 0.0, -1.6], + "zero": [0.0, -0.0] * 4, + "pos": [ + 1.123, + 1297.324, + 9192.9793, + 7.5, + 912.097, + 32.9, + 2.712834, + 5002352.434, + ], + "neg": [ + -9623.1, + -0.1, + -1.0, + -923737552.5, + -5.5, + -0.12083, + -93.4, + -6699917733.1242, + ], + "null_s": [0.0, None, None, None, None, None, None, None], + } + ), + "df_int": pl.DataFrame( + { + "a": [3, 1, 0, -12, 4, 5, 1 << 20, 5], + "b": [-23, 18282, -42, 1729, None, -2323, 11, 1], + "null_s": [0] + [None] * 7, + } + ), } # compare one dataframe and one SQL backend to all others diff --git a/tests/test_backend_equivalence/test_ops/test_cast.py b/tests/test_backend_equivalence/test_ops/test_cast.py new file mode 100644 index 00000000..019edc86 --- /dev/null +++ b/tests/test_backend_equivalence/test_ops/test_cast.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import pydiverse.transform as pdt +from pydiverse.transform.pipe.c import C +from pydiverse.transform.pipe.verbs import mutate +from tests.test_backend_equivalence.test_ops.test_ops_numerical import add_nan_inf_cols +from tests.util.assertion import assert_result_equal + + +def test_string_to_float(df_strings): + assert_result_equal( + df_strings, + lambda t: t >> mutate(u=t.c.cast(pdt.Float64())), + ) + + +def test_string_to_int(df_strings): + assert_result_equal( + df_strings, + lambda t: t >> mutate(u=t.d.cast(pdt.Int64())), + ) + + +def test_float_to_int(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{col.name: col.cast(pdt.Int64()) for col in t}), + ) + + assert_result_equal( + df_num, + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.inf.cast(pdt.Int64())), + exception=Exception, + may_throw=True, + ) + assert_result_equal( + df_num, + lambda t: t >> add_nan_inf_cols() >> mutate(u=C.nan.cast(pdt.Int64())), + exception=Exception, + may_throw=True, + ) + + +def test_datetime_to_date(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t >> mutate(u=t.col1.cast(pdt.Date()), v=t.col2.cast(pdt.Date())), + ) + + +def test_int_to_string(df_int): + assert_result_equal( + df_int, lambda t: t >> mutate(**{c.name: c.cast(pdt.String()) for c in t}) + ) + + +def test_float_to_string(df_num): + assert_result_equal( + df_num, + lambda t: t + >> add_nan_inf_cols() + >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.String()) for c in s})) + >> (lambda s: s >> mutate(**{c.name: c.cast(pdt.Float64()) for c in s})), + ) + + +def test_datetime_to_string(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t + >> mutate( + x=t.col1.cast(pdt.String()), + y=t.col2.cast(pdt.String()), + ) + >> mutate( + x=C.x.str.to_datetime(), + y=C.y.str.to_datetime(), + ), + ) + + +def test_date_to_string(df_datetime): + assert_result_equal( + df_datetime, + lambda t: t + >> mutate( + x=t.col1.cast(pdt.Date()).cast(pdt.String()), + y=t.col2.cast(pdt.Date()).cast(pdt.String()), + z=t.cdate.cast(pdt.String()), + ), + ) diff --git a/tests/test_backend_equivalence/test_ops/test_ops_numerical.py b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py new file mode 100644 index 00000000..f2a57400 --- /dev/null +++ b/tests/test_backend_equivalence/test_ops/test_ops_numerical.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import pydiverse.transform as pdt +from pydiverse.transform.pipe.pipeable import verb +from pydiverse.transform.pipe.verbs import mutate +from tests.util.assertion import assert_result_equal + + +@verb +def add_nan_inf_cols(table: pdt.Table) -> pdt.Table: + return table >> mutate( + **{ + "nan": float("nan"), + "negnan": float("-nan"), + "inf": float("inf"), + "neginf": float("-inf"), + } + ) + + +def test_exp(df_num): + assert_result_equal( + df_num, + lambda t: t >> add_nan_inf_cols() >> mutate(**{c.name: c.exp() for c in t}), + ) + + +def test_log(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: c.log() for c in t}), + ) + + +def test_abs(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: abs(c) for c in t}), + ) + + +def test_round(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: round(c) for c in t}), + ) + + +def test_add(df_num): + assert_result_equal( + df_num, + lambda t: t + >> add_nan_inf_cols() + >> ( + lambda s: s + >> mutate(**{f"add_{c.name}_{d.name}": c + d for d in s for c in s}) + ), + ) + + +def test_sub(df_num): + assert_result_equal( + df_num, + lambda t: t + >> add_nan_inf_cols() + >> ( + lambda s: s + >> mutate(**{f"sub_{c.name}_{d.name}": c - d for d in s for c in s}) + ), + ) + + +def test_neg(df_num): + assert_result_equal( + df_num, + lambda t: t + >> add_nan_inf_cols() + >> (lambda s: s >> mutate(**{f"neg_{c.name}": -c for c in s})), + ) + + +def test_mul(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}*{d.name}": c * d for d in t for c in t}), + ) + + +def test_div(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{f"{c.name}/{d.name}": c / d for d in t for c in t}), + ) + + +def test_decimal(df_num): + # TODO: test the decimal here + assert_result_equal(df_num, lambda t: t >> mutate(u=t.f + t.g, z=t.f * t.g)) + + +def test_floor(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: c.floor() for c in t}), + ) + + +def test_ceil(df_num): + assert_result_equal( + df_num, + lambda t: t >> mutate(**{c.name: c.ceil() for c in t}), + ) diff --git a/tests/test_operator_registry.py b/tests/test_operator_registry.py index 7014e7f7..edcde881 100644 --- a/tests/test_operator_registry.py +++ b/tests/test_operator_registry.py @@ -23,27 +23,27 @@ def assert_signature( class TestOperatorSignature: def test_parse_simple(self): - s = OperatorSignature.parse("int, int -> int") - assert_signature(s, [dtypes.Int(), dtypes.Int()], dtypes.Int()) + s = OperatorSignature.parse("int64, int64 -> int64") + assert_signature(s, [dtypes.Int64(), dtypes.Int64()], dtypes.Int64()) s = OperatorSignature.parse("bool->bool ") assert_signature(s, [dtypes.Bool()], dtypes.Bool()) - s = OperatorSignature.parse("-> int") - assert_signature(s, [], dtypes.Int()) + s = OperatorSignature.parse("-> int64") + assert_signature(s, [], dtypes.Int64()) with pytest.raises(ValueError): - OperatorSignature.parse("int, int -> ") + OperatorSignature.parse("int64, int64 -> ") with pytest.raises(ValueError): - OperatorSignature.parse("int, int -> int, int") + OperatorSignature.parse("int64, int64 -> int64, int64") with pytest.raises(ValueError): - OperatorSignature.parse("o#r -> int") + OperatorSignature.parse("o#r -> int64") with pytest.raises(ValueError): - OperatorSignature.parse("int -> a#") + OperatorSignature.parse("int64 -> a#") def test_parse_template(self): - s = OperatorSignature.parse("T, int -> int") + s = OperatorSignature.parse("T, int64 -> int64") assert isinstance(s.args[0], dtypes.Template) s = OperatorSignature.parse("T -> T") @@ -54,26 +54,26 @@ def test_parse_template(self): OperatorSignature.parse("T, T -> U") def test_parse_varargs(self): - s = OperatorSignature.parse("int, str... -> int") + s = OperatorSignature.parse("int64, str... -> int64") assert not s.args[0].vararg assert s.args[1].vararg - s = OperatorSignature.parse("int... -> bool") + s = OperatorSignature.parse("int64... -> bool") assert s.args[0].vararg with pytest.raises(ValueError): - OperatorSignature.parse("int..., str -> int") + OperatorSignature.parse("int64..., str -> int64") with pytest.raises(ValueError): - OperatorSignature.parse("int, str -> int...") + OperatorSignature.parse("int64, str -> int64...") - s0 = OperatorSignature.parse("int, str -> int") - s1 = OperatorSignature.parse("int, str... -> int") + s0 = OperatorSignature.parse("int64, str -> int64") + s1 = OperatorSignature.parse("int64, str... -> int64") assert not s0.is_vararg assert s1.is_vararg def test_parse_const(self): - s = OperatorSignature.parse("const int -> int") + s = OperatorSignature.parse("const int64 -> int64") assert s.args[0].const assert not s.rtype.const @@ -103,33 +103,31 @@ def test_simple(self): reg.register_op(op1) reg.register_op(op2) - reg.add_impl(op1, lambda: 1, "int, int -> int") + reg.add_impl(op1, lambda: 1, "int64, int64 -> int64") reg.add_impl(op1, lambda: 2, "str, str -> str") - reg.add_impl(op2, lambda: 10, "int, int -> int") + reg.add_impl(op2, lambda: 10, "int64, int64 -> int64") reg.add_impl(op2, lambda: 20, "str, str -> str") - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 1 assert isinstance( - reg.get_impl("op1", parse_dtypes("int", "int")).return_type, - dtypes.Int, + reg.get_impl("op1", parse_dtypes("int64", "int64")).return_type, + dtypes.Int64, ) - assert reg.get_impl("op2", parse_dtypes("int", "int"))() == 10 + assert reg.get_impl("op2", parse_dtypes("int64", "int64"))() == 10 assert reg.get_impl("op1", parse_dtypes("str", "str"))() == 2 assert reg.get_impl("op2", parse_dtypes("str", "str"))() == 20 - with pytest.raises(ValueError): - reg.get_impl("op1", parse_dtypes("int", "str")) + with pytest.raises(TypeError): + reg.get_impl("op1", parse_dtypes("int64", "str")) with pytest.raises(ValueError): reg.get_impl( "not_implemented", - parse_dtypes( - "int", - ), + parse_dtypes("int64"), ) - reg.add_impl(op1, lambda: 100, "-> int") + reg.add_impl(op1, lambda: 100, "-> int64") assert reg.get_impl("op1", tuple())() == 100 def test_template(self): @@ -149,53 +147,53 @@ def test_template(self): with pytest.raises(ValueError, match="already defined"): reg.add_impl(op1, lambda: 3, "T, U -> U") - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 1 - assert reg.get_impl("op1", parse_dtypes("int", "str"))() == 2 - # int can be promoted to float; results in "float, float -> bool" signature - assert reg.get_impl("op1", parse_dtypes("int", "float"))() == 1 - assert reg.get_impl("op1", parse_dtypes("float", "int"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 1 + assert reg.get_impl("op1", parse_dtypes("int64", "str"))() == 2 + # int64 can be promoted to float; results in "float, float -> bool" signature + assert reg.get_impl("op1", parse_dtypes("int64", "float64"))() == 1 + assert reg.get_impl("op1", parse_dtypes("float64", "int64"))() == 1 # More template matching... Also check matching precedence - reg.add_impl(op2, lambda: 1, "int, int, int -> int") - reg.add_impl(op2, lambda: 2, "int, str, T -> int") - reg.add_impl(op2, lambda: 3, "int, T, str -> int") - reg.add_impl(op2, lambda: 4, "int, T, T -> int") - reg.add_impl(op2, lambda: 5, "T, T, T -> int") - reg.add_impl(op2, lambda: 6, "A, T, T -> int") - - assert reg.get_impl("op2", parse_dtypes("int", "int", "int"))() == 1 - assert reg.get_impl("op2", parse_dtypes("int", "str", "str"))() == 2 - assert reg.get_impl("op2", parse_dtypes("int", "int", "str"))() == 3 - assert reg.get_impl("op2", parse_dtypes("int", "bool", "bool"))() == 4 + reg.add_impl(op2, lambda: 1, "int64, int64, int64 -> int64") + reg.add_impl(op2, lambda: 2, "int64, str, T -> int64") + reg.add_impl(op2, lambda: 3, "int64, T, str -> int64") + reg.add_impl(op2, lambda: 4, "int64, T, T -> int64") + reg.add_impl(op2, lambda: 5, "T, T, T -> int64") + reg.add_impl(op2, lambda: 6, "A, T, T -> int64") + + assert reg.get_impl("op2", parse_dtypes("int64", "int64", "int64"))() == 1 + assert reg.get_impl("op2", parse_dtypes("int64", "str", "str"))() == 2 + assert reg.get_impl("op2", parse_dtypes("int64", "int64", "str"))() == 3 + assert reg.get_impl("op2", parse_dtypes("int64", "bool", "bool"))() == 4 assert reg.get_impl("op2", parse_dtypes("str", "str", "str"))() == 5 - assert reg.get_impl("op2", parse_dtypes("float", "str", "str"))() == 6 + assert reg.get_impl("op2", parse_dtypes("float64", "str", "str"))() == 6 - with pytest.raises(ValueError): - reg.get_impl("op2", parse_dtypes("int", "bool", "float")) + with pytest.raises(TypeError): + reg.get_impl("op2", parse_dtypes("int64", "bool", "float64")) # Return type reg.add_impl(op3, lambda: 1, "T -> T") - reg.add_impl(op3, lambda: 2, "int, T, U -> T") + reg.add_impl(op3, lambda: 2, "int64, T, U -> T") reg.add_impl(op3, lambda: 3, "str, T, U -> U") with pytest.raises(ValueError, match="already defined."): - reg.add_impl(op3, lambda: 4, "int, T, U -> U") + reg.add_impl(op3, lambda: 4, "int64, T, U -> U") assert isinstance( reg.get_impl("op3", parse_dtypes("str")).return_type, dtypes.String, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("int")).return_type, - dtypes.Int, + reg.get_impl("op3", parse_dtypes("int64")).return_type, + dtypes.Int64, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("int", "int", "float")).return_type, - dtypes.Int, + reg.get_impl("op3", parse_dtypes("int64", "int64", "float64")).return_type, + dtypes.Int64, ) assert isinstance( - reg.get_impl("op3", parse_dtypes("str", "int", "float")).return_type, - dtypes.Float, + reg.get_impl("op3", parse_dtypes("str", "int64", "float64")).return_type, + dtypes.Float64, ) def test_vararg(self): @@ -204,25 +202,25 @@ def test_vararg(self): op1 = self.Op1() reg.register_op(op1) - reg.add_impl(op1, lambda: 1, "int... -> int") - reg.add_impl(op1, lambda: 2, "int, int... -> int") - reg.add_impl(op1, lambda: 3, "int, T... -> T") + reg.add_impl(op1, lambda: 1, "int64... -> int64") + reg.add_impl(op1, lambda: 2, "int64, int64... -> int64") + reg.add_impl(op1, lambda: 3, "int64, T... -> T") assert ( reg.get_impl( "op1", parse_dtypes( - "int", + "int64", ), )() == 1 ) - assert reg.get_impl("op1", parse_dtypes("int", "int"))() == 2 - assert reg.get_impl("op1", parse_dtypes("int", "int", "int"))() == 2 - assert reg.get_impl("op1", parse_dtypes("int", "str", "str"))() == 3 + assert reg.get_impl("op1", parse_dtypes("int64", "int64"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int64", "int64", "int64"))() == 2 + assert reg.get_impl("op1", parse_dtypes("int64", "str", "str"))() == 3 assert isinstance( - reg.get_impl("op1", parse_dtypes("int", "str", "str")).return_type, + reg.get_impl("op1", parse_dtypes("int64", "str", "str")).return_type, dtypes.String, ) @@ -233,13 +231,13 @@ def test_variant(self): reg.register_op(op1) with pytest.raises(ValueError): - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR") - reg.add_impl(op1, lambda: 1, "-> int") - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 1, "-> int64") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR") assert reg.get_impl("op1", tuple())() == 1 assert reg.get_impl("op1", tuple()).get_variant("VAR")() == 2 with pytest.raises(ValueError): - reg.add_impl(op1, lambda: 2, "-> int", variant="VAR") + reg.add_impl(op1, lambda: 2, "-> int64", variant="VAR") diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index d3b53cc7..af2a48e5 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -116,12 +116,12 @@ def tbl_dt(): class TestPolarsLazyImpl: def test_dtype(self, tbl1, tbl2): - assert isinstance(tbl1.col1.dtype(), dtypes.Int) + assert isinstance(tbl1.col1.dtype(), dtypes.Int64) assert isinstance(tbl1.col2.dtype(), dtypes.String) - assert isinstance(tbl2.col1.dtype(), dtypes.Int) - assert isinstance(tbl2.col2.dtype(), dtypes.Int) - assert isinstance(tbl2.col3.dtype(), dtypes.Float) + assert isinstance(tbl2.col1.dtype(), dtypes.Int64) + assert isinstance(tbl2.col2.dtype(), dtypes.Int64) + assert isinstance(tbl2.col3.dtype(), dtypes.Float64) # test that column expression type errors are checked immediately with pytest.raises(TypeError): diff --git a/tests/util/assertion.py b/tests/util/assertion.py index 0fec9339..479d957f 100644 --- a/tests/util/assertion.py +++ b/tests/util/assertion.py @@ -8,9 +8,10 @@ from polars.testing import assert_frame_equal from pydiverse.transform import Table +from pydiverse.transform.backend.sqlite import SqliteImpl from pydiverse.transform.backend.targets import Polars from pydiverse.transform.errors import NonStandardBehaviourWarning -from pydiverse.transform.pipe.verbs import export, show_query +from pydiverse.transform.pipe.verbs import export, get_backend, show_query def assert_equal(left, right, check_dtypes=False, check_row_order=True): @@ -85,6 +86,10 @@ def assert_result_equal( pl.col(pl.Decimal(scale=10)).cast(pl.Float64) ) + # sqlite does not know NaN + if get_backend(query_y._ast) is SqliteImpl: + dfx = dfx.fill_nan(None) + # after a join, cols containing only null values get type Null on SQLite and # Postgres. maybe we can fix this but for now we just ignore such cols assert dfx.columns == dfy.columns diff --git a/tests/util/backend.py b/tests/util/backend.py index 957cad28..006181e7 100644 --- a/tests/util/backend.py +++ b/tests/util/backend.py @@ -32,7 +32,7 @@ def polars_table(df: pl.DataFrame, name: str): _sql_engine_cache = {} -def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict = None): +def sql_table(df: pl.DataFrame, name: str, url: str, dtypes_map: dict | None = None): import sqlalchemy as sqa global _sql_engine_cache @@ -69,6 +69,7 @@ def duckdb_table(df: pl.DataFrame, name: str): @_cached_table def postgres_table(df: pl.DataFrame, name: str): url = "postgresql://sa:Pydiverse23@127.0.0.1:6543" + return sql_table(df, name, url) @@ -80,12 +81,7 @@ def mssql_table(df: pl.DataFrame, name: str): "mssql+pyodbc://sa:PydiQuant27@127.0.0.1:1433" "/master?driver=ODBC+Driver+18+for+SQL+Server&encrypt=no" ) - return sql_table( - df, - name, - url, - dtypes_map={pl.Datetime(): DATETIME2()}, - ) + return sql_table(df, name, url, dtypes_map={pl.Datetime(): DATETIME2()}) BACKEND_TABLES = {