Skip to content

Commit

Permalink
Merge pull request #27 from pydiverse/cast
Browse files Browse the repository at this point in the history
Add casts and string to datetime conversion
  • Loading branch information
finn-rudolph authored Oct 1, 2024
2 parents 6a547d6 + a3933e5 commit 46ed4fe
Show file tree
Hide file tree
Showing 25 changed files with 904 additions and 254 deletions.
12 changes: 11 additions & 1 deletion src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import Col
from pydiverse.transform.tree.col_expr import Cast, Col


class DuckDbImpl(SqlImpl):
Expand All @@ -21,3 +23,11 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]):
DuckDbImpl.build_query(nd, final_select), connection=conn
)
return SqlImpl.export(nd, target, final_select)

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
return super().compile_cast(cast, sqa_col)
77 changes: 76 additions & 1 deletion src/pydiverse/transform/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import sqlalchemy as sqa
from sqlalchemy.dialects.mssql import DATETIME2

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
Expand All @@ -13,6 +14,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand All @@ -25,6 +27,37 @@
class MsSqlImpl(SqlImpl):
dialect_name = "mssql"

INF = sqa.cast(sqa.literal("1.0"), type_=sqa.Float()) / sqa.literal(
"0.0", type_=sqa.Float()
)
NEG_INF = -INF
NAN = INF + NEG_INF

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)
if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64:
return sqa.case(
(compiled_val == "inf", cls.INF),
(compiled_val == "-inf", -cls.INF),
(compiled_val.in_(("nan", "-nan")), cls.NAN),
else_=sqa.cast(
compiled_val,
cls.sqa_type(cast.target_type),
),
)

if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String:
compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String)
return sqa.case(
(compiled == "1.#QNAN", "nan"),
(compiled == "1.#INF", "inf"),
(compiled == "-1.#INF", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))

@classmethod
def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
# boolean / bit conversion
Expand Down Expand Up @@ -54,6 +87,13 @@ def build_select(cls, nd: AstNode, final_select: list[Col]) -> Any:
table, query, _ = cls.compile_ast(nd, {col._uuid: 1 for col in final_select})
return cls.compile_query(table, query)

@classmethod
def sqa_type(cls, t: dtypes.Dtype):
if isinstance(t, dtypes.DateTime):
return DATETIME2()

return super().sqa_type(t)


def convert_order_list(order_list: list[Order]) -> list[Order]:
new_list: list[Order] = []
Expand Down Expand Up @@ -93,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
)

elif isinstance(expr, Col):
if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool):
if not wants_bool_as_bit and expr.dtype() == dtypes.Bool:
return ColFn("__eq__", expr, LiteralCol(True))
return expr

Expand Down Expand Up @@ -146,6 +186,14 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
elif isinstance(expr, LiteralCol):
return expr

elif isinstance(expr, Cast):
# TODO: does this really work for casting onto / from booleans? we probably have
# to use wants_bool_as_bit in some way when casting to bool
return Cast(
convert_bool_bit(expr.val, wants_bool_as_bit=wants_bool_as_bit),
expr.target_type,
)

raise AssertionError


Expand Down Expand Up @@ -289,3 +337,30 @@ def _day_of_week(x):
@op.auto
def _mean(x):
return sqa.func.AVG(sqa.cast(x, sqa.Double()), type_=sqa.Double())


with MsSqlImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
# TODO: we still need to handle inf / -inf / nan
return sqa.case(
(x > 0, sqa.func.log(x)),
(x < 0, MsSqlImpl.NAN),
(x.is_(sqa.null()), None),
else_=-MsSqlImpl.INF,
)


with MsSqlImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return sqa.func.ceiling(x)


with MsSqlImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return sqa.cast(x, DATETIME2)
95 changes: 66 additions & 29 deletions src/pydiverse/transform/backend/polars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import datetime
from types import NoneType
from typing import Any
from uuid import UUID

Expand All @@ -15,6 +13,7 @@
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
CaseExpr,
Cast,
Col,
ColExpr,
ColFn,
Expand Down Expand Up @@ -159,7 +158,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:

