Skip to content

Commit

Permalink
Make tests work again
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Oct 1, 2024
2 parents 1baa568 + 15307a3 commit 5789284
Show file tree
Hide file tree
Showing 33 changed files with 1,379 additions and 580 deletions.
27 changes: 26 additions & 1 deletion src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import Col
from pydiverse.transform.tree.col_expr import Cast, Col


class DuckDbImpl(SqlImpl):
Expand All @@ -21,3 +24,25 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]):
DuckDbImpl.build_query(nd, final_select), connection=conn
)
return SqlImpl.export(nd, target, final_select)

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
return super().compile_cast(cast, sqa_col)


with DuckDbImpl.op(ops.FloorDiv()) as op:

@op.auto
def _floordiv(lhs, rhs):
return sqa.func.divide(lhs, rhs)


with DuckDbImpl.op(ops.RFloorDiv()) as op:

@op.auto
def _floordiv(rhs, lhs):
return sqa.func.divide(lhs, rhs)
86 changes: 85 additions & 1 deletion src/pydiverse/transform/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import sqlalchemy as sqa
from sqlalchemy.dialects.mssql import DATETIME2

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
Expand All @@ -13,6 +14,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand All @@ -25,6 +27,37 @@
class MsSqlImpl(SqlImpl):
dialect_name = "mssql"

INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Double) / sqa.literal(
"0.0", type_=sqa.Double
)
NEG_INF = -INF
NAN = INF + NEG_INF

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)
if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64:
return sqa.case(
(compiled_val == "inf", cls.INF),
(compiled_val == "-inf", -cls.INF),
(compiled_val.in_(("nan", "-nan")), cls.NAN),
else_=sqa.cast(
compiled_val,
cls.sqa_type(cast.target_type),
),
)

if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String:
compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String)
return sqa.case(
(compiled == "1.#QNAN", "nan"),
(compiled == "1.#INF", "inf"),
(compiled == "-1.#INF", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))

@classmethod
def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
# boolean / bit conversion
Expand Down Expand Up @@ -54,6 +87,13 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select})
return cls.compile_query(table, query)

@classmethod
def sqa_type(cls, t: dtypes.Dtype):
if isinstance(t, dtypes.DateTime):
return DATETIME2

return super().sqa_type(t)


def convert_order_list(order_list: list[Order]) -> list[Order]:
new_list: list[Order] = []
Expand Down Expand Up @@ -93,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
)

elif isinstance(expr, Col):
if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool):
if not wants_bool_as_bit and expr.dtype() == dtypes.Bool:
return ColFn("__eq__", expr, LiteralCol(True))
return expr

Expand Down Expand Up @@ -146,6 +186,14 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
elif isinstance(expr, LiteralCol):
return expr

elif isinstance(expr, Cast):
# TODO: does this really work for casting onto / from booleans? we probably have
# to use wants_bool_as_bit in some way when casting to bool
return Cast(
convert_bool_bit(expr.val, wants_bool_as_bit=wants_bool_as_bit),
expr.target_type,
)

raise AssertionError


Expand Down Expand Up @@ -289,3 +337,39 @@ def _day_of_week(x):
@op.auto
def _mean(x):
return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double())


with MsSqlImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
return sqa.case(
(x > 0, sqa.func.log(x)),
(x < 0, MsSqlImpl.NAN),
(x.is_(sqa.null()), None),
else_=-MsSqlImpl.INF,
)


with MsSqlImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return sqa.func.ceiling(x)


with MsSqlImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return sqa.cast(x, DATETIME2)


with MsSqlImpl.op(ops.Round()) as op:

@op.auto
def _round(x, decimals=0):
return sqa.case(
(x != x, MsSqlImpl.NAN),
else_=sqa.func.round(x, decimals, type_=x.type),
)
113 changes: 84 additions & 29 deletions src/pydiverse/transform/backend/polars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import datetime
from types import NoneType
from typing import Any
from uuid import UUID

Expand All @@ -15,6 +13,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand Down Expand Up @@ -159,7 +158,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:

