Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement execution of DuckDB on a polars data frame #31

Merged
merged 8 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pydiverse/transform/_internal/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from .duckdb import DuckDbImpl
from .duckdb_polars import DuckDbPolarsImpl
from .mssql import MsSqlImpl
from .polars import PolarsImpl
from .postgres import PostgresImpl
Expand Down
77 changes: 77 additions & 0 deletions src/pydiverse/transform/_internal/backend/duckdb_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from uuid import UUID

import duckdb
import duckdb_engine
import polars as pl
import sqlalchemy as sqa

from pydiverse.transform._internal.backend.duckdb import DuckDbImpl
from pydiverse.transform._internal.backend.polars import polars_type_to_pdt
from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Polars, Target
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Col


# TODO: we should move the engine of SqlImpl in the subclasses and let this thing
# inherit from SqlImpl in order to make the usage of SqlImpl.compile_ast more clean.
# Currently it works only since this class also has a table object, but it should be
# enforced by inheritance.
class DuckDbPolarsImpl(TableImpl):
def __init__(self, name: str, df: pl.DataFrame | pl.LazyFrame):
self.df = df if isinstance(df, pl.LazyFrame) else df.lazy()

super().__init__(
name,
{
name: polars_type_to_pdt(dtype)
for name, dtype in df.collect_schema().items()
},
)

self.table = sqa.Table(
name,
sqa.MetaData(),
*(
sqa.Column(col.name, DuckDbImpl.sqa_type(col.dtype()))
for col in self.cols.values()
),
)

@staticmethod
def build_query(nd: AstNode, final_select: list[Col]) -> str | None:
return DuckDbImpl.build_query(nd, final_select)

@staticmethod
def export(nd: AstNode, target: Target, final_select: list[Col]) -> pl.DataFrame:
if isinstance(target, Polars):
sel = DuckDbImpl.build_select(nd, final_select)
query_str = str(
sel.compile(
dialect=duckdb_engine.Dialect(),
compile_kwargs={"literal_binds": True},
)
)

# tell duckdb which table names in the SQL query correspond to which
# data frames
for desc in nd.iter_subtree():
if isinstance(desc, DuckDbPolarsImpl):
duckdb.register(desc.table.name, desc.df)

return duckdb.sql(query_str).pl()

raise AssertionError

def _clone(self) -> tuple[AstNode, dict[AstNode, AstNode], dict[UUID, UUID]]:
cloned = DuckDbPolarsImpl(self.name, self.df)
return (
cloned,
{self: cloned},
{
self.cols[name]._uuid: cloned.cols[name]._uuid
for name in self.cols.keys()
},
)
4 changes: 2 additions & 2 deletions src/pydiverse/transform/_internal/backend/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def compile_ast(
sqa.Label(lb.name + nd.suffix, lb) for lb in right_query.select
]

