Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1794362: Support df.summary #2837

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#### New Features

- Added support for `DataFrame.summary()` to compute desired statistics of a DataFrame.
- Added support for the following functions in `functions.py`
- `array_reverse`
- `divnull`
Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ DataFrame
DataFrame.show
DataFrame.sort
DataFrame.subtract
DataFrame.summary
DataFrame.take
DataFrame.toDF
DataFrame.toJSON
Expand Down
201 changes: 146 additions & 55 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#

import copy
import functools
import itertools
import re
import sys
Expand Down Expand Up @@ -170,6 +171,7 @@
from snowflake.snowpark.exceptions import SnowparkDataframeException
from snowflake.snowpark.functions import (
abs as abs_,
approx_percentile,
col,
count,
lit,
Expand Down Expand Up @@ -5046,44 +5048,15 @@ def session(self) -> "snowflake.snowpark.Session":
"""
return self._session

@publicapi
def describe(
self, *cols: Union[str, List[str]], _emit_ast: bool = True
def _calculate_statistics(
self,
cols: List[str],
stat_func_dict: Dict[str, Callable],
) -> "DataFrame":
"""
Computes basic statistics for numeric columns, which includes
``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
are provided, this function computes statistics for all numerical or
string columns. Non-numeric and non-string columns will be ignored
when calling this method.

Example::
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> desc_result = df.describe().sort("SUMMARY").show()
-------------------------------------------------------
|"SUMMARY" |"A" |"B" |
-------------------------------------------------------
|count |2.0 |2.0 |
|max |3.0 |4.0 |
|mean |2.0 |3.0 |
|min |1.0 |2.0 |
|stddev |1.4142135623730951 |1.4142135623730951 |
-------------------------------------------------------
<BLANKLINE>

Args:
cols: The names of columns whose basic statistics are computed.
Calculates the statistics for the specified columns.
This method is used for the implementation of the `describe` and `summary` method.
"""
stmt = None
if _emit_ast:
stmt = self._session._ast_batch.assign()
expr = with_src_position(stmt.expr.sp_dataframe_describe, stmt)
self._set_ast_ref(expr.df)
col_list, expr.cols.variadic = parse_positional_args_to_list_variadic(*cols)
for c in col_list:
build_expr_from_snowpark_column_or_col_name(expr.cols.args.add(), c)

cols = parse_positional_args_to_list(*cols)
df = self.select(cols, _emit_ast=False) if len(cols) > 0 else self

# ignore non-numeric and non-string columns
Expand All @@ -5093,30 +5066,11 @@ def describe(
if isinstance(field.datatype, (StringType, _NumericType))
}

stat_func_dict = {
"count": count,
"mean": mean,
"stddev": stddev,
"min": min_,
"max": max_,
}

# if no columns should be selected, just return stat names
if len(numerical_string_col_type_dict) == 0:
df = self._session.create_dataframe(
list(stat_func_dict.keys()), schema=["summary"], _emit_ast=False
)
# We need to set the API calls for this to same API calls for describe
# Also add the new API calls for creating this DataFrame to the describe subcalls
adjust_api_subcalls(
df,
"DataFrame.describe",
precalls=self._plan.api_calls,
subcalls=df._plan.api_calls,
)

if _emit_ast:
df._ast_id = stmt.var_id.bitfield1

return df

Expand All @@ -5128,7 +5082,7 @@ def describe(
# for string columns, we need to convert all stats to string
# such that they can be fitted into one column
if isinstance(t, StringType):
if name in ["mean", "stddev"]:
if name.lower() in ["mean", "stddev"] or name.endswith("%"):
agg_cols.append(to_char(func(lit(None))).as_(c))
else:
agg_cols.append(to_char(func(c)))
Expand All @@ -5147,6 +5101,55 @@ def describe(
res_df.union(agg_stat_df, _emit_ast=False) if res_df else agg_stat_df
)

return res_df

@publicapi
def describe(
self, *cols: Union[str, List[str]], _emit_ast: bool = True
) -> "DataFrame":
"""
Computes basic statistics for numeric columns, which includes
``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
are provided, this function computes statistics for all numerical or
string columns. Non-numeric and non-string columns will be ignored
when calling this method.

Example::
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> desc_result = df.describe().sort("SUMMARY").show()
-------------------------------------------------------
|"SUMMARY" |"A" |"B" |
-------------------------------------------------------
|count |2.0 |2.0 |
|max |3.0 |4.0 |
|mean |2.0 |3.0 |
|min |1.0 |2.0 |
|stddev |1.4142135623730951 |1.4142135623730951 |
-------------------------------------------------------
<BLANKLINE>

Args:
cols: The names of columns whose basic statistics are computed.
"""
stmt = None
if _emit_ast:
stmt = self._session._ast_batch.assign()
expr = with_src_position(stmt.expr.sp_dataframe_describe, stmt)
self._set_ast_ref(expr.df)
col_list, expr.cols.variadic = parse_positional_args_to_list_variadic(*cols)
for c in col_list:
build_expr_from_snowpark_column_or_col_name(expr.cols.args.add(), c)

stat_func_dict = {
"count": count,
"mean": mean,
"stddev": stddev,
"min": min_,
"max": max_,
}
cols = parse_positional_args_to_list(*cols)
res_df = self._calculate_statistics(cols, stat_func_dict)

adjust_api_subcalls(
res_df,
"DataFrame.describe",
Expand All @@ -5159,6 +5162,94 @@ def describe(

return res_df

@publicapi
def summary(self, *statistics: str, _emit_ast: bool = True) -> "DataFrame":
"""
Computes specified statistics for all numeric and string columns.
Non-numeric and non-string columns will be ignored when calling this method.

Available statistics are: ``count``, ``mean``, ``stddev``, ``min``, ``max`` and
arbitrary approximate percentiles specified as a percentage (e.g., 75%).

If no statistics are given, this function computes ``count``, ``mean``, ``stddev``, ``min``,
approximate quartiles (percentiles at 25%, 50%, and 75%), and ``max``.

If no columns are provided, this function computes statistics for all numerical or
string columns.

Example::
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
>>> desc_result = df.summary().sort("SUMMARY").show()
-------------------------------------------------------
|"SUMMARY" |"A" |"B" |
-------------------------------------------------------
|25% |1.5 |2.5 |
|50% |2.0 |3.0 |
|75% |2.5 |3.5 |
|count |2.0 |2.0 |
|max |3.0 |4.0 |
|mean |2.0 |3.0 |
|min |1.0 |2.0 |
|stddev |1.4142135623730951 |1.4142135623730951 |
-------------------------------------------------------
<BLANKLINE>

Args:
statistics: The names of columns whose basic statistics are computed.
"""
# get stats that we want to calculate
stat_func_dict = {}
for s in statistics:
if s.lower() == "count":
stat_func_dict[s] = count
elif s.lower() == "mean":
stat_func_dict[s] = mean
elif s.lower() == "stddev":
stat_func_dict[s] = stddev
elif s.lower() == "min":
stat_func_dict[s] = min_
elif s.lower() == "max":
stat_func_dict[s] = max_
elif s.endswith("%"):
try:
number = float(s[:-1])
except Exception as ex:
raise ValueError(f"Unable to parse {s} as a percentile: {ex}.")
if number < 0 or number > 100:
raise ValueError(
"requirement failed: Percentiles must be in the range [0, 1]."
)
stat_func_dict[s] = functools.partial(
approx_percentile, percentile=number / 100
)
else:
raise ValueError(f"{s} is not a recognised statistic.")

# if stats are not specified, use the following default stats
if not stat_func_dict:
stat_func_dict = {
"count": count,
"mean": mean,
"stddev": stddev,
"min": min_,
"25%": lambda c: approx_percentile(c, 0.25),
"50%": lambda c: approx_percentile(c, 0.50),
"75%": lambda c: approx_percentile(c, 0.75),
"max": max_,
}

# calculate stats on all columns
res_df = self._calculate_statistics([], stat_func_dict)

adjust_api_subcalls(
res_df,
"DataFrame.summary",
precalls=self._plan.api_calls,
subcalls=res_df._plan.api_calls.copy(),
)

return res_df

@df_api_usage
@publicapi
def rename(
Expand Down
Loading
Loading