Skip to content

Commit b76a13f

Browse files
authored
Fix: make copy of dask_gufunc_kwargs before changing content (#4576)
* Fix: make copy of dask_gufunc_kwargs before changing content (in apply_ufunc), add test * DOC: add entry to whats-new.rst * DOC: fix type in whats-new.rst [skip-ci]
1 parent 76036bd commit b76a13f

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

doc/whats-new.rst

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ Bug fixes
7575
``DataArray`` objects, previously only the global attributes were retained (:issue:`4497`, :pull:`4510`).
7676
By `Mathias Hauser <https://github.com/mathause>`_.
7777
- Improve performance where reading small slices from huge dimensions was slower than necessary (:pull:`4560`). By `Dion Häfner <https://github.com/dionhaefner>`_.
78+
- Fix bug where ``dask_gufunc_kwargs`` was silently changed in :py:func:`apply_ufunc` (:pull:`4576`). By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
7879

7980
Documentation
8081
~~~~~~~~~~~~~

xarray/core/computation.py

+2
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,8 @@ def apply_variable_ufunc(
646646

647647
if dask_gufunc_kwargs is None:
648648
dask_gufunc_kwargs = {}
649+
else:
650+
dask_gufunc_kwargs = dask_gufunc_kwargs.copy()
649651

650652
allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None)
651653
if allow_rechunk is None:

xarray/tests/test_computation.py

+26
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,32 @@ def func(x):
796796
assert_identical(expected, actual)
797797

798798

799+
@requires_dask
800+
def test_apply_dask_new_output_sizes():
801+
ds = xr.Dataset({"foo": (["lon", "lat"], np.arange(10 * 10).reshape((10, 10)))})
802+
ds["bar"] = ds["foo"]
803+
newdims = {"lon_new": 3, "lat_new": 6}
804+
805+
def extract(obj):
806+
def func(da):
807+
return da[1:4, 1:7]
808+
809+
return apply_ufunc(
810+
func,
811+
obj,
812+
dask="parallelized",
813+
input_core_dims=[["lon", "lat"]],
814+
output_core_dims=[["lon_new", "lat_new"]],
815+
dask_gufunc_kwargs=dict(output_sizes=newdims),
816+
)
817+
818+
expected = extract(ds)
819+
820+
actual = extract(ds.chunk())
821+
assert actual.dims == {"lon_new": 3, "lat_new": 6}
822+
assert_identical(expected.chunk(), actual)
823+
824+
799825
def pandas_median(x):
800826
return pd.Series(x).median()
801827

0 commit comments

Comments
 (0)