Skip to content

Commit 073f414

Browse files
Remove dask_array_type checks (#7023)
* remove dask_array_type checks * remove dask_array_compat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore xarray wraps dask precedence * fully remove dask_aray_compat.py file * remove one reference to cupy * simplify __apply_ufunc__ check Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1afba66 commit 073f414

File tree

5 files changed

+18
-97
lines changed

5 files changed

+18
-97
lines changed

xarray/core/arithmetic.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
1717
from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods
1818
from .options import OPTIONS, _get_keep_attrs
19-
from .pycompat import dask_array_type
19+
from .pycompat import is_duck_array
2020

2121

2222
class SupportsArithmetic:
@@ -33,20 +33,21 @@ class SupportsArithmetic:
3333

3434
# TODO: allow extending this with some sort of registration system
3535
_HANDLED_TYPES = (
36-
np.ndarray,
3736
np.generic,
3837
numbers.Number,
3938
bytes,
4039
str,
41-
) + dask_array_type
40+
)
4241

4342
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
4443
from .computation import apply_ufunc
4544

4645
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
4746
out = kwargs.get("out", ())
4847
for x in inputs + out:
49-
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
48+
if not is_duck_array(x) and not isinstance(
49+
x, self._HANDLED_TYPES + (SupportsArithmetic,)
50+
):
5051
return NotImplemented
5152

5253
if ufunc.signature is not None:

xarray/core/dask_array_compat.py

-62
This file was deleted.

xarray/core/duck_array_ops.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from numpy import take, tensordot, transpose, unravel_index # noqa
2424
from numpy import where as _where
2525

26-
from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
26+
from . import dask_array_ops, dtypes, npcompat, nputils
2727
from .nputils import nanfirst, nanlast
28-
from .pycompat import cupy_array_type, dask_array_type, is_duck_dask_array
28+
from .pycompat import cupy_array_type, is_duck_dask_array
2929
from .utils import is_duck_array
3030

3131
try:
@@ -113,7 +113,7 @@ def isnull(data):
113113
return zeros_like(data, dtype=bool)
114114
else:
115115
# at this point, array should have dtype=object
116-
if isinstance(data, (np.ndarray, dask_array_type)):
116+
if isinstance(data, np.ndarray):
117117
return pandas_isnull(data)
118118
else:
119119
# Not reachable yet, but intended for use with other duck array
@@ -631,7 +631,9 @@ def sliding_window_view(array, window_shape, axis):
631631
The rolling dimension will be placed at the last dimension.
632632
"""
633633
if is_duck_dask_array(array):
634-
return dask_array_compat.sliding_window_view(array, window_shape, axis)
634+
import dask.array as da
635+
636+
return da.lib.stride_tricks.sliding_window_view(array, window_shape, axis)
635637
else:
636638
return npcompat.sliding_window_view(array, window_shape, axis)
637639

xarray/core/nanops.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@
66

77
from . import dtypes, nputils, utils
88
from .duck_array_ops import count, fillna, isnull, where, where_method
9-
from .pycompat import dask_array_type
10-
11-
try:
12-
import dask.array as dask_array
13-
14-
from . import dask_array_compat
15-
except ImportError:
16-
dask_array = None # type: ignore[assignment]
17-
dask_array_compat = None # type: ignore[assignment]
189

1910

2011
def _maybe_null_out(result, axis, mask, min_count=1):
@@ -65,34 +56,30 @@ def nanmin(a, axis=None, out=None):
6556
if a.dtype.kind == "O":
6657
return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis)
6758

68-
module = dask_array if isinstance(a, dask_array_type) else nputils
69-
return module.nanmin(a, axis=axis)
59+
return nputils.nanmin(a, axis=axis)
7060

7161

7262
def nanmax(a, axis=None, out=None):
7363
if a.dtype.kind == "O":
7464
return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis)
7565

76-
module = dask_array if isinstance(a, dask_array_type) else nputils
77-
return module.nanmax(a, axis=axis)
66+
return nputils.nanmax(a, axis=axis)
7867

7968

8069
def nanargmin(a, axis=None):
8170
if a.dtype.kind == "O":
8271
fill_value = dtypes.get_pos_infinity(a.dtype)
8372
return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
8473

85-
module = dask_array if isinstance(a, dask_array_type) else nputils
86-
return module.nanargmin(a, axis=axis)
74+
return nputils.nanargmin(a, axis=axis)
8775

8876

8977
def nanargmax(a, axis=None):
9078
if a.dtype.kind == "O":
9179
fill_value = dtypes.get_neg_infinity(a.dtype)
9280
return _nan_argminmax_object("argmax", fill_value, a, axis=axis)
9381

94-
module = dask_array if isinstance(a, dask_array_type) else nputils
95-
return module.nanargmax(a, axis=axis)
82+
return nputils.nanargmax(a, axis=axis)
9683

9784

9885
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
@@ -128,8 +115,6 @@ def nanmean(a, axis=None, dtype=None, out=None):
128115
warnings.filterwarnings(
129116
"ignore", r"Mean of empty slice", category=RuntimeWarning
130117
)
131-
if isinstance(a, dask_array_type):
132-
return dask_array.nanmean(a, axis=axis, dtype=dtype)
133118

134119
return np.nanmean(a, axis=axis, dtype=dtype)
135120

xarray/core/variable.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from .pycompat import (
4040
DuckArrayModule,
4141
cupy_array_type,
42-
dask_array_type,
4342
integer_types,
4443
is_duck_dask_array,
4544
sparse_array_type,
@@ -59,12 +58,8 @@
5958
)
6059

6160
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
62-
(
63-
indexing.ExplicitlyIndexed,
64-
pd.Index,
65-
)
66-
+ dask_array_type
67-
+ cupy_array_type
61+
indexing.ExplicitlyIndexed,
62+
pd.Index,
6863
)
6964
# https://github.com/python/mypy/issues/224
7065
BASIC_INDEXING_TYPES = integer_types + (slice,)
@@ -1150,7 +1145,7 @@ def to_numpy(self) -> np.ndarray:
11501145
data = self.data
11511146

11521147
# TODO first attempt to call .to_numpy() once some libraries implement it
1153-
if isinstance(data, dask_array_type):
1148+
if hasattr(data, "chunks"):
11541149
data = data.compute()
11551150
if isinstance(data, cupy_array_type):
11561151
data = data.get()

0 commit comments

Comments
 (0)