Skip to content

[REVIEW] Fast path when possible for non numeric aggregation #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions dask_sql/physical/rel/logical/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import dask.dataframe as dd
import pandas as pd

try:
import dask_cudf
except ImportError:
dask_cudf = None

from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.physical.rex.core.call import IsNullOperation
Expand Down Expand Up @@ -48,18 +53,42 @@ class AggregationSpecification:
"""
Most of the aggregations in SQL are already
implemented 1:1 in dask and can just be called via their name
(e.g. AVG is the mean). However sometimes those already
implemented functions only work well for numerical
functions. This small container class therefore
can have an additional aggregation function, which is
valid for non-numerical types.
(e.g. AVG is the mean). However sometimes those
implemented functions only work well for some datatypes.
This small container class therefore
can have an custom aggregation function, which is
valid for not supported dtypes.
"""

def __init__(self, numerical_aggregation, non_numerical_aggregation=None):
self.numerical_aggregation = numerical_aggregation
self.non_numerical_aggregation = (
non_numerical_aggregation or numerical_aggregation
)
def __init__(self, built_in_aggregation, custom_aggregation=None):
self.built_in_aggregation = built_in_aggregation
self.custom_aggregation = custom_aggregation or built_in_aggregation

def get_supported_aggregation(self, series):
built_in_aggregation = self.built_in_aggregation

# built-in aggregations work well for numeric types
if pd.api.types.is_numeric_dtype(series.dtype):
return built_in_aggregation

# Todo: Add Categorical when support comes to dask-sql
if built_in_aggregation in ["min", "max"]:
if pd.api.types.is_datetime64_any_dtype(series.dtype):
return built_in_aggregation

if pd.api.types.is_string_dtype(series.dtype):
# If dask_cudf strings dtype, return built-in aggregation
if dask_cudf is not None and isinstance(series, dask_cudf.Series):
return built_in_aggregation

# With pandas StringDtype built-in aggregations work
# while with pandas ObjectDtype and Nulls built-in aggregations fail
if isinstance(series, dd.Series) and isinstance(
series.dtype, pd.StringDtype
):
return built_in_aggregation

return self.custom_aggregation


class LogicalAggregatePlugin(BaseRelPlugin):
Expand Down Expand Up @@ -303,13 +332,9 @@ def _collect_aggregations(
f"Aggregation function {aggregation_name} not implemented (yet)."
)
if isinstance(aggregation_function, AggregationSpecification):
dtype = df[input_col].dtype
if pd.api.types.is_numeric_dtype(dtype):
aggregation_function = aggregation_function.numerical_aggregation
else:
aggregation_function = (
aggregation_function.non_numerical_aggregation
)
aggregation_function = aggregation_function.get_supported_aggregation(
df[input_col]
)

# Finally, extract the output column name
output_col = str(agg_call.getValue())
Expand Down
43 changes: 41 additions & 2 deletions tests/integration/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from dask_sql import Context


def cast_datetime_to_string(df):
cols = df.select_dtypes(include=["datetime64[ns]"]).columns
# Casting to object first as
# directly converting to string looses second precision
df[cols] = df[cols].astype("object").astype("string")
return df


def eq_sqlite(sql, **dfs):
c = Context()
engine = sqlite3.connect(":memory:")
Expand All @@ -30,6 +38,10 @@ def eq_sqlite(sql, **dfs):
dask_result = c.sql(sql).compute().reset_index(drop=True)
sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True)

# casting to object to ensure equality with sql-lite
# which returns object dtype for datetime inputs
dask_result = cast_datetime_to_string(dask_result)

# Make sure SQL and Dask use the same "NULL" value
dask_result = dask_result.fillna(np.NaN)
sqlite_result = sqlite_result.fillna(np.NaN)
Expand All @@ -54,6 +66,11 @@ def make_rand_df(size: int, **kwargs):
r = [f"ssssss{x}" for x in range(10)]
c = np.random.randint(10, size=size)
s = np.array([r[x] for x in c])
elif dt is pd.StringDtype:
r = [f"ssssss{x}" for x in range(10)]
c = np.random.randint(10, size=size)
s = np.array([r[x] for x in c])
s = pd.array(s, dtype="string")
elif dt is datetime:
rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)]
c = np.random.randint(10, size=size)
Expand Down Expand Up @@ -337,7 +354,14 @@ def test_agg_sum_avg():

def test_agg_min_max_no_group_by():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
100,
a=(int, 50),
b=(str, 50),
c=(int, 30),
d=(str, 40),
e=(float, 40),
f=(pd.StringDtype, 40),
g=(datetime, 40),
)
eq_sqlite(
"""
Expand All @@ -352,6 +376,10 @@ def test_agg_min_max_no_group_by():
MAX(d) AS max_d,
MIN(e) AS min_e,
MAX(e) AS max_e,
MIN(f) as min_f,
MAX(f) as max_f,
MIN(g) as min_g,
MAX(g) as max_g,
MIN(a+e) AS mix_1,
MIN(a)+MIN(e) AS mix_2
FROM a
Expand All @@ -362,7 +390,14 @@ def test_agg_min_max_no_group_by():

def test_agg_min_max():
a = make_rand_df(
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
100,
a=(int, 50),
b=(str, 50),
c=(int, 30),
d=(str, 40),
e=(float, 40),
f=(pd.StringDtype, 40),
g=(datetime, 40),
)
eq_sqlite(
"""
Expand All @@ -374,6 +409,10 @@ def test_agg_min_max():
MAX(d) AS max_d,
MIN(e) AS min_e,
MAX(e) AS max_e,
MIN(f) AS min_f,
MAX(f) AS max_f,
MIN(g) AS min_g,
MAX(g) AS max_g,
MIN(a+e) AS mix_1,
MIN(a)+MIN(e) AS mix_2
FROM a GROUP BY a, b
Expand Down