# the function was executed on the ordered arguments. here we
# restore the original order of the table.
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64).sort_by(
inv_permutation = pl.int_range(0, pl.len(), dtype=pl.Int64()).sort_by(
by=order_by,
descending=descending,
nulls_last=nulls_last,
Expand All @@ -182,10 +181,20 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
return compiled

elif isinstance(expr, LiteralCol):
if isinstance(expr.dtype(), dtypes.String):
if expr.dtype() == dtypes.String:
return pl.lit(expr.val) # polars interprets strings as column names
return expr.val

elif isinstance(expr, Cast):
compiled = compile_col_expr(expr.val, name_in_df).cast(
pdt_type_to_polars(expr.target_type)
)

if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String:
compiled = compiled.replace("NaN", "nan")

return compiled

else:
raise AssertionError

Expand Down Expand Up @@ -340,9 +349,9 @@ def has_path_to_leaf_without_agg(expr: ColExpr):

def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:
if t.is_float():
return dtypes.Float()
return dtypes.Float64()
elif t.is_integer():
return dtypes.Int()
return dtypes.Int64()
elif isinstance(t, pl.Boolean):
return dtypes.Bool()
elif isinstance(t, pl.String):
Expand All @@ -360,9 +369,9 @@ def polars_type_to_pdt(t: pl.DataType) -> dtypes.Dtype:


def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
if isinstance(t, dtypes.Float):
if isinstance(t, (dtypes.Float64, dtypes.Decimal)):
return pl.Float64()
elif isinstance(t, dtypes.Int):
elif isinstance(t, dtypes.Int64):
return pl.Int64()
elif isinstance(t, dtypes.Bool):
return pl.Boolean()
Expand All @@ -380,27 +389,6 @@ def pdt_type_to_polars(t: dtypes.Dtype) -> pl.DataType:
raise AssertionError


def python_type_to_polars(t: type) -> pl.DataType:
if t is int:
return pl.Int64()
elif t is float:
return pl.Float64()
elif t is bool:
return pl.Boolean()
elif t is str:
return pl.String()
elif t is datetime.datetime:
return pl.Datetime()
elif t is datetime.date:
return pl.Date()
elif t is datetime.timedelta:
return pl.Duration()
elif t is NoneType:
return pl.Null()

raise TypeError(f"python builtin type {t} is not supported by pydiverse.transform")


with PolarsImpl.op(ops.Mean()) as op:

@op.auto
Expand Down Expand Up @@ -709,3 +697,52 @@ def _greatest(*x):
@op.auto
def _least(*x):
return pl.min_horizontal(*x)


with PolarsImpl.op(ops.Round()) as op:

@op.auto
def _round(x, digits=0):
return x.round(digits)


with PolarsImpl.op(ops.Exp()) as op:

@op.auto
def _exp(x):
return x.exp()


with PolarsImpl.op(ops.Log()) as op:

@op.auto
def _log(x):
return x.log()


with PolarsImpl.op(ops.Floor()) as op:

@op.auto
def _floor(x):
return x.floor()


with PolarsImpl.op(ops.Ceil()) as op:

@op.auto
def _ceil(x):
return x.ceil()


with PolarsImpl.op(ops.StrToDateTime()) as op:

@op.auto
def _str_to_datetime(x):
return x.str.to_datetime()


with PolarsImpl.op(ops.StrToDate()) as op:

@op.auto
def _str_to_date(x):
return x.str.to_date()
21 changes: 21 additions & 0 deletions src/pydiverse/transform/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,32 @@

from pydiverse.transform import ops
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.col_expr import Cast


class PostgresImpl(SqlImpl):
dialect_name = "postgresql"

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)

if isinstance(cast.val.dtype(), dtypes.Float64):
if isinstance(cast.target_type, dtypes.Int64):
return sqa.func.trunc(compiled_val).cast(sqa.BigInteger())

if isinstance(cast.target_type, dtypes.String):
compiled = sqa.cast(compiled_val, sqa.String)
return sqa.case(
(compiled == "NaN", "nan"),
(compiled == "Infinity", "inf"),
(compiled == "-Infinity", "-inf"),
else_=compiled,
)

return sqa.cast(compiled_val, cls.sqa_type(cast.target_type))


with PostgresImpl.op(ops.Less()) as op:

Expand Down
Loading

0 comments on commit 46ed4fe

Please sign in to comment.