Skip to content

Commit 369a908

Browse files
committed
Support nanargmin, nanargmax
1 parent 6a5969f commit 369a908

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

flox/aggregations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def _pick_second(*x):
421421
chunk=("nanmax", "nanargmax"), # order is important
422422
combine=("max", "argmax"),
423423
reduction_type="argreduce",
424-
fill_value=(dtypes.NINF, -1),
424+
fill_value=(dtypes.NINF, 0),
425425
final_fill_value=-1,
426426
finalize=_pick_second,
427427
dtypes=(None, np.intp),
@@ -434,7 +434,7 @@ def _pick_second(*x):
434434
chunk=("nanmin", "nanargmin"), # order is important
435435
combine=("min", "argmin"),
436436
reduction_type="argreduce",
437-
fill_value=(dtypes.INF, -1),
437+
fill_value=(dtypes.INF, 0),
438438
final_fill_value=-1,
439439
finalize=_pick_second,
440440
dtypes=(None, np.intp),

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33

4-
@pytest.fixture(scope="module", params=["flox"])
4+
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
55
def engine(request):
66
if request.param == "numba":
77
try:

tests/test_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def dask_array_ones(*args):
5555
"nansum",
5656
"argmax",
5757
"nanfirst",
58-
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
58+
"nanargmax",
5959
"prod",
6060
"nanprod",
6161
"mean",
@@ -233,8 +233,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
233233
# computing silences a bunch of dask warnings
234234
array_ = array.compute() if chunks is not None else array
235235
if "arg" in func and add_nan_by:
236-
array_[..., nanmask] = np.nan
237-
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
236+
func_ = f"nan{func}" if "nan" not in func else func
237+
array[..., nanmask] = np.nan
238+
expected = getattr(np, func_)(array, axis=-1, **kwargs)
238239
# elif func in ["first", "last"]:
239240
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240241
elif func in ["nanfirst", "nanlast"]:

0 commit comments

Comments
 (0)