From 8e13a69c57e0d1a97cb6cac64362bdba1cb25e12 Mon Sep 17 00:00:00 2001 From: ritchie Date: Fri, 17 Jan 2025 15:41:15 +0100 Subject: [PATCH] refactor: Small drive-by's --- crates/polars-plan/src/plans/aexpr/mod.rs | 1 + crates/polars-plan/src/plans/aexpr/traverse.rs | 2 +- .../src/physical_plan/lower_group_by.rs | 15 +++++++++------ 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 91c9e509b0e..ba0636dd849 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -47,6 +47,7 @@ pub enum IRAggExpr { method: QuantileMethod, }, Sum(Node), + // include_nulls Count(Node, bool), Std(Node, u8), Var(Node, u8), diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index a5ff7c5af21..1b1b34c535f 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -3,7 +3,7 @@ use super::*; impl AExpr { /// Push the inputs of this node to the given container, in reverse order. /// This ensures the primary node responsible for the name is pushed last. - pub(crate) fn inputs_rev(&self, container: &mut E) + pub fn inputs_rev(&self, container: &mut E) where E: Extend, { diff --git a/crates/polars-stream/src/physical_plan/lower_group_by.rs b/crates/polars-stream/src/physical_plan/lower_group_by.rs index d07b187f722..f8956e704fa 100644 --- a/crates/polars-stream/src/physical_plan/lower_group_by.rs +++ b/crates/polars-stream/src/physical_plan/lower_group_by.rs @@ -12,6 +12,7 @@ use polars_plan::prelude::GroupbyOptions; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use polars_utils::pl_str::PlSmallStr; +use recursive::recursive; use slotmap::SlotMap; use super::lower_expr::lower_exprs; @@ -71,6 +72,7 @@ fn build_group_by_fallback( /// /// Such an expression is defined as the elementwise combination of scalar /// aggregations of elementwise combinations of the input columns / scalar literals. +#[recursive] fn try_lower_elementwise_scalar_agg_expr( expr: Node, inside_agg: bool, @@ -178,7 +180,10 @@ fn try_lower_elementwise_scalar_agg_expr( }, AExpr::Agg(agg) => { - let orig_agg = agg.clone(); + // Nested aggregates not supported. + if inside_agg { + return None; + } match agg { IRAggExpr::Min { input, .. } | IRAggExpr::Max { input, .. } @@ -188,10 +193,7 @@ fn try_lower_elementwise_scalar_agg_expr( | IRAggExpr::Sum(input) | IRAggExpr::Var(input, ..) | IRAggExpr::Std(input, ..) => { - // Nested aggregates not supported. - if inside_agg { - return None; - } + let orig_agg = agg.clone(); // Lower and replace input. let trans_input = lower_rec!(*input, true)?; let mut trans_agg = orig_agg; @@ -311,7 +313,8 @@ fn try_build_streaming_group_by( &mut trans_agg_exprs, &trans_input_cols, )?; - trans_output_exprs.push(ExprIR::new(trans_node, agg.output_name_inner().clone())); + let output_name = OutputName::Alias(agg.output_name().clone()); + trans_output_exprs.push(ExprIR::new(trans_node, output_name)); } let input_schema = &phys_sm[trans_input.node].output_schema;