Skip to content

Commit

Permalink
Merge pull request #31 from pydiverse/polars-duckdb
Browse files Browse the repository at this point in the history
Implement execution of DuckDB on a polars data frame
  • Loading branch information
finn-rudolph authored Oct 7, 2024
2 parents 09fd6ec + 58f59bf commit 997ac78
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 65 deletions.
1 change: 1 addition & 0 deletions src/pydiverse/transform/_internal/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from .duckdb import DuckDbImpl
from .duckdb_polars import DuckDbPolarsImpl
from .mssql import MsSqlImpl
from .polars import PolarsImpl
from .postgres import PostgresImpl
Expand Down
77 changes: 77 additions & 0 deletions src/pydiverse/transform/_internal/backend/duckdb_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from uuid import UUID

import duckdb
import duckdb_engine
import polars as pl
import sqlalchemy as sqa

from pydiverse.transform._internal.backend.duckdb import DuckDbImpl
from pydiverse.transform._internal.backend.polars import polars_type_to_pdt
from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Polars, Target
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Col


# TODO: we should move the engine of SqlImpl in the subclasses and let this thing
# inherit from SqlImpl in order to make the usage of SqlImpl.compile_ast more clean.
# Currently it works only since this class also has a table object, but it should be
# enforced by inheritance.
class DuckDbPolarsImpl(TableImpl):
def __init__(self, name: str, df: pl.DataFrame | pl.LazyFrame):
self.df = df if isinstance(df, pl.LazyFrame) else df.lazy()

super().__init__(
name,
{
name: polars_type_to_pdt(dtype)
for name, dtype in df.collect_schema().items()
},
)

self.table = sqa.Table(
name,
sqa.MetaData(),
*(
sqa.Column(col.name, DuckDbImpl.sqa_type(col.dtype()))
for col in self.cols.values()
),
)

@staticmethod
def build_query(nd: AstNode, final_select: list[Col]) -> str | None:
return DuckDbImpl.build_query(nd, final_select)

@staticmethod
def export(nd: AstNode, target: Target, final_select: list[Col]) -> pl.DataFrame:
if isinstance(target, Polars):
sel = DuckDbImpl.build_select(nd, final_select)
query_str = str(
sel.compile(
dialect=duckdb_engine.Dialect(),
compile_kwargs={"literal_binds": True},
)
)

# tell duckdb which table names in the SQL query correspond to which
# data frames
for desc in nd.iter_subtree():
if isinstance(desc, DuckDbPolarsImpl):
duckdb.register(desc.table.name, desc.df)

return duckdb.sql(query_str).pl()

raise AssertionError

def _clone(self) -> tuple[AstNode, dict[AstNode, AstNode], dict[UUID, UUID]]:
cloned = DuckDbPolarsImpl(self.name, self.df)
return (
cloned,
{self: cloned},
{
self.cols[name]._uuid: cloned.cols[name]._uuid
for name in self.cols.keys()
},
)
4 changes: 2 additions & 2 deletions src/pydiverse/transform/_internal/backend/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def compile_ast(
sqa.Label(lb.name + nd.suffix, lb) for lb in right_query.select
]

