From 3c27e4c1177b9984e10bfde84fa87f454ab803d5 Mon Sep 17 00:00:00 2001 From: Finn Rudolph Date: Sat, 11 Jan 2025 17:03:30 +0100 Subject: [PATCH] make some horizontal functions actual operators --- docs/source/reference/operators/index.rst | 3 + .../transform/_internal/backend/table_impl.py | 13 ++++ .../transform/_internal/ops/ops/horizontal.py | 20 +++++- .../transform/_internal/pipe/functions.py | 69 +++++++++++-------- 4 files changed, 74 insertions(+), 31 deletions(-) diff --git a/docs/source/reference/operators/index.rst b/docs/source/reference/operators/index.rst index dba9af8..358169a 100644 --- a/docs/source/reference/operators/index.rst +++ b/docs/source/reference/operators/index.rst @@ -113,6 +113,8 @@ Global functions .. autosummary:: :nosignatures: + all + any coalesce count dense_rank @@ -121,4 +123,5 @@ Global functions min rank row_number + sum when diff --git a/src/pydiverse/transform/_internal/backend/table_impl.py b/src/pydiverse/transform/_internal/backend/table_impl.py index be4e944..d952f1b 100644 --- a/src/pydiverse/transform/_internal/backend/table_impl.py +++ b/src/pydiverse/transform/_internal/backend/table_impl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import uuid from collections.abc import Generator, Iterable, Sequence from typing import TYPE_CHECKING, Any @@ -227,3 +228,15 @@ def _gt(lhs, rhs): @impl(ops.greater_equal) def _ge(lhs, rhs): return lhs >= rhs + + @impl(ops.horizontal_all) + def _horizontal_all(*args): + return functools.reduce(_and, args) + + @impl(ops.horizontal_any) + def _horizontal_any(*args): + return functools.reduce(_or, args) + + @impl(ops.horizontal_sum) + def _horizontal_sum(*args): + return functools.reduce(_add, args) diff --git a/src/pydiverse/transform/_internal/ops/ops/horizontal.py b/src/pydiverse/transform/_internal/ops/ops/horizontal.py index c828bfc..6b62a32 100644 --- a/src/pydiverse/transform/_internal/ops/ops/horizontal.py +++ b/src/pydiverse/transform/_internal/ops/ops/horizontal.py @@ -2,7 +2,14 @@ from pydiverse.transform._internal.ops.op import Operator from pydiverse.transform._internal.ops.signature import Signature -from pydiverse.transform._internal.tree.types import COMPARABLE, D +from pydiverse.transform._internal.tree.types import ( + COMPARABLE, + NUMERIC, + Bool, + D, + Duration, + String, +) class Horizontal(Operator): @@ -127,3 +134,14 @@ def __init__(self, name: str, *signatures: Signature, doc: str = ""): └──────┴──────┴──────┴─────┴─────┘ """, ) + +horizontal_any = Horizontal("any", Signature(Bool(), Bool(), ..., return_type=Bool())) + +horizontal_all = Horizontal("all", Signature(Bool(), Bool(), ..., return_type=Bool())) + +horizontal_sum = Horizontal( + "sum", + *(Signature(dtype, dtype, ..., return_type=dtype) for dtype in NUMERIC), + Signature(String(), String(), ..., return_type=String()), + Signature(Duration(), Duration(), ..., return_type=Duration()), +) diff --git a/src/pydiverse/transform/_internal/pipe/functions.py b/src/pydiverse/transform/_internal/pipe/functions.py index ecad79e..e5f2f7a 100644 --- a/src/pydiverse/transform/_internal/pipe/functions.py +++ b/src/pydiverse/transform/_internal/pipe/functions.py @@ -2,8 +2,6 @@ from __future__ import annotations -import functools -import operator from collections.abc import Iterable from typing import Any, overload @@ -24,6 +22,7 @@ Datetime, Decimal, Dtype, + Duration, Float, Int, String, @@ -55,34 +54,6 @@ def lit(val: Any, dtype: Dtype | None = None) -> LiteralCol: return LiteralCol(val, dtype) -def all(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]: - return functools.reduce(operator.and_, (arg, *args)) - - -def any(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]: - return functools.reduce(operator.or_, (arg, *args)) - - -@overload -def sum(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ... - - -@overload -def sum(arg: ColExpr[Float], *args: ColExpr[Float]) -> ColExpr[Float]: ... - - -@overload -def sum(arg: ColExpr[Decimal], *args: ColExpr[Decimal]) -> ColExpr[Decimal]: ... - - -@overload -def sum(arg: ColExpr[String], *args: ColExpr[String]) -> ColExpr[String]: ... - - -def sum(arg: ColExpr, *args: ColExpr) -> ColExpr: - return functools.reduce(operator.add, (arg, *args)) - - # --- from here the code is generated, do not delete this comment --- @@ -187,6 +158,18 @@ def dense_rank( return ColFn(ops.dense_rank, partition_by=partition_by, arrange=arrange) +def all(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]: + """""" + + return ColFn(ops.horizontal_all, arg, *args) + + +def any(arg: ColExpr[Bool], *args: ColExpr[Bool]) -> ColExpr[Bool]: + """""" + + return ColFn(ops.horizontal_any, arg, *args) + + @overload def max(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ... @@ -301,6 +284,32 @@ def min(arg: ColExpr, *args: ColExpr) -> ColExpr: return ColFn(ops.horizontal_min, arg, *args) +@overload +def sum(arg: ColExpr[Int], *args: ColExpr[Int]) -> ColExpr[Int]: ... + + +@overload +def sum(arg: ColExpr[Float], *args: ColExpr[Float]) -> ColExpr[Float]: ... + + +@overload +def sum(arg: ColExpr[Decimal], *args: ColExpr[Decimal]) -> ColExpr[Decimal]: ... + + +@overload +def sum(arg: ColExpr[String], *args: ColExpr[String]) -> ColExpr[String]: ... + + +@overload +def sum(arg: ColExpr[Duration], *args: ColExpr[Duration]) -> ColExpr[Duration]: ... + + +def sum(arg: ColExpr, *args: ColExpr) -> ColExpr: + """""" + + return ColFn(ops.horizontal_sum, arg, *args) + + def rank( *, partition_by: Col | ColName | Iterable[Col | ColName] | None = None,