From 85024ff1f23942443ea0ac56386ec540b9fd7101 Mon Sep 17 00:00:00 2001 From: Cindy Chiao Date: Tue, 21 Dec 2021 03:05:35 +0000 Subject: [PATCH 1/4] fix --- xarray/core/parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f20256346da..aad1d285377 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -353,8 +353,8 @@ def _wrapper( # all xarray objects must be aligned. This is consistent with apply_ufunc. aligned = align(*xarray_objs, join="exact") xarray_objs = tuple( - dataarray_to_dataset(arg) if is_da else arg - for is_da, arg in zip(is_array, aligned) + dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg + for arg in aligned ) _, npargs = unzip( From 650f701aca800461ac18b6dbaf34e81f3c618e2e Mon Sep 17 00:00:00 2001 From: Cindy Chiao Date: Tue, 21 Dec 2021 05:14:33 +0000 Subject: [PATCH 2/4] add test --- xarray/tests/test_dask.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 3b962cb2c5c..8e70523900f 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1167,6 +1167,19 @@ def func(obj): assert_identical(actual, expected) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_mixed_type_inputs(obj): + def func(obj, non_xarray_input, obj2): + result = obj + obj.x + 5 * obj.y + return result + + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, args=["non_xarray_input", obj]) + expected = func(obj, "non_xarray_input", obj) + assert_chunks_equal(expected.chunk(), actual) + assert_identical(actual, expected) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_convert_args_to_list(obj): expected = obj + 10 From 7cc3a1523d214aaf66d278d828fe6a7c7291a553 Mon Sep 17 00:00:00 2001 From: Cindy Chiao Date: Tue, 21 Dec 2021 05:34:10 +0000 Subject: [PATCH 3/4] kick test --- xarray/tests/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8e70523900f..3b1e0ad658e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1169,8 +1169,8 @@ def func(obj): @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_mixed_type_inputs(obj): - def func(obj, non_xarray_input, obj2): - result = obj + obj.x + 5 * obj.y + def func(obj1, non_xarray_input, obj2): + result = obj1 + obj1.x + 5 * obj1.y return result with raise_if_dask_computes(): From 8252a9c73d243840866b6ba97b6a3179c655a409 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 29 Dec 2021 09:53:35 -0700 Subject: [PATCH 4/4] [skip-ci] add whats-new --- doc/whats-new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f991a4e2a89..2572651415d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,7 +37,8 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Fix applying function with non-xarray arguments using :py:func:`xr.map_blocks`. + By `Cindy Chiao `_. Documentation ~~~~~~~~~~~~~