elif isinstance(nd, SqlImpl):
elif isinstance(nd, TableImpl):
table = nd.table
cols = [
sqa.type_coerce(col, cls.sqa_type(nd.cols[col.name].dtype())).label(
Expand Down Expand Up @@ -627,7 +627,7 @@ def create_aliases(nd: AstNode, num_occurences: dict[str, int]) -> dict[str, int
if isinstance(nd, verbs.Join):
num_occurences = create_aliases(nd.right, num_occurences)

elif isinstance(nd, SqlImpl):
elif isinstance(nd, TableImpl):
if cnt := num_occurences.get(nd.table.name):
nd.table = nd.table.alias(f"{nd.table.name}_{cnt}")
else:
Expand Down
58 changes: 54 additions & 4 deletions src/pydiverse/transform/_internal/backend/table_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from uuid import UUID

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.targets import Target
Expand Down Expand Up @@ -48,14 +52,60 @@ def __init_subclass__(cls, **kwargs):
break
cls.registry = OperatorRegistry(cls, super_reg)

@staticmethod
def from_resource(
resource: Any,
backend: Target | None = None,
*,
name: str | None = None,
uuids: dict[str, UUID] | None = None,
) -> TableImpl:
from pydiverse.transform._internal.backend import (
DuckDb,
DuckDbPolarsImpl,
Polars,
PolarsImpl,
SqlAlchemy,
SqlImpl,
)

if isinstance(resource, TableImpl):
res = resource

elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)):
if name is None:
# If the data frame has be previously exported by transform, a
# name attribute was added.
if hasattr(resource, "name"):
name = resource.name
else:
name = "?"
if backend is None or isinstance(backend, Polars):
res = PolarsImpl(name, resource)
elif isinstance(backend, DuckDb):
res = DuckDbPolarsImpl(name, resource)

elif isinstance(resource, (str, sqa.Table)):
if isinstance(backend, SqlAlchemy):
res = SqlImpl(resource, backend, name)

else:
raise AssertionError

if uuids is not None:
for name, col in res.cols.items():
col._uuid = uuids[name]

return res

def iter_subtree(self) -> Iterable[AstNode]:
yield self

@staticmethod
def build_query(nd: AstNode, final_select: list[Col]) -> str | None: ...
@classmethod
def build_query(cls, nd: AstNode, final_select: list[Col]) -> str | None: ...

@staticmethod
def export(nd: AstNode, target: Target, final_select: list[Col]) -> Any: ...
@classmethod
def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: ...

@classmethod
def _html_repr_expr(cls, expr):
Expand Down
41 changes: 14 additions & 27 deletions src/pydiverse/transform/_internal/pipe/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import inspect
from collections.abc import Callable, Iterable
from html import escape
from typing import Any
from uuid import UUID

import sqlalchemy as sqa

from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Target
from pydiverse.transform._internal.pipe.pipeable import Pipeable
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Col, ColName
Expand All @@ -25,36 +25,17 @@ class Table:

# TODO: define exactly what can be given for the two and do type checks
# maybe call the second one execution_engine or similar?
def __init__(self, resource, backend=None, *, name: str | None = None):
import polars as pl

from pydiverse.transform._internal.backend import (
PolarsImpl,
SqlAlchemy,
SqlImpl,
)

if isinstance(resource, TableImpl):
self._ast: AstNode = resource
elif isinstance(resource, (pl.DataFrame, pl.LazyFrame)):
if name is None:
# TODO: we could look whether the df has a name attr set (which is the
# case if it was previously exported)
name = "?"
self._ast = PolarsImpl(name, resource)
elif isinstance(resource, (str, sqa.Table)):
if isinstance(backend, SqlAlchemy):
self._ast = SqlImpl(resource, backend, name)

if not isinstance(self._ast, TableImpl):
raise AssertionError

def __init__(
self, resource: Any, backend: Target | None = None, *, name: str | None = None
):
self._ast = TableImpl.from_resource(resource, backend, name=name)
self._cache = Cache(
self._ast.cols,
list(self._ast.cols.values()),
{col._uuid: col.name for col in self._ast.cols.values()},
[],
{self._ast},
{col._uuid: col for col in self._ast.cols.values()},
)

def __getitem__(self, key: str) -> Col:
Expand Down Expand Up @@ -165,7 +146,10 @@ class Cache:
select: list[Col]
uuid_to_name: dict[UUID, str] # only the selected UUIDs
partition_by: list[Col]
nodes: set[AstNode]
# all nodes that this table is derived from (it cannot be joined with another node
# having nonempty intersection of `derived_from`)
derived_from: set[AstNode]
all_cols: dict[UUID, Col] # all columns in current scope (including unnamed ones)

def has_col(self, col: str | Col | ColName) -> bool:
if isinstance(col, Col):
Expand All @@ -184,6 +168,9 @@ def update(
self.select = new_select
if new_cols is not None:
self.cols = new_cols
self.all_cols = self.all_cols | {
col._uuid: col for col in new_cols.values()
}

if new_select is not None or new_cols is not None:
selected_uuids = (
Expand Down
Loading