Skip to content

Commit dd08920

Browse files
committed
Fix tests
1 parent 41abaac commit dd08920

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

flox/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool:
259259
def _is_sparse_supported_reduction(func: T_Agg) -> bool:
260260
if isinstance(func, Aggregation):
261261
func = func.name
262-
return not HAS_SPARSE or all(f not in func for f in ["first", "last", "prod", "var", "std"])
262+
return HAS_SPARSE and all(f not in func for f in ["first", "last", "prod", "var", "std"])
263263

264264

265265
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:

tests/test_core.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def dask_array_ones(*args):
7676

7777

7878
DEFAULT_QUANTILE = 0.9
79+
REINDEX_SPARSE_STRAT = ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)
80+
REINDEX_SPARSE_PARAM = pytest.param(
81+
REINDEX_SPARSE_STRAT, marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse"))
82+
)
7983

8084
if TYPE_CHECKING:
8185
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method
@@ -325,7 +329,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
325329
params = list(
326330
itertools.product(
327331
["map-reduce"],
328-
[True, False, None, ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO)],
332+
[True, False, None, REINDEX_SPARSE_STRAT],
329333
)
330334
)
331335
params.extend(itertools.product(["cohorts"], [False, None]))
@@ -460,18 +464,7 @@ def test_numpy_reduce_nd_md():
460464

461465

462466
@requires_dask
463-
@pytest.mark.parametrize(
464-
"reindex",
465-
[
466-
None,
467-
False,
468-
True,
469-
pytest.param(
470-
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
471-
marks=pytest.mark.skipif(not has_sparse, reason="no sparse"),
472-
),
473-
],
474-
)
467+
@pytest.mark.parametrize("reindex", [None, False, True, REINDEX_SPARSE_PARAM])
475468
@pytest.mark.parametrize("func", ALL_FUNCS)
476469
@pytest.mark.parametrize("add_nan", [False, True])
477470
@pytest.mark.parametrize("dtype", (float,))
@@ -803,7 +796,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
803796
pytest.param(False, (2, 2, 3), marks=requires_dask),
804797
pytest.param(True, (2, 2, 3), marks=requires_dask),
805798
pytest.param(
806-
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
799+
REINDEX_SPARSE_STRAT,
807800
(2, 2, 3),
808801
marks=(requires_dask, pytest.mark.skipif(not has_sparse, reason="no sparse")),
809802
),
@@ -860,7 +853,7 @@ def _maybe_chunk(arr):
860853
([0, 1, 2], False),
861854
pytest.param(
862855
[0, 1, 2],
863-
ReindexStrategy(blockwise=False, array_type=ReindexArrayType.SPARSE_COO),
856+
REINDEX_SPARSE_STRAT,
864857
marks=pytest.mark.skipif(not has_sparse, reason="no sparse"),
865858
),
866859
],

0 commit comments

Comments
 (0)