# the function was executed on the ordered arguments. here we
# restore the original order of the table.
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by(
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64()).sort_by(
by=order_by,
descending=descending,
nulls_last=nulls_last,
Expand All @@ -182,10 +181,20 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
return compiled

elif isinstance(expr, LiteralCol):
if isinstance(expr.dtype(), dtypes.String):
if expr.dtype() == dtypes.String:
return pl.lit(expr.val) # polars interprets strings as column names
return expr.val

elif isinstance(expr, Cast):
compiled = compile_col_expr(expr.val, name_in_df).cast(
pdt_type_to_polars(expr.target_type)
)

if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String:
compiled = compiled.replace("NaN", "nan")

return compiled

else:
raise AssertionError

Expand Down Expand Up @@ -340,9 +349,9 @@ def has_path_to_leaf_without_agg(expr: ColExpr):

def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:
if t.is_float():
return dtypes.Float()
return dtypes.Float64()
elif t.is_integer():
return dtypes.Int()
return dtypes.Int64()
elif isinstance(t, pl.Boolean):
return dtypes.Bool()
elif isinstance(t, pl.String):
Expand All @@ -360,9 +369,9 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:


def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
if isinstance(t, dtypes.Float):
if isinstance(t, (dtypes.Float64, dtypes.Decimal)):
return pl.Float64()
elif isinstance(t, dtypes.Int):
elif isinstance(t, dtypes.Int64):
return pl.Int64()
elif isinstance(t, dtypes.Bool):
return pl.Boolean()
Expand All @@ -380,27 +389,6 @@ def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
raise AssertionError


def python_type_to_polars(t: type) -> pl.DataType:
if t is int:
return pl.Int64()
elif t is float:
return pl.Float64()
elif t is bool:
return pl.Boolean()
elif t is str:
return pl.String()
elif t is datetime.datetime:
return pl.Datetime()
elif t is datetime.date:
return pl.Date()
elif t is datetime.timedelta:
return pl.Duration()
elif t is NoneType:
return pl.Null()

raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform")


with PolarsImpl.op(ops.Mean()) as op:

@op.auto
Expand Down Expand Up @@ -709,3 +697,70 @@ def _greatest(*x):
@op.auto
def _least(*x):
return pl.min_horizontal(*x)


with PolarsImpl.op(ops.Round()) as op:

@op.auto
def _round(x, digits=0):
return x.round(digits)


with PolarsImpl.op(ops.Exp()) as op:

@op.auto
def _exp(x):
return x.exp()


with PolarsImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
return x.log()


with PolarsImpl.op(ops.Floor()) as op:

@op.auto
def _floor(x):
return x.floor()


with PolarsImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return x.ceil()


with PolarsImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return x.str.to_datetime()


with PolarsImpl.op(ops.StrToDate()) as op:

@op.auto
def _str_to_date(x):
return x.str.to_date()


with PolarsImpl.op(ops.FloorDiv()) as op:

@op.auto
def _floordiv(lhs, rhs):
result_sign = (lhs < 0) ^ (rhs < 0)
return (abs(lhs) // abs(rhs)) * pl.when(result_sign).then(-1).otherwise(1)
# TODO: test some alternatives if this is too slow


with PolarsImpl.op(ops.Mod()) as op:

@op.auto
def _mod(lhs, rhs):
return lhs % (abs(rhs) * pl.when(lhs >= 0).then(1).otherwise(-1))
# TODO: see whether the following is faster:
# pl.when(lhs >= 0).then(lhs % abs(rhs)).otherwise(lhs % -abs(rhs))
21 changes: 21 additions & 0 deletions src/pydiverse/transform/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,32 @@

from pydiverse.transform import ops
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.col_expr import Cast


class PostgresImpl(SqlImpl):
dialect_name = "postgresql"

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)

if isinstance(cast.val.dtype(), dtypes.Float64):
if isinstance(cast.target_type, dtypes.Int64):
return sqa.func.trunc(compiled_val).cast(sqa.BigInteger())

if isinstance(cast.target_type, dtypes.String):
compiled = sqa.cast(compiled_val, sqa.String)
return sqa.case(
(compiled == "NaN", "nan"),
(compiled == "Infinity", "inf"),
(compiled == "-Infinity", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))


with PostgresImpl.op(ops.Less()) as op:

Expand Down
Loading

0 comments on commit 5789284

Please sign in to comment.