Skip to content

Commit 17b9936

Browse files
committed
Add tests.
1 parent ab2148a commit 17b9936

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

xarray/core/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _wrapper(func, args, kwargs, arg_is_array, expected_shapes):
350350
if isinstance(template, DataArray):
351351
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore
352352
else:
353-
output_chunks = template.chunks
353+
output_chunks = dict(template.chunks)
354354

355355
if isinstance(template, DataArray):
356356
result_is_array = True

xarray/tests/test_dask.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,36 @@ def test_map_blocks_convert_args_to_list(obj):
10921092
assert_identical(actual, expected)
10931093

10941094

1095+
def test_map_blocks_dask_args():
1096+
da1 = xr.DataArray(
1097+
np.ones((10, 20)),
1098+
dims=["x", "y"],
1099+
coords={"x": np.arange(10), "y": np.arange(20)},
1100+
).chunk({"x": 5, "y": 4})
1101+
1102+
# check that block shapes are the same
1103+
def sumda(da1, da2):
1104+
assert da1.shape == da2.shape
1105+
return da1 + da2
1106+
1107+
da2 = da1 + 1
1108+
with raise_if_dask_computes():
1109+
mapped = xr.map_blocks(sumda, da1, args=[da2])
1110+
xr.testing.assert_equal(da1 + da2, mapped)
1111+
1112+
# one dimension in common
1113+
da2 = (da1 + 1).isel(x=1, drop=True)
1114+
with raise_if_dask_computes():
1115+
mapped = xr.map_blocks(operator.add, da1, args=[da2])
1116+
xr.testing.assert_equal(da1 + da2, mapped)
1117+
1118+
# test that everything works when dimension names are different
1119+
da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"})
1120+
with raise_if_dask_computes():
1121+
mapped = xr.map_blocks(operator.add, da1, args=[da2])
1122+
xr.testing.assert_equal(da1 + da2, mapped)
1123+
1124+
10951125
@pytest.mark.parametrize("obj", [make_da(), make_ds()])
10961126
def test_map_blocks_add_attrs(obj):
10971127
def add_attrs(obj):

0 commit comments

Comments
 (0)