From f5b999ca258c591022c8eeb797d37350bc3c7e74 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 8 Nov 2017 10:04:27 -0800 Subject: [PATCH 1/2] Fix apply_ufunc with dask='parallelized' for scalar arguments Fixes GH1697 --- doc/whats-new.rst | 4 ++++ xarray/core/computation.py | 11 +++++---- xarray/tests/test_computation.py | 41 +++++++++++++++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b3cd22e32e1..dd775417132 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,6 +47,10 @@ Bug fixes coordinates in the DataArray constructor (:issue:`1684`). By `Joe Hamman `_ +- Fixed ``apply_ufunc`` with ``dask='parallelized'`` for scalar arguments + (:issue:`1697`). + By `Stephan Hoyer `_. + Testing ~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c8bfbd2985d..37d42ac7041 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -549,8 +549,8 @@ def apply_variable_ufunc(func, *args, **kwargs): 'or load your data into memory first with ' '``.load()`` or ``.compute()``') elif dask == 'parallelized': - input_dims = [broadcast_dims + input_dims - for input_dims in signature.input_core_dims] + input_dims = [broadcast_dims + dims + for dims in signature.input_core_dims] numpy_func = func func = lambda *arrays: _apply_with_dask_atop( numpy_func, arrays, input_dims, output_dims, signature, @@ -619,9 +619,10 @@ def _apply_with_dask_atop(func, args, input_dims, output_dims, signature, (out_ind,) = output_dims # skip leading dimensions that we did not insert with broadcast_compat_data - atop_args = [element - for (arg, dims) in zip(args, input_dims) - for element in (arg, dims[-getattr(arg, 'ndim', 0):])] + atop_args = [ + element + for (arg, dims) in zip(args, input_dims) + for element in (arg, dims[-getattr(arg, 'ndim', 0) or len(dims):])] return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True, new_axes=output_sizes) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 255312ff721..430a1a027cb 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -582,18 +582,53 @@ def dask_safe_identity(x): @requires_dask -def test_apply_dask_parallelized(): +def test_apply_dask_parallelized_one_arg(): import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) data_array = xr.DataArray(array, dims=('x', 'y')) - actual = apply_ufunc(identity, data_array, dask='parallelized', - output_dtypes=[float]) + def parallel_identity(x): + return apply_ufunc(identity, x, dask='parallelized', + output_dtypes=[x.dtype]) + + actual = parallel_identity(data_array) assert isinstance(actual.data, da.Array) assert actual.data.chunks == array.chunks assert_identical(data_array, actual) + computed = data_array.compute() + actual = parallel_identity(computed) + assert_identical(computed, actual) + + +@requires_dask +def test_apply_dask_parallelized_two_args(): + import dask.array as da + + array = da.ones((2, 2), chunks=(1, 1), dtype=np.int64) + data_array = xr.DataArray(array, dims=('x', 'y')) + data_array.name = None + + def parallel_add(x, y): + return apply_ufunc(operator.add, x, y, + dask='parallelized', + output_dtypes=[np.int64]) + + def check(x, y): + actual = parallel_add(x, y) + assert isinstance(actual.data, da.Array) + assert actual.data.chunks == array.chunks + assert_identical(data_array, actual) + + check(data_array, 0), + check(0, data_array) + check(data_array, xr.DataArray(0)) + check(data_array, 0 * data_array) + check(data_array, 0 * data_array[0]) + check(data_array[:, 0], 0 * data_array[0]) + check(data_array, 0 * data_array.compute()) + @requires_dask def test_apply_dask_parallelized_errors(): From d3074f6325c2e83c012042bf29968bce16b18c33 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 8 Nov 2017 21:00:02 -0800 Subject: [PATCH 2/2] Rewrite nested comprehension to use a "for" loop --- xarray/core/computation.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 37d42ac7041..e58b072f752 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -618,11 +618,14 @@ def _apply_with_dask_atop(func, args, input_dims, output_dims, signature, .format(dim, n, {dim: -1})) (out_ind,) = output_dims - # skip leading dimensions that we did not insert with broadcast_compat_data - atop_args = [ - element - for (arg, dims) in zip(args, input_dims) - for element in (arg, dims[-getattr(arg, 'ndim', 0) or len(dims):])] + + atop_args = [] + for arg, dims in zip(args, input_dims): + # skip leading dimensions that are implicitly added by broadcasting + ndim = getattr(arg, 'ndim', 0) + trimmed_dims = dims[-ndim:] if ndim else () + atop_args.extend([arg, trimmed_dims]) + return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True, new_axes=output_sizes)