Skip to content

Support first, last with dask arrays #7562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ New Features

- Fix :py:meth:`xr.cov` and :py:meth:`xr.corr` now support complex valued arrays (:issue:`7340`, :pull:`7392`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Support dask arrays in ``first`` and ``last`` reductions.
By `Deepak Cherian <https://github.com/dcherian>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from functools import partial

from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

from xarray.core import dtypes, nputils


Expand Down Expand Up @@ -92,3 +96,36 @@ def _fill_with_last_one(a, b):
axis=axis,
dtype=array.dtype,
)


def _first_last_wrapper(array, *, axis, op, keepdims):
return op(array, axis, keepdims=keepdims)


def _first_or_last(darray, axis, op):
import dask.array

# This will raise the same error message seen for numpy
axis = normalize_axis_index(axis, darray.ndim)

wrapped_op = partial(_first_last_wrapper, op=op)
return dask.array.reduction(
darray,
chunk=wrapped_op,
aggregate=wrapped_op,
axis=axis,
dtype=darray.dtype,
keepdims=False, # match numpy version
)


def nanfirst(darray, axis):
from xarray.core.duck_array_ops import nanfirst

return _first_or_last(darray, axis, op=nanfirst)


def nanlast(darray, axis):
from xarray.core.duck_array_ops import nanlast

return _first_or_last(darray, axis, op=nanlast)
19 changes: 8 additions & 11 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import datetime
import inspect
import warnings
from functools import partial
from importlib import import_module

import numpy as np
Expand Down Expand Up @@ -637,27 +636,25 @@ def cumsum(array, axis=None, **kwargs):
return _nd_cum_func(cumsum_1d, array, axis, **kwargs)


_fail_on_dask_array_input_skipna = partial(
fail_on_dask_array_input,
msg="%r with skipna=True is not yet implemented on dask arrays",
)


def first(values, axis, skipna=None):
"""Return the first non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
# only bother for dtypes that can hold NaN
_fail_on_dask_array_input_skipna(values)
return nanfirst(values, axis)
if is_duck_dask_array(values):
return dask_array_ops.nanfirst(values, axis)
else:
return nanfirst(values, axis)
return take(values, 0, axis=axis)


def last(values, axis, skipna=None):
"""Return the last non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
# only bother for dtypes that can hold NaN
_fail_on_dask_array_input_skipna(values)
return nanlast(values, axis)
if is_duck_dask_array(values):
return dask_array_ops.nanlast(values, axis)
else:
return nanlast(values, axis)
return take(values, -1, axis=axis)


Expand Down
20 changes: 16 additions & 4 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,29 @@ def _select_along_axis(values, idx, axis):
return values[sl]


def nanfirst(values, axis):
def nanfirst(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
axis = normalize_axis_index(axis, values.ndim)
idx_first = np.argmax(~pd.isnull(values), axis=axis)
return _select_along_axis(values, idx_first, axis)
result = _select_along_axis(values, idx_first, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result


def nanlast(values, axis):
def nanlast(values, axis, keepdims=False):
if isinstance(axis, tuple):
(axis,) = axis
axis = normalize_axis_index(axis, values.ndim)
rev = (slice(None),) * axis + (slice(None, None, -1),)
idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
return _select_along_axis(values, idx_last, axis)
result = _select_along_axis(values, idx_last, axis)
if keepdims:
return np.expand_dims(result, axis=axis)
else:
return result


def inverse_permutation(indices):
Expand Down
15 changes: 10 additions & 5 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,17 +549,22 @@ def test_rolling(self):
actual = v.rolling(x=2).mean()
self.assertLazyAndAllClose(expected, actual)

def test_groupby_first(self):
@pytest.mark.parametrize("func", ["first", "last"])
def test_groupby_first_last(self, func):
method = operator.methodcaller(func)
u = self.eager_array
v = self.lazy_array

for coords in [u.coords, v.coords]:
coords["ab"] = ("x", ["a", "a", "b", "b"])
with pytest.raises(NotImplementedError, match=r"dask"):
v.groupby("ab").first()
expected = u.groupby("ab").first()
expected = method(u.groupby("ab"))

with raise_if_dask_computes():
actual = method(v.groupby("ab"))
self.assertLazyAndAllClose(expected, actual)

with raise_if_dask_computes():
actual = v.groupby("ab").first(skipna=False)
actual = method(v.groupby("ab"))
self.assertLazyAndAllClose(expected, actual)

def test_reindex(self):
Expand Down
29 changes: 28 additions & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ class TestOps:
def setUp(self):
self.x = array(
[
[[nan, nan, 2.0, nan], [nan, 5.0, 6.0, nan], [8.0, 9.0, 10.0, nan]],
[
[nan, nan, 2.0, nan],
[nan, 5.0, 6.0, nan],
[8.0, 9.0, 10.0, nan],
],
[
[nan, 13.0, 14.0, 15.0],
[nan, 17.0, 18.0, nan],
Expand Down Expand Up @@ -128,6 +132,29 @@ def test_all_nan_arrays(self):
assert np.isnan(mean([np.nan, np.nan]))


@requires_dask
class TestDaskOps(TestOps):
@pytest.fixture(autouse=True)
def setUp(self):
import dask.array

self.x = dask.array.from_array(
[
[
[nan, nan, 2.0, nan],
[nan, 5.0, 6.0, nan],
[8.0, 9.0, 10.0, nan],
],
[
[nan, 13.0, 14.0, 15.0],
[nan, 17.0, 18.0, nan],
[nan, 21.0, nan, nan],
],
],
chunks=(2, 1, 2),
)


def test_cumsum_1d():
inputs = np.array([0, 1, 2, 3])
expected = np.array([0, 1, 3, 6])
Expand Down