Skip to content

Commit 4caae2e

Browse files
authored
Fix apply_ufunc with dask='parallelized' for scalar arguments (#1701)
* Fix apply_ufunc with dask='parallelized' for scalar arguments Fixes GH1697 * Rewrite nested comprehension to use a "for" loop
1 parent dbf7b01 commit 4caae2e

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ Bug fixes
4747
coordinates in the DataArray constructor (:issue:`1684`).
4848
By `Joe Hamman <https://github.com/jhamman>`_
4949

50+
- Fixed ``apply_ufunc`` with ``dask='parallelized'`` for scalar arguments
51+
(:issue:`1697`).
52+
By `Stephan Hoyer <https://github.com/shoyer>`_.
53+
5054
Testing
5155
~~~~~~~
5256

xarray/core/computation.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ def apply_variable_ufunc(func, *args, **kwargs):
549549
'or load your data into memory first with '
550550
'``.load()`` or ``.compute()``')
551551
elif dask == 'parallelized':
552-
input_dims = [broadcast_dims + input_dims
553-
for input_dims in signature.input_core_dims]
552+
input_dims = [broadcast_dims + dims
553+
for dims in signature.input_core_dims]
554554
numpy_func = func
555555
func = lambda *arrays: _apply_with_dask_atop(
556556
numpy_func, arrays, input_dims, output_dims, signature,
@@ -618,10 +618,14 @@ def _apply_with_dask_atop(func, args, input_dims, output_dims, signature,
618618
.format(dim, n, {dim: -1}))
619619

620620
(out_ind,) = output_dims
621-
# skip leading dimensions that we did not insert with broadcast_compat_data
622-
atop_args = [element
623-
for (arg, dims) in zip(args, input_dims)
624-
for element in (arg, dims[-getattr(arg, 'ndim', 0):])]
621+
622+
atop_args = []
623+
for arg, dims in zip(args, input_dims):
624+
# skip leading dimensions that are implicitly added by broadcasting
625+
ndim = getattr(arg, 'ndim', 0)
626+
trimmed_dims = dims[-ndim:] if ndim else ()
627+
atop_args.extend([arg, trimmed_dims])
628+
625629
return da.atop(func, out_ind, *atop_args, dtype=dtype, concatenate=True,
626630
new_axes=output_sizes)
627631

xarray/tests/test_computation.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -582,18 +582,53 @@ def dask_safe_identity(x):
582582

583583

584584
@requires_dask
585-
def test_apply_dask_parallelized():
585+
def test_apply_dask_parallelized_one_arg():
586586
import dask.array as da
587587

588588
array = da.ones((2, 2), chunks=(1, 1))
589589
data_array = xr.DataArray(array, dims=('x', 'y'))
590590

591-
actual = apply_ufunc(identity, data_array, dask='parallelized',
592-
output_dtypes=[float])
591+
def parallel_identity(x):
592+
return apply_ufunc(identity, x, dask='parallelized',
593+
output_dtypes=[x.dtype])
594+
595+
actual = parallel_identity(data_array)
593596
assert isinstance(actual.data, da.Array)
594597
assert actual.data.chunks == array.chunks
595598
assert_identical(data_array, actual)
596599

600+
computed = data_array.compute()
601+
actual = parallel_identity(computed)
602+
assert_identical(computed, actual)
603+
604+
605+
@requires_dask
606+
def test_apply_dask_parallelized_two_args():
607+
import dask.array as da
608+
609+
array = da.ones((2, 2), chunks=(1, 1), dtype=np.int64)
610+
data_array = xr.DataArray(array, dims=('x', 'y'))
611+
data_array.name = None
612+
613+
def parallel_add(x, y):
614+
return apply_ufunc(operator.add, x, y,
615+
dask='parallelized',
616+
output_dtypes=[np.int64])
617+
618+
def check(x, y):
619+
actual = parallel_add(x, y)
620+
assert isinstance(actual.data, da.Array)
621+
assert actual.data.chunks == array.chunks
622+
assert_identical(data_array, actual)
623+
624+
check(data_array, 0),
625+
check(0, data_array)
626+
check(data_array, xr.DataArray(0))
627+
check(data_array, 0 * data_array)
628+
check(data_array, 0 * data_array[0])
629+
check(data_array[:, 0], 0 * data_array[0])
630+
check(data_array, 0 * data_array.compute())
631+
597632

598633
@requires_dask
599634
def test_apply_dask_parallelized_errors():

0 commit comments

Comments
 (0)