diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dd3120f73fc..9460fc08478 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,6 +75,7 @@ Bug fixes ``DataArray`` objects, previously only the global attributes were retained (:issue:`4497`, :pull:`4510`). By `Mathias Hauser `_. - Improve performance where reading small slices from huge dimensions was slower than necessary (:pull:`4560`). By `Dion Häfner `_. +- Fix bug where ``dask_gufunc_kwargs`` was silently changed in :py:func:`apply_ufunc` (:pull:`4576`). By `Kai Mühlbauer `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 7b62c2c705f..9251edf1cb8 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -646,6 +646,8 @@ def apply_variable_ufunc( if dask_gufunc_kwargs is None: dask_gufunc_kwargs = {} + else: + dask_gufunc_kwargs = dask_gufunc_kwargs.copy() allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None) if allow_rechunk is None: diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 63bedfaf280..1922977fdeb 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -796,6 +796,32 @@ def func(x): assert_identical(expected, actual) +@requires_dask +def test_apply_dask_new_output_sizes(): + ds = xr.Dataset({"foo": (["lon", "lat"], np.arange(10 * 10).reshape((10, 10)))}) + ds["bar"] = ds["foo"] + newdims = {"lon_new": 3, "lat_new": 6} + + def extract(obj): + def func(da): + return da[1:4, 1:7] + + return apply_ufunc( + func, + obj, + dask="parallelized", + input_core_dims=[["lon", "lat"]], + output_core_dims=[["lon_new", "lat_new"]], + dask_gufunc_kwargs=dict(output_sizes=newdims), + ) + + expected = extract(ds) + + actual = extract(ds.chunk()) + assert actual.dims == {"lon_new": 3, "lat_new": 6} + assert_identical(expected.chunk(), actual) + + def pandas_median(x): return pd.Series(x).median()