diff --git a/src/pydiverse/transform/_internal/backend/__init__.py b/src/pydiverse/transform/_internal/backend/__init__.py index cc21fb29..1c07c218 100644 --- a/src/pydiverse/transform/_internal/backend/__init__.py +++ b/src/pydiverse/transform/_internal/backend/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations from .duckdb import DuckDbImpl +from .duckdb_polars import DuckDbPolarsImpl from .mssql import MsSqlImpl from .polars import PolarsImpl from .postgres import PostgresImpl diff --git a/src/pydiverse/transform/_internal/backend/duckdb_polars.py b/src/pydiverse/transform/_internal/backend/duckdb_polars.py new file mode 100644 index 00000000..d8a403b4 --- /dev/null +++ b/src/pydiverse/transform/_internal/backend/duckdb_polars.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from uuid import UUID + +import duckdb +import duckdb_engine +import polars as pl +import sqlalchemy as sqa + +from pydiverse.transform._internal.backend.duckdb import DuckDbImpl +from pydiverse.transform._internal.backend.polars import polars_type_to_pdt +from pydiverse.transform._internal.backend.table_impl import TableImpl +from pydiverse.transform._internal.backend.targets import Polars, Target +from pydiverse.transform._internal.tree.ast import AstNode +from pydiverse.transform._internal.tree.col_expr import Col + + +# TODO: we should move the engine of SqlImpl in the subclasses and let this thing +# inherit from SqlImpl in order to make the usage of SqlImpl.compile_ast more clean. +# Currently it works only since this class also has a table object, but it should be +# enforced by inheritance. +class DuckDbPolarsImpl(TableImpl): + def __init__(self, name: str, df: pl.DataFrame | pl.LazyFrame): + self.df = df if isinstance(df, pl.LazyFrame) else df.lazy() + + super().__init__( + name, + { + name: polars_type_to_pdt(dtype) + for name, dtype in df.collect_schema().items() + }, + ) + + self.table = sqa.Table( + name, + sqa.MetaData(), + *( + sqa.Column(col.name, DuckDbImpl.sqa_type(col.dtype())) + for col in self.cols.values() + ), + ) + + @staticmethod + def build_query(nd: AstNode, final_select: list[Col]) -> str | None: + return DuckDbImpl.build_query(nd, final_select) + + @staticmethod + def export(nd: AstNode, target: Target, final_select: list[Col]) -> pl.DataFrame: + if isinstance(target, Polars): + sel = DuckDbImpl.build_select(nd, final_select) + query_str = str( + sel.compile( + dialect=duckdb_engine.Dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + # tell duckdb which table names in the SQL query correspond to which + # data frames + for desc in nd.iter_subtree(): + if isinstance(desc, DuckDbPolarsImpl): + duckdb.register(desc.table.name, desc.df) + + return duckdb.sql(query_str).pl() + + raise AssertionError + + def _clone(self) -> tuple[AstNode, dict[AstNode, AstNode], dict[UUID, UUID]]: + cloned = DuckDbPolarsImpl(self.name, self.df) + return ( + cloned, + {self: cloned}, + { + self.cols[name]._uuid: cloned.cols[name]._uuid + for name in self.cols.keys() + }, + ) diff --git a/src/pydiverse/transform/_internal/backend/sql.py b/src/pydiverse/transform/_internal/backend/sql.py index 6fd109ce..c825240a 100644 --- a/src/pydiverse/transform/_internal/backend/sql.py +++ b/src/pydiverse/transform/_internal/backend/sql.py @@ -505,7 +505,7 @@ def compile_ast( sqa.Label(lb.name + nd.suffix, lb) for lb in right_query.select ] - elif isinstance(nd, SqlImpl): + elif isinstance(nd, TableImpl): table = nd.table cols = [ sqa.type_coerce(col, cls.sqa_type(nd.cols[col.name].dtype())).label( @@ -627,7 +627,7 @@ def create_aliases(nd: AstNode, num_occurences: dict[str, int]) -> dict[str, int if isinstance(nd, verbs.Join): num_occurences = create_aliases(nd.right, num_occurences) - elif isinstance(nd, SqlImpl): + elif isinstance(nd, TableImpl): if cnt := num_occurences.get(nd.table.name): nd.table = nd.table.alias(f"{nd.table.name}_{cnt}") else: diff --git a/src/pydiverse/transform/_internal/backend/table_impl.py b/src/pydiverse/transform/_internal/backend/table_impl.py index 7d22db68..061f3c53 100644 --- a/src/pydiverse/transform/_internal/backend/table_impl.py +++ b/src/pydiverse/transform/_internal/backend/table_impl.py @@ -3,6 +3,10 @@ import uuid from collections.abc import Iterable from typing import TYPE_CHECKING, Any +from uuid import UUID + +import polars as pl +import sqlalchemy as sqa from pydiverse.transform._internal import ops from pydiverse.transform._internal.backend.targets import Target @@ -48,14 +52,60 @@ def __init_subclass__(cls, **kwargs): break cls.registry = OperatorRegistry(cls, super_reg) + @staticmethod + def from_resource( + resource: Any, + backend: Target | None = None, + *, + name: str | None = None, + uuids: dict[str, UUID] | None = None, + ) -> TableImpl: + from pydiverse.transform._internal.backend import ( + DuckDb, + DuckDbPolarsImpl, + Polars, + PolarsImpl, + SqlAlchemy, + SqlImpl, + ) + + if isinstance(resource, TableImpl): + res = resource + + elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)): + if name is None: + # If the data frame has be previously exported by transform, a + # name attribute was added. + if hasattr(resource, "name"): + name = resource.name + else: + name = "?" + if backend is None or isinstance(backend, Polars): + res = PolarsImpl(name, resource) + elif isinstance(backend, DuckDb): + res = DuckDbPolarsImpl(name, resource) + + elif isinstance(resource, (str, sqa.Table)): + if isinstance(backend, SqlAlchemy): + res = SqlImpl(resource, backend, name) + + else: + raise AssertionError + + if uuids is not None: + for name, col in res.cols.items(): + col._uuid = uuids[name] + + return res + def iter_subtree(self) -> Iterable[AstNode]: yield self - @staticmethod - def build_query(nd: AstNode, final_select: list[Col]) -> str | None: ... + @classmethod + def build_query(cls, nd: AstNode, final_select: list[Col]) -> str | None: ... - @staticmethod - def export(nd: AstNode, target: Target, final_select: list[Col]) -> Any: ... + @classmethod + def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: ... @classmethod def _html_repr_expr(cls, expr): diff --git a/src/pydiverse/transform/_internal/pipe/table.py b/src/pydiverse/transform/_internal/pipe/table.py index 0a93c09f..793330e0 100644 --- a/src/pydiverse/transform/_internal/pipe/table.py +++ b/src/pydiverse/transform/_internal/pipe/table.py @@ -5,11 +5,11 @@ import inspect from collections.abc import Callable, Iterable from html import escape +from typing import Any from uuid import UUID -import sqlalchemy as sqa - from pydiverse.transform._internal.backend.table_impl import TableImpl +from pydiverse.transform._internal.backend.targets import Target from pydiverse.transform._internal.pipe.pipeable import Pipeable from pydiverse.transform._internal.tree.ast import AstNode from pydiverse.transform._internal.tree.col_expr import Col, ColName @@ -25,36 +25,17 @@ class Table: # TODO: define exactly what can be given for the two and do type checks # maybe call the second one execution_engine or similar? - def __init__(self, resource, backend=None, *, name: str | None = None): - import polars as pl - - from pydiverse.transform._internal.backend import ( - PolarsImpl, - SqlAlchemy, - SqlImpl, - ) - - if isinstance(resource, TableImpl): - self._ast: AstNode = resource - elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)): - if name is None: - # TODO: we could look whether the df has a name attr set (which is the - # case if it was previously exported) - name = "?" - self._ast = PolarsImpl(name, resource) - elif isinstance(resource, (str, sqa.Table)): - if isinstance(backend, SqlAlchemy): - self._ast = SqlImpl(resource, backend, name) - - if not isinstance(self._ast, TableImpl): - raise AssertionError - + def __init__( + self, resource: Any, backend: Target | None = None, *, name: str | None = None + ): + self._ast = TableImpl.from_resource(resource, backend, name=name) self._cache = Cache( self._ast.cols, list(self._ast.cols.values()), {col._uuid: col.name for col in self._ast.cols.values()}, [], {self._ast}, + {col._uuid: col for col in self._ast.cols.values()}, ) def __getitem__(self, key: str) -> Col: @@ -165,7 +146,10 @@ class Cache: select: list[Col] uuid_to_name: dict[UUID, str] # only the selected UUIDs partition_by: list[Col] - nodes: set[AstNode] + # all nodes that this table is derived from (it cannot be joined with another node + # having nonempty intersection of `derived_from`) + derived_from: set[AstNode] + all_cols: dict[UUID, Col] # all columns in current scope (including unnamed ones) def has_col(self, col: str | Col | ColName) -> bool: if isinstance(col, Col): @@ -184,6 +168,9 @@ def update( self.select = new_select if new_cols is not None: self.cols = new_cols + self.all_cols = self.all_cols | { + col._uuid: col for col in new_cols.values() + } if new_select is not None or new_cols is not None: selected_uuids = ( diff --git a/src/pydiverse/transform/_internal/pipe/verbs.py b/src/pydiverse/transform/_internal/pipe/verbs.py index d5da29ee..a67d38cb 100644 --- a/src/pydiverse/transform/_internal/pipe/verbs.py +++ b/src/pydiverse/transform/_internal/pipe/verbs.py @@ -80,25 +80,27 @@ def alias(table: Table, new_name: str | None = None): # We could also do lazy alias, e.g. wait until a join happens and then only copy # the common subtree. + new._cache.all_cols = { + uuid_map[uid]: Col( + col.name, nd_map[col._ast], uuid_map[uid], col._dtype, col._ftype + ) + for uid, col in table._cache.all_cols.items() + } new._cache.partition_by = [ - Col(col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype) - for col in table._cache.partition_by + new._cache.all_cols[uuid_map[col._uuid]] for col in table._cache.partition_by ] new._cache.update( new_select=[ - Col(col.name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype) - for col in table._cache.select + new._cache.all_cols[uuid_map[col._uuid]] for col in table._cache.select ], new_cols={ - name: Col( - name, nd_map[col._ast], uuid_map[col._uuid], col._dtype, col._ftype - ) + name: new._cache.all_cols[uuid_map[col._uuid]] for name, col in table._cache.cols.items() }, ) - new._cache.nodes = set(new._ast.iter_subtree()) + new._cache.derived_from = set(new._ast.iter_subtree()) return new @@ -107,11 +109,25 @@ def alias(table: Table, new_name: str | None = None): def collect(table: Table, target: Target | None = None) -> Table: errors.check_arg_type(Target | None, "collect", "target", target) - df = table >> export(Polars(lazy=False)) + df = table >> select(*table._cache.all_cols.values()) >> export(Polars(lazy=False)) if target is None: target = Polars() - return Table(df, target) + new = Table( + TableImpl.from_resource( + df, + target, + name=table._ast.name, + # preserve UUIDs and by this column references across collect() + uuids={name: col._uuid for name, col in table._cache.cols.items()}, + ) + ) + new._cache.derived_from = table._cache.derived_from | {new._ast} + new._cache.select = [preprocess_arg(col, new) for col in table._cache.select] + new._cache.partition_by = [ + preprocess_arg(col, new) for col in table._cache.partition_by + ] + return new @verb @@ -147,7 +163,7 @@ def select(table: Table, *cols: Col | ColName): new._cache = copy.copy(table._cache) # TODO: prevent selection of overwritten columns new._cache.update(new_select=new._ast.select) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -187,7 +203,7 @@ def rename(table: Table, name_map: dict[str, str]): new_cols[replacement] = table._cache.cols[name] new._cache.update(new_cols=new_cols) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -226,7 +242,7 @@ def mutate(table: Table, **kwargs: ColExpr): + [new_cols[name] for name in new._ast.names], new_cols=new_cols, ) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -258,7 +274,7 @@ def filter(table: Table, *predicates: ColExpr): ) new._cache = copy.copy(table._cache) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -274,7 +290,7 @@ def arrange(table: Table, *order_by: ColExpr): preprocess_arg((Order.from_col_expr(ord) for ord in order_by), table), ) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -295,7 +311,7 @@ def group_by(table: Table, *cols: Col | ColName, add=False): else: new._cache.partition_by = new._ast.group_by - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -360,7 +376,7 @@ def check_summarize_col_expr(expr: ColExpr, agg_fn_above: bool): new_cols=new_cols, ) new._cache.partition_by = [] - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -376,7 +392,7 @@ def slice_head(table: Table, n: int, *, offset: int = 0): new = copy.copy(table) new._ast = SliceHead(table._ast, n, offset) new._cache = copy.copy(table._cache) - new._cache.nodes = table._cache.nodes | {new._ast} + new._cache.derived_from = table._cache.derived_from | {new._ast} return new @@ -403,7 +419,7 @@ def join( elif right._cache.partition_by: raise ValueError(f"cannot join grouped table `{right._ast.name}`") - if intersection := left._cache.nodes & right._cache.nodes: + if intersection := left._cache.derived_from & right._cache.derived_from: raise ValueError( f"table `{list(intersection)[0]}` occurs twice in the table " "tree.\nhint: To join two tables derived from a common table, " @@ -448,7 +464,9 @@ def join( | {name + suffix: col for name, col in right._cache.cols.items()}, new_select=left._cache.select + right._cache.select, ) - new._cache.nodes = left._cache.nodes | right._cache.nodes | {new._ast} + new._cache.derived_from = ( + left._cache.derived_from | right._cache.derived_from | {new._ast} + ) new._ast.on = preprocess_arg(new._ast.on, new, update_partition_by=False) return new @@ -522,24 +540,26 @@ def preprocess_arg(arg: Any, table: Table, *, update_partition_by: bool = True) arg = wrap_literal(arg) assert isinstance(arg, ColExpr) - arg: ColExpr = arg.map_subtree( - lambda col: col if not isinstance(col, ColName) else table[col.name] - ) - - for cexpr in arg.iter_subtree(): - if isinstance(cexpr, Col) and cexpr._ast not in table._cache.nodes: + for expr in arg.iter_subtree(): + if isinstance(expr, Col) and expr._ast not in table._cache.derived_from: raise ValueError( - f"table `{cexpr._ast.name}` used to reference the column " - f"`{repr(cexpr)}` cannot be used at this point. The current table " + f"table `{expr._ast.name}` used to reference the column " + f"`{repr(expr)}` cannot be used at this point. The current table " "is not derived from it." ) if ( update_partition_by - and isinstance(cexpr, ColFn) - and "partition_by" not in cexpr.context_kwargs - and (cexpr.op().ftype in (Ftype.WINDOW, Ftype.AGGREGATE)) + and isinstance(expr, ColFn) + and "partition_by" not in expr.context_kwargs + and (expr.op().ftype in (Ftype.WINDOW, Ftype.AGGREGATE)) ): - cexpr.context_kwargs["partition_by"] = table._cache.partition_by + expr.context_kwargs["partition_by"] = table._cache.partition_by + + arg: ColExpr = arg.map_subtree( + lambda col: table[col.name] + if isinstance(col, ColName) + else (table._cache.all_cols[col._uuid] if isinstance(col, Col) else col) + ) return arg diff --git a/tests/test_polars_table.py b/tests/test_polars_table.py index e3c21315..904ce5f0 100644 --- a/tests/test_polars_table.py +++ b/tests/test_polars_table.py @@ -554,6 +554,32 @@ def test_datetime(self, tbl_dt): ), ) + def test_duckdb_execution(self, tbl2, tbl3): + assert_equal( + tbl3 + >> mutate(u=tbl3.col1 * 2) + >> collect(DuckDb()) + >> mutate(v=tbl3.col3 + C.u), + tbl3 >> mutate(u=tbl3.col1 * 2) >> mutate(v=C.col3 + C.u), + ) + + assert_equal( + tbl3 + >> collect(DuckDb()) + >> left_join( + tbl2 >> collect(DuckDb()), tbl3.col1 == tbl2.col1, suffix="_right" + ) + >> mutate(v=tbl3.col3 + tbl2.col2) + >> group_by(C.col2) + >> summarize(y=C.col3_right.sum()), + tbl3 + >> left_join(tbl2, C.col1 == C.col1_right, suffix="_right") + >> mutate(v=C.col3 + C.col2_right) + >> group_by(C.col2) + >> summarize(y=C.col3_right.sum()), + check_row_order=False, + ) + class TestPrintAndRepr: def test_table_str(self, tbl1):