Skip to content

Commit dcf93ee

Browse files
[Review] Add fast path for multi-column sorting (#229)
* add fast path for multi-column sorting * lint * Prevent single column Dask dataframes from calling sort_values * Wrap dask_cudf import in try/except block * Add test for fast multi column sort * Move multi_col_sort contents to apply_sort * Ignore index for dask-cudf sorting * Fix show tables test for cudf enabled fixture * Trigger CI * Add single partition sort case * Return cudf sorted dataframe without persisting * Update nan sort test to reflect Pandas' sort_values ordering * Add comments tracking relevant [dask-]cudf issues * Move GPU sorting tests to test_sort.py * Remove unnecessary isin import Co-authored-by: Charles Blackmon-Luca <[email protected]>
1 parent 54d72cd commit dcf93ee

File tree

4 files changed

+278
-48
lines changed

4 files changed

+278
-48
lines changed

dask_sql/physical/utils/sort.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22

33
import dask.dataframe as dd
44
import pandas as pd
5+
from dask.utils import M
56

6-
from dask_sql.utils import make_pickable_without_dask_sql, new_temporary_column
7+
from dask_sql.utils import make_pickable_without_dask_sql
8+
9+
try:
10+
import dask_cudf
11+
except ImportError:
12+
dask_cudf = None
713

814

915
def apply_sort(
@@ -12,6 +18,46 @@ def apply_sort(
1218
sort_ascending: List[bool],
1319
sort_null_first: List[bool],
1420
) -> dd.DataFrame:
21+
# if we have a single partition, we can sometimes sort with map_partitions
22+
if df.npartitions == 1:
23+
if dask_cudf is not None and isinstance(df, dask_cudf.DataFrame):
24+
# cudf only supports null positioning if `ascending` is a single boolean:
25+
# https://github.com/rapidsai/cudf/issues/9400
26+
if (all(sort_ascending) or not any(sort_ascending)) and not any(
27+
sort_null_first[1:]
28+
):
29+
return df.map_partitions(
30+
M.sort_values,
31+
by=sort_columns,
32+
ascending=all(sort_ascending),
33+
na_position="first" if sort_null_first[0] else "last",
34+
)
35+
if not any(sort_null_first):
36+
return df.map_partitions(
37+
M.sort_values, by=sort_columns, ascending=sort_ascending
38+
)
39+
elif not any(sort_null_first[1:]):
40+
return df.map_partitions(
41+
M.sort_values,
42+
by=sort_columns,
43+
ascending=sort_ascending,
44+
na_position="first" if sort_null_first[0] else "last",
45+
)
46+
47+
# dask-cudf only supports ascending sort / nulls last:
48+
# https://github.com/rapidsai/cudf/pull/9250
49+
# https://github.com/rapidsai/cudf/pull/9264
50+
if (
51+
dask_cudf is not None
52+
and isinstance(df, dask_cudf.DataFrame)
53+
and all(sort_ascending)
54+
and not any(sort_null_first)
55+
):
56+
try:
57+
return df.sort_values(sort_columns, ignore_index=True)
58+
except ValueError:
59+
pass
60+
1561
# Split the first column. We need to handle this one with set_index
1662
first_sort_column = sort_columns[0]
1763
first_sort_ascending = sort_ascending[0]

tests/integration/fixtures.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from dask.distributed import Client
1010
from pandas.testing import assert_frame_equal
1111

12+
try:
13+
import cudf
14+
except ImportError:
15+
cudf = None
16+
1217

1318
@pytest.fixture()
1419
def timeseries_df(c):
@@ -86,6 +91,21 @@ def datetime_table():
8691
)
8792

8893

94+
@pytest.fixture()
95+
def gpu_user_table_1(user_table_1):
96+
return cudf.from_pandas(user_table_1) if cudf else None
97+
98+
99+
@pytest.fixture()
100+
def gpu_df(df):
101+
return cudf.from_pandas(df) if cudf else None
102+
103+
104+
@pytest.fixture()
105+
def gpu_long_table(long_table):
106+
return cudf.from_pandas(long_table) if cudf else None
107+
108+
89109
@pytest.fixture()
90110
def c(
91111
df_simple,
@@ -97,6 +117,9 @@ def c(
97117
user_table_nan,
98118
string_table,
99119
datetime_table,
120+
gpu_user_table_1,
121+
gpu_df,
122+
gpu_long_table,
100123
):
101124
dfs = {
102125
"df_simple": df_simple,
@@ -108,13 +131,18 @@ def c(
108131
"user_table_nan": user_table_nan,
109132
"string_table": string_table,
110133
"datetime_table": datetime_table,
134+
"gpu_user_table_1": gpu_user_table_1,
135+
"gpu_df": gpu_df,
136+
"gpu_long_table": gpu_long_table,
111137
}
112138

113139
# Lazy import, otherwise the pytest framework has problems
114140
from dask_sql.context import Context
115141

116142
c = Context()
117143
for df_name, df in dfs.items():
144+
if df is None:
145+
continue
118146
dask_df = dd.from_pandas(df, npartitions=3)
119147
c.create_table(df_name, dask_df)
120148

tests/integration/test_show.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
import pytest
33
from pandas.testing import assert_frame_equal
44

5+
try:
6+
import cudf
7+
except ImportError:
8+
cudf = None
9+
510

611
def test_schemas(c):
712
df = c.sql("SHOW SCHEMAS")
@@ -36,6 +41,21 @@ def test_tables(c):
3641
"string_table",
3742
"datetime_table",
3843
]
44+
if cudf is None
45+
else [
46+
"df",
47+
"df_simple",
48+
"user_table_1",
49+
"user_table_2",
50+
"long_table",
51+
"user_table_inf",
52+
"user_table_nan",
53+
"string_table",
54+
"datetime_table",
55+
"gpu_user_table_1",
56+
"gpu_df",
57+
"gpu_long_table",
58+
]
3959
}
4060
)
4161

0 commit comments

Comments
 (0)