elif isinstance(nd, SqlImpl):
elif isinstance(nd, TableImpl):
table = nd.table
cols = [
sqa.type_coerce(col, cls.sqa_type(nd.cols[col.name].dtype())).label(
Expand Down Expand Up @@ -627,7 +627,7 @@ def create_aliases(nd: AstNode, num_occurences: dict[str, int]) -> dict[str, int
if isinstance(nd, verbs.Join):
num_occurences = create_aliases(nd.right, num_occurences)

elif isinstance(nd, SqlImpl):
elif isinstance(nd, TableImpl):
if cnt := num_occurences.get(nd.table.name):
nd.table = nd.table.alias(f"{nd.table.name}_{cnt}")
else:
Expand Down
58 changes: 54 additions & 4 deletions src/pydiverse/transform/_internal/backend/table_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from uuid import UUID

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.targets import Target
Expand Down Expand Up @@ -48,14 +52,60 @@ def __init_subclass__(cls, **kwargs):
break
cls.registry = OperatorRegistry(cls, super_reg)

@staticmethod
def from_resource(
resource: Any,
backend: Target | None = None,
*,
name: str | None = None,
uuids: dict[str, UUID] | None = None,
) -> TableImpl:
from pydiverse.transform._internal.backend import (
DuckDb,
DuckDbPolarsImpl,
Polars,
PolarsImpl,
SqlAlchemy,
SqlImpl,
)

if isinstance(resource, TableImpl):
res = resource

elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)):
if name is None:
# If the data frame has be previously exported by transform, a
# name attribute was added.
if hasattr(resource, "name"):
name = resource.name
else:
name = "?"
if backend is None or isinstance(backend, Polars):
res = PolarsImpl(name, resource)
elif isinstance(backend, DuckDb):
res = DuckDbPolarsImpl(name, resource)

elif isinstance(resource, (str, sqa.Table)):
if isinstance(backend, SqlAlchemy):
res = SqlImpl(resource, backend, name)

else:
raise AssertionError

if uuids is not None:
for name, col in res.cols.items():
col._uuid = uuids[name]

return res

def iter_subtree(self) -> Iterable[AstNode]:
yield self

@staticmethod
def build_query(nd: AstNode, final_select: list[Col]) -> str | None: ...
@classmethod
def build_query(cls, nd: AstNode, final_select: list[Col]) -> str | None: ...

@staticmethod
def export(nd: AstNode, target: Target, final_select: list[Col]) -> Any: ...
@classmethod
def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: ...

@classmethod
def _html_repr_expr(cls, expr):
Expand Down
41 changes: 14 additions & 27 deletions src/pydiverse/transform/_internal/pipe/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import inspect
from collections.abc import Callable, Iterable
from html import escape
from typing import Any
from uuid import UUID

import sqlalchemy as sqa

from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Target
from pydiverse.transform._internal.pipe.pipeable import Pipeable
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Col, ColName
Expand All @@ -25,36 +25,17 @@ class Table:

# TODO: define exactly what can be given for the two and do type checks
# maybe call the second one execution_engine or similar?
def __init__(self, resource, backend=None, *, name: str | None = None):
import polars as pl

from pydiverse.transform._internal.backend import (
PolarsImpl,
SqlAlchemy,
SqlImpl,
)

if isinstance(resource, TableImpl):
self._ast: AstNode = resource
elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)):
if name is None:
# TODO: we could look whether the df has a name attr set (which is the
# case if it was previously exported)
name = "?"
self._ast = PolarsImpl(name, resource)
elif isinstance(resource, (str, sqa.Table)):
if isinstance(backend, SqlAlchemy):
self._ast = SqlImpl(resource, backend, name)

if not isinstance(self._ast, TableImpl):
raise AssertionError

def __init__(
self, resource: Any, backend: Target | None = None, *, name: str | None = None
):
self._ast = TableImpl.from_resource(resource, backend, name=name)
self._cache = Cache(
self._ast.cols,
list(self._ast.cols.values()),
{col._uuid: col.name for col in self._ast.cols.values()},
[],
{self._ast},
{col._uuid: col for col in self._ast.cols.values()},
)

def __getitem__(self, key: str) -> Col:
Expand Down Expand Up @@ -165,7 +146,10 @@ class Cache:
select: list[Col]
uuid_to_name: dict[UUID, str] # only the selected UUIDs
partition_by: list[Col]
nodes: set[AstNode]
# all nodes that this table is derived from (it cannot be joined with another node
# having nonempty intersection of `derived_from`)
derived_from: set[AstNode]
all_cols: dict[UUID, Col] # all columns in current scope (including unnamed ones)

def has_col(self, col: str | Col | ColName) -> bool:
if isinstance(col, Col):
Expand All @@ -184,6 +168,9 @@ def update(
self.select = new_select
if new_cols is not None:
self.cols = new_cols
self.all_cols = self.all_cols | {
col._uuid: col for col in new_cols.values()
}

if new_select is not None or new_cols is not None:
selected_uuids = (
Expand Down
Loading

0 comments on commit 997ac78

Please sign in to comment.