Skip to content

Commit

Permalink
Merge pull request #30 from pydiverse/api
Browse files Browse the repository at this point in the history
Stabilize public API
  • Loading branch information
finn-rudolph authored Oct 4, 2024
2 parents 6d1968a + 781fc9e commit 09fd6ec
Show file tree
Hide file tree
Showing 66 changed files with 712 additions and 560 deletions.
46 changes: 8 additions & 38 deletions src/pydiverse/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,11 @@
from __future__ import annotations

from pydiverse.transform.backend.targets import DuckDb, Polars, SqlAlchemy
from pydiverse.transform.pipe.c import C
from pydiverse.transform.pipe.functions import (
count,
dense_rank,
max,
min,
rank,
row_number,
when,
)
from pydiverse.transform.pipe.pipeable import verb
from pydiverse.transform.pipe.table import Table
from pydiverse.transform.tree.dtypes import (
Bool,
Date,
DateTime,
Duration,
Float64,
Int64,
String,
)
from ._internal.pipe.pipeable import verb
from ._internal.pipe.table import Table
from ._internal.tree.col_expr import ColExpr
from .extended import *
from .extended import __all__ as __extended
from .types import *
from .types import __all__ as __types

__all__ = [
"Polars",
"SqlAlchemy",
"DuckDb",
"Table",
"aligned",
"verb",
"C",
"Float64",
"Int64",
"String",
"Bool",
"DateTime",
"Date",
"Duration",
]
__all__ = __extended + __types + ["Table", "ColExpr", "verb"]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import sqlalchemy as sqa
from sqlalchemy.sql.type_api import TypeEngine as TypeEngine

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.tree import dtypes, verbs
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import Cast, Col, ColFn, LiteralCol
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend import sql
from pydiverse.transform._internal.backend.sql import SqlImpl
from pydiverse.transform._internal.backend.targets import Polars, Target
from pydiverse.transform._internal.tree import dtypes, verbs
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Cast, Col, ColFn, LiteralCol


class DuckDbImpl(SqlImpl):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import sqlalchemy as sqa
from sqlalchemy.dialects.mssql import DATETIME2

from pydiverse.transform import ops
from pydiverse.transform.backend import sql
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.errors import NotSupportedError
from pydiverse.transform.tree import dtypes, verbs
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend import sql
from pydiverse.transform._internal.backend.sql import SqlImpl
from pydiverse.transform._internal.errors import NotSupportedError
from pydiverse.transform._internal.tree import dtypes, verbs
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import (
CaseExpr,
Cast,
Col,
Expand All @@ -22,7 +22,7 @@
LiteralCol,
Order,
)
from pydiverse.transform.util.warnings import warn_non_standard
from pydiverse.transform._internal.util.warnings import warn_non_standard


class MsSqlImpl(SqlImpl):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import polars as pl

