Skip to content

Commit

Permalink
make some horizontal functions actual operators
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Jan 11, 2025
1 parent 74c242d commit 3c27e4c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 31 deletions.
3 changes: 3 additions & 0 deletions docs/source/reference/operators/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ Global functions
.. autosummary::
:nosignatures:

all
any
coalesce
count
dense_rank
Expand All @@ -121,4 +123,5 @@ Global functions
min
rank
row_number
sum
when
13 changes: 13 additions & 0 deletions src/pydiverse/transform/_internal/backend/table_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import uuid
from collections.abc import Generator, Iterable, Sequence
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -227,3 +228,15 @@ def _gt(lhs, rhs):
@impl(ops.greater_equal)
def _ge(lhs, rhs):
return lhs >= rhs

@impl(ops.horizontal_all)
def _horizontal_all(*args):
return functools.reduce(_and, args)

@impl(ops.horizontal_any)
def _horizontal_any(*args):
return functools.reduce(_or, args)

@impl(ops.horizontal_sum)
def _horizontal_sum(*args):
return functools.reduce(_add, args)
20 changes: 19 additions & 1 deletion src/pydiverse/transform/_internal/ops/ops/horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from pydiverse.transform._internal.ops.op import Operator
from pydiverse.transform._internal.ops.signature import Signature
from pydiverse.transform._internal.tree.types import COMPARABLE, D
from pydiverse.transform._internal.tree.types import (
COMPARABLE,
NUMERIC,
Bool,
D,
Duration,
String,
)


class Horizontal(Operator):
Expand Down Expand Up @@ -127,3 +134,14 @@ def __init__(self, name: str, *signatures: Signature, doc: str = ""):
└──────┴──────┴──────┴─────┴─────┘
""",
)

horizontal_any = Horizontal("any", Signature(Bool(), Bool(), ..., return_type=Bool()))

horizontal_all = Horizontal("all", Signature(Bool(), Bool(), ..., return_type=Bool()))

horizontal_sum = Horizontal(
"sum",
*(Signature(dtype, dtype, ..., return_type=dtype) for dtype in NUMERIC),
Signature(String(), String(), ..., return_type=String()),
Signature(Duration(), Duration(), ..., return_type=Duration()),
)
69 changes: 39 additions & 30 deletions src/pydiverse/transform/_internal/pipe/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import functools
import operator
from collections.abc import Iterable
from typing import Any, overload

Expand All @@ -24,6 +22,7 @@
Datetime,
Decimal,
Dtype,
Duration,
Float,
Int,
String,
Expand Down Expand Up @@ -55,34 +54,6 @@ def lit(val: Any, dtype: Dtype | None = None) -> LiteralCol:
return LiteralCol(val, dtype)


def all(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]:
return functools.reduce(operator.and_, (arg, *args))


def any(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]:
return functools.reduce(operator.or_, (arg, *args))


@overload
def sum(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ...


@overload
def sum(arg: ColExpr[Float], *args: ColExpr[Float]) -> ColExpr[Float]: ...


@overload
def sum(arg: ColExpr[Decimal], *args: ColExpr[Decimal]) -> ColExpr[Decimal]: ...


@overload
def sum(arg: ColExpr[String], *args: ColExpr[String]) -> ColExpr[String]: ...


def sum(arg: ColExpr, *args: ColExpr) -> ColExpr:
return functools.reduce(operator.add, (arg, *args))


# --- from here the code is generated, do not delete this comment ---


Expand Down Expand Up @@ -187,6 +158,18 @@ def dense_rank(
return ColFn(ops.dense_rank, partition_by=partition_by, arrange=arrange)


def all(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]:
""""""

return ColFn(ops.horizontal_all, arg, *args)


def any(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]:
""""""

return ColFn(ops.horizontal_any, arg, *args)


@overload
def max(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ...

Expand Down Expand Up @@ -301,6 +284,32 @@ def min(arg: ColExpr, *args: ColExpr) -> ColExpr:
return ColFn(ops.horizontal_min, arg, *args)


@overload
def sum(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ...


@overload
def sum(arg: ColExpr[Float], *args: ColExpr[Float]) -> ColExpr[Float]: ...


@overload
def sum(arg: ColExpr[Decimal], *args: ColExpr[Decimal]) -> ColExpr[Decimal]: ...


@overload
def sum(arg: ColExpr[String], *args: ColExpr[String]) -> ColExpr[String]: ...


@overload
def sum(arg: ColExpr[Duration], *args: ColExpr[Duration]) -> ColExpr[Duration]: ...


def sum(arg: ColExpr, *args: ColExpr) -> ColExpr:
""""""

return ColFn(ops.horizontal_sum, arg, *args)


def rank(
*,
partition_by: Col | ColName | Iterable[Col | ColName] | None = None,
Expand Down

0 comments on commit 3c27e4c

Please sign in to comment.