From 7c59fb9ce50a9dcf301b9008ff5552fefce4f622 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 25 Feb 2021 19:25:09 +0000 Subject: [PATCH 1/4] Support dask.graph_manipulation --- xarray/core/dataarray.py | 8 ++-- xarray/core/dataset.py | 97 ++++++++++++++++++++------------------- xarray/core/variable.py | 17 ++----- xarray/tests/test_dask.py | 38 +++++++++++++++ 4 files changed, 96 insertions(+), 64 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 34354da61e2..b48791ef7f3 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -839,15 +839,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): func, args = self._to_temp_dataset().__dask_postcompute__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args @staticmethod - def _dask_finalize(results, func, args, name): - ds = func(results, *args) + def _dask_finalize(results, name, func, *args, **kwargs): + ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables return DataArray(variable, coords, name=name, fastpath=True) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9faf74dd4bc..54b592dde65 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -863,15 +863,25 @@ def __dask_scheduler__(self): return da.Array.__dask_scheduler__ def __dask_postcompute__(self): + return self._dask_postcompute, () + + def __dask_postpersist__(self): + return self._dask_postpersist, () + + def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset": import dask - info = [ - (k, None) + v.__dask_postcompute__() - if dask.is_dask_collection(v) - else (k, v, None, None) - for k, v in self._variables.items() - ] - construct_direct_args = ( + variables = {} + results_iter = iter(results) + + for k, v in self._variables.items(): + if dask.is_dask_collection(v): + rebuild, args = v.__dask_postcompute__() + v = rebuild(next(results_iter), *args) + variables[k] = v + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -879,18 +889,40 @@ def __dask_postcompute__(self): self._encoding, self._close, ) - return self._dask_postcompute, (info, construct_direct_args) - def __dask_postpersist__(self): - import dask + def _dask_postpersist( + self, dsk: Mapping, *, rename: Mapping[str, str] = None + ) -> "Dataset": + from dask import is_dask_collection + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull - info = [ - (k, None, v.__dask_keys__()) + v.__dask_postpersist__() - if dask.is_dask_collection(v) - else (k, v, None, None, None) - for k, v in self._variables.items() - ] - construct_direct_args = ( + variables = {} + + for k, v in self._variables.items(): + if is_dask_collection(v): + if isinstance(dsk, HighLevelGraph): + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph + # + # In case of multiple layers, don't pollute a Variable's + # HighLevelGraph with layers belonging exclusively to other + # Variables. However, we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + dsk2 = dsk.cull_layers(v.__dask_layers__()) + else: + # __dask_postpersist__() was called by dask.optimize or dask.persist + dsk2, _ = cull(dsk, v.__dask_keys__()) + + rebuild, args = v.__dask_postpersist__() + # rename was added in dask 2021.3 + kwargs = {"rename": rename} if rename else {} + v = rebuild(dsk2, *args, **kwargs) + + variables[k] = v + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -898,37 +930,6 @@ def __dask_postpersist__(self): self._encoding, self._close, ) - return self._dask_postpersist, (info, construct_direct_args) - - @staticmethod - def _dask_postcompute(results, info, construct_direct_args): - variables = {} - results_iter = iter(results) - for k, v, rebuild, rebuild_args in info: - if v is None: - variables[k] = rebuild(next(results_iter), *rebuild_args) - else: - variables[k] = v - - final = Dataset._construct_direct(variables, *construct_direct_args) - return final - - @staticmethod - def _dask_postpersist(dsk, info, construct_direct_args): - from dask.optimization import cull - - variables = {} - # postpersist is called in both dask.optimize and dask.persist - # When persisting, we want to filter out unrelated keys for - # each Variable's task graph. - for k, v, dask_keys, rebuild, rebuild_args in info: - if v is None: - dsk2, _ = cull(dsk, dask_keys) - variables[k] = rebuild(dsk2, *rebuild_args) - else: - variables[k] = v - - return Dataset._construct_direct(variables, *construct_direct_args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9b70f721689..c59cbf1f3e4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -531,22 +531,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): array_func, array_args = self._data.__dask_postcompute__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args def __dask_postpersist__(self): array_func, array_args = self._data.__dask_postpersist__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args - @staticmethod - def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - data = array_func(results, *array_args) - return Variable(dims, data, attrs=attrs, encoding=encoding) + def _dask_finalize(self, results, array_func, *args, **kwargs): + data = array_func(results, *args, **kwargs) + return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @property def values(self): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8220c8b83dc..2dfb0c6c942 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1599,3 +1599,41 @@ def test_optimize(): arr = xr.DataArray(a).chunk(5) (arr2,) = dask.optimize(arr) arr2.compute() + + +# The graph_manipulation module is in dask since 2021.2 but it became usable with +# xarray only since 2021.3 +@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module") +def test_graph_manipulation(): + """dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder + function returned by __dask_postperist__; also, the dsk passed to the rebuilder is + a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict. + """ + import dask.graph_manipulation as gm + + # BasicLayer with 3 keys + Blockwise layer + v = Variable(["x"], [1, 2]).chunk(1) * 2 + da = DataArray(v) + ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])}) + v2, da2, ds2 = gm.clone(ds, da, v) + + assert_equal(v2, v) + assert_equal(da2, da) + assert_equal(ds2, ds) + + assert v2.name != v.name + assert v2.__dask_keys__() != v.__dask_keys__() + assert v2.dask.layers.keys() != v.dask.layers.keys() + + assert da2.name != da.name + assert da2.__dask_keys__() != da.__dask_keys__() + assert da2.dask.layers.keys() != da.dask.layers.keys() + + assert ds2.__dask_keys__() != ds.__dask_keys__() + assert ds2.dask.layers.keys() != ds.dask.layers.keys() + + # Above we performed a slice operation; adding the two slices back together creates + # a diamond-shaped dependency graph, which in turn will trigger a collision in layer + # names if we were to use HighLevelGraph.cull() instead of + # HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__(). + assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2) From bfe92f547d435b789ce6f045eeeac6442c50ad6e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 26 Feb 2021 11:08:31 +0000 Subject: [PATCH 2/4] fix --- xarray/core/dataset.py | 46 ++++++++++++++++++++++++--------------- xarray/tests/test_dask.py | 23 +++++++++----------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 54b592dde65..d49cd64c6c7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -900,26 +900,36 @@ def _dask_postpersist( variables = {} for k, v in self._variables.items(): - if is_dask_collection(v): - if isinstance(dsk, HighLevelGraph): - # dask >= 2021.3 - # __dask_postpersist__() was called by dask.highlevelgraph - # - # In case of multiple layers, don't pollute a Variable's - # HighLevelGraph with layers belonging exclusively to other - # Variables. However, we need to prevent partial layers: - # https://github.com/dask/dask/issues/7137 - dsk2 = dsk.cull_layers(v.__dask_layers__()) - else: - # __dask_postpersist__() was called by dask.optimize or dask.persist - dsk2, _ = cull(dsk, v.__dask_keys__()) + if not is_dask_collection(v): + variables[k] = v + continue - rebuild, args = v.__dask_postpersist__() - # rename was added in dask 2021.3 - kwargs = {"rename": rename} if rename else {} - v = rebuild(dsk2, *args, **kwargs) + if isinstance(dsk, HighLevelGraph): + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph. + # Don't use dsk.cull(), as we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + layers = v.__dask_layers__() + if rename: + layers = [rename.get(k, k) for k in layers] + dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover + # At the moment of writing, this is only for forward compatibility. + # replace_name_in_key requires dask >= 2021.3. + from dask.base import flatten, replace_name_in_key + + keys = [ + replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__()) + ] + dsk2, _ = cull(dsk, keys) + else: + # __dask_postpersist__() was called by dask.optimize or dask.persist + dsk2, _ = cull(dsk, v.__dask_keys__()) - variables[k] = v + rebuild, args = v.__dask_postpersist__() + # rename was added in dask 2021.3 + kwargs = {"rename": rename} if rename else {} + variables[k] = rebuild(dsk2, *args, **kwargs) return Dataset._construct_direct( variables, diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 2dfb0c6c942..908a959db45 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1611,26 +1611,23 @@ def test_graph_manipulation(): """ import dask.graph_manipulation as gm - # BasicLayer with 3 keys + Blockwise layer - v = Variable(["x"], [1, 2]).chunk(1) * 2 + v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2 da = DataArray(v) ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])}) - v2, da2, ds2 = gm.clone(ds, da, v) + + v2, da2, ds2 = gm.clone(v, da, ds) assert_equal(v2, v) assert_equal(da2, da) assert_equal(ds2, ds) - assert v2.name != v.name - assert v2.__dask_keys__() != v.__dask_keys__() - assert v2.dask.layers.keys() != v.dask.layers.keys() - - assert da2.name != da.name - assert da2.__dask_keys__() != da.__dask_keys__() - assert da2.dask.layers.keys() != da.dask.layers.keys() - - assert ds2.__dask_keys__() != ds.__dask_keys__() - assert ds2.dask.layers.keys() != ds.dask.layers.keys() + for a, b in ((v, v2), (da, da2), (ds, ds2)): + assert a.__dask_layers__() != b.__dask_layers__() + assert len(a.__dask_layers__()) == len(b.__dask_layers__()) + assert a.__dask_graph__().keys() != b.__dask_graph__().keys() + assert len(a.__dask_graph__()) == len(b.__dask_graph__()) + assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys() + assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers) # Above we performed a slice operation; adding the two slices back together creates # a diamond-shaped dependency graph, which in turn will trigger a collision in layer From 16c63b8dbff9b4dcb7e866faba09b75f68c0702d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 3 Mar 2021 11:32:25 +0000 Subject: [PATCH 3/4] What's New --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eed4e16eb62..e54abc680e5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,7 +22,9 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ - +- Support for `dask.graph_manipulation + `_ (requires dask >=2021.3) + By `Guido Imperiale `_ Breaking changes ~~~~~~~~~~~~~~~~ From c88c949911f3c9f42740b6b5367940433ec0a8c7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 3 Mar 2021 11:35:42 +0000 Subject: [PATCH 4/4] [test-upstream]