from pydiverse.transform import ops
from pydiverse.transform.backend.table_impl import TableImpl
from pydiverse.transform.backend.targets import Polars, Target
from pydiverse.transform.ops.core import Ftype
from pydiverse.transform.tree import dtypes, verbs
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Polars, Target
from pydiverse.transform._internal.ops.core import Ftype
from pydiverse.transform._internal.tree import dtypes, verbs
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import (
CaseExpr,
Cast,
Col,
Expand Down Expand Up @@ -227,7 +227,7 @@ def compile_ast(
if isinstance(nd, verbs.Verb):
df, name_in_df, select, partition_by = compile_ast(nd.child)

if isinstance(nd, (verbs.Mutate, verbs.Summarise)):
if isinstance(nd, (verbs.Mutate, verbs.Summarize)):
overwritten = set(name for name in nd.names if name in set(select))
if overwritten:
# We rename overwritten cols to some unique dummy name
Expand Down Expand Up @@ -279,16 +279,13 @@ def compile_ast(
maintain_order=True,
)

elif isinstance(nd, verbs.Summarise):
# We support usage of aggregated columns in expressions in summarise, but polars
elif isinstance(nd, verbs.Summarize):
# We support usage of aggregated columns in expressions in summarize, but polars
# creates arrays when doing that. Thus we unwrap the arrays when necessary.
def has_path_to_leaf_without_agg(expr: ColExpr):
if isinstance(expr, Col):
return True
if (
isinstance(expr, ColFn)
and PolarsImpl.registry.get_op(expr.name).ftype == Ftype.AGGREGATE
):
if isinstance(expr, ColFn) and expr.op().ftype == Ftype.AGGREGATE:
return False
return any(
has_path_to_leaf_without_agg(child) for child in expr.iter_children()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import sqlalchemy as sqa

from pydiverse.transform import ops
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.col_expr import Cast
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.sql import SqlImpl
from pydiverse.transform._internal.tree import dtypes
from pydiverse.transform._internal.tree.col_expr import Cast


class PostgresImpl(SqlImpl):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,21 @@
import math
import operator
from collections.abc import Iterable
from typing import Any
from typing import Any, Literal
from uuid import UUID

import polars as pl
import sqlalchemy as sqa

from pydiverse.transform import ops
from pydiverse.transform.backend.polars import pdt_type_to_polars
from pydiverse.transform.backend.table_impl import TableImpl
from pydiverse.transform.backend.targets import Polars, SqlAlchemy, Target
from pydiverse.transform.ops.core import Ftype
from pydiverse.transform.tree import dtypes, verbs
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.polars import pdt_type_to_polars
from pydiverse.transform._internal.backend.table_impl import TableImpl
from pydiverse.transform._internal.backend.targets import Polars, SqlAlchemy, Target
from pydiverse.transform._internal.errors import SubqueryError
from pydiverse.transform._internal.ops.core import Ftype
from pydiverse.transform._internal.tree import dtypes, verbs
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import (
CaseExpr,
Cast,
Col,
Expand All @@ -29,7 +30,7 @@
LiteralCol,
Order,
)
from pydiverse.transform.tree.dtypes import Dtype
from pydiverse.transform._internal.tree.dtypes import Dtype


class SqlImpl(TableImpl):
Expand Down Expand Up @@ -251,7 +252,7 @@ def compile_query(cls, table: sqa.Table, query: Query) -> sqa.sql.Select:
j.right,
onclause=j.on,
isouter=j.how != "inner",
full=j.how == "outer",
full=j.how == "full",
)

if query.where:
Expand Down Expand Up @@ -297,7 +298,7 @@ def compile_ast(
nd,
(
verbs.Filter,
verbs.Summarise,
verbs.Summarize,
verbs.Arrange,
verbs.GroupBy,
verbs.Join,
Expand All @@ -306,15 +307,30 @@ def compile_ast(
and query.limit is not None
)
or (
isinstance(nd, (verbs.Mutate, verbs.Filter))
isinstance(nd, verbs.Mutate)
and any(
node.ftype(agg_is_window=True) == Ftype.WINDOW
for node in nd.iter_col_nodes()
if isinstance(node, Col)
any(
col.ftype(agg_is_window=True) in (Ftype.WINDOW, Ftype.AGGREGATE)
for col in fn.iter_subtree()
if isinstance(col, Col)
)
for fn in nd.iter_col_nodes()
if (
isinstance(fn, ColFn)
and fn.op().ftype in (Ftype.AGGREGATE, Ftype.WINDOW)
)
)
)
or (
isinstance(nd, verbs.Filter)
and any(
col.ftype(agg_is_window=True) == Ftype.WINDOW
for col in nd.iter_col_nodes()
if isinstance(col, Col)
)
)
or (
isinstance(nd, verbs.Summarise)
isinstance(nd, verbs.Summarize)
and (
(
bool(query.group_by)
Expand All @@ -331,6 +347,15 @@ def compile_ast(
)
)
):
if not isinstance(nd.child, verbs.Alias):
raise SubqueryError(
f"forbidden subquery required during compilation of `{repr(nd)}`\n"
"hint: If you are sure you want to do a subquery, put an "
"`>> alias()` before this verb. On the other hand, if you want to "
"write out the table of the subquery, put `>> materialize()` "
"before this verb."
)

if needed_cols.keys().isdisjoint(sqa_col.keys()):
# We cannot select zero columns from a subquery. This happens when the
# user only 0-ary functions after the subquery, e.g. `count`.
Expand Down Expand Up @@ -371,7 +396,7 @@ def compile_ast(
],
)

if isinstance(nd, (verbs.Mutate, verbs.Summarise)):
if isinstance(nd, (verbs.Mutate, verbs.Summarize)):
query.select = [lb for lb in query.select if lb.name not in set(nd.names)]

if isinstance(nd, verbs.Select):
Expand Down Expand Up @@ -420,7 +445,7 @@ def compile_ast(
)
)

elif isinstance(nd, verbs.Summarise):
elif isinstance(nd, verbs.Summarize):
query.group_by.extend(query.partition_by)

for name, val, uid in zip(nd.names, nd.values, nd.uuids):
Expand Down Expand Up @@ -471,9 +496,9 @@ def compile_ast(
query.where.extend(right_query.where)
elif nd.how == "left":
j.on = functools.reduce(operator.and_, (j.on, *right_query.where))
elif nd.how == "outer":
elif nd.how == "full":
if query.where or right_query.where:
raise ValueError("invalid filter before outer join")
raise ValueError("invalid filter before full join")

query.join.append(j)
query.select += [
Expand Down Expand Up @@ -570,7 +595,7 @@ class Query:
class SqlJoin:
right: sqa.Subquery
on: sqa.ColumnElement
how: verbs.JoinHow
how: Literal["inner", "left", "full"]


# MSSQL complains about duplicates in ORDER BY.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import sqlalchemy as sqa

from pydiverse.transform import ops
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.errors import NotSupportedError
from pydiverse.transform.tree import dtypes
from pydiverse.transform.tree.col_expr import Cast
from pydiverse.transform.util.warnings import warn_non_standard
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.sql import SqlImpl
from pydiverse.transform._internal.errors import NotSupportedError
from pydiverse.transform._internal.tree import dtypes
from pydiverse.transform._internal.tree.col_expr import Cast
from pydiverse.transform._internal.util.warnings import warn_non_standard


class SqliteImpl(SqlImpl):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any

from pydiverse.transform import ops
from pydiverse.transform.backend.targets import Target
from pydiverse.transform.ops.core import Ftype
from pydiverse.transform.tree.ast import AstNode
from pydiverse.transform.tree.col_expr import (
from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend.targets import Target
from pydiverse.transform._internal.ops.core import Ftype
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import (
Col,
)
from pydiverse.transform.tree.dtypes import Dtype
from pydiverse.transform.tree.registry import (
from pydiverse.transform._internal.tree.dtypes import Dtype
from pydiverse.transform._internal.tree.registry import (
OperatorRegistrationContextManager,
OperatorRegistry,
)

if TYPE_CHECKING:
from pydiverse.transform.ops import Operator
from pydiverse.transform._internal.ops import Operator


class TableImpl(AstNode):
Expand Down
File renamed without changes.
Loading

0 comments on commit 09fd6ec

Please sign in to comment.