From a825faaf60dc75e0365f18a0f24acb0fe288b263 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Tue, 9 Feb 2021 16:03:14 +0000 Subject: [PATCH 1/2] Compatibility with dask 2021.02.0 --- ci/requirements/environment-windows.yml | 2 +- ci/requirements/environment.yml | 2 +- xarray/core/dataset.py | 32 +++++++++++++++++++------ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 9455ef2f127..6de2bc8dc64 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -8,7 +8,7 @@ dependencies: # - cdms2 # Not available on Windows # - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340 - cftime - - dask<2021.02.0 + - dask - distributed - h5netcdf - h5py=2 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 7261b5b6954..0f59d9570c8 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -9,7 +9,7 @@ dependencies: - cdms2 - cfgrib - cftime - - dask<2021.02.0 + - dask - distributed - h5netcdf - h5py=2 diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7d51adb5244..64c3899f597 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -920,20 +920,38 @@ def _dask_postcompute(results, info, *args): @staticmethod def _dask_postpersist(dsk, info, *args): + from dask.core import flatten + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull + variables = {} # postpersist is called in both dask.optimize and dask.persist # When persisting, we want to filter out unrelated keys for # each Variable's task graph. - is_persist = len(dsk) == len(info) for is_dask, k, v in info: if is_dask: - func, args2 = v - if is_persist: - name = args2[1][0] - dsk2 = {k: v for k, v in dsk.items() if k[0] == name} + rebuild, rebuild_args = v + tmp = rebuild(dsk, *rebuild_args) + keys = set(flatten(tmp.__dask_keys__())) + if isinstance(dsk, HighLevelGraph): + # __dask_postpersist__() was invoked by various functions in the + # dask.graph_manipulation module. + # + # In case of multiple layers, don't pollute a Variable's + # HighLevelGraph with layers belonging exclusively to other + # Variables. However, we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + # TODO We're wasting a lot of key-level work. We should write a fast + # variant of HighLevelGraph.cull() that works at layer level + # only. + dsk2 = dsk.cull(keys) + dsk3 = HighLevelGraph( + {k: dsk.layers[k] for k in dsk2.layers}, dsk2.dependencies + ) else: - dsk2 = dsk - result = func(dsk2, *args2) + # __dask_postpersist__() was invoked by dask.persist() + dsk3, _ = cull(dsk, keys) + result = rebuild(dsk3, *rebuild_args) else: result = v variables[k] = result From 628263d3034ad190e4f52429ae55fb41564bdea8 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Thu, 11 Feb 2021 14:46:17 +0000 Subject: [PATCH 2/2] Rework postpersist and postcompute --- xarray/core/dataset.py | 72 +++++++++++++----------------------------- 1 file changed, 22 insertions(+), 50 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 64c3899f597..066a2f690b0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -866,13 +866,12 @@ def __dask_postcompute__(self): import dask info = [ - (True, k, v.__dask_postcompute__()) + (k, None) + v.__dask_postcompute__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -880,19 +879,18 @@ def __dask_postcompute__(self): self._encoding, self._close, ) - return self._dask_postcompute, args + return self._dask_postcompute, (info, construct_direct_args) def __dask_postpersist__(self): import dask info = [ - (True, k, v.__dask_postpersist__()) + (k, None, v.__dask_keys__()) + v.__dask_postpersist__() if dask.is_dask_collection(v) - else (False, k, v) + else (k, v, None, None, None) for k, v in self._variables.items() ] - args = ( - info, + construct_direct_args = ( self._coord_names, self._dims, self._attrs, @@ -900,63 +898,37 @@ def __dask_postpersist__(self): self._encoding, self._close, ) - return self._dask_postpersist, args + return self._dask_postpersist, (info, construct_direct_args) @staticmethod - def _dask_postcompute(results, info, *args): + def _dask_postcompute(results, info, construct_direct_args): variables = {} - results2 = list(results[::-1]) - for is_dask, k, v in info: - if is_dask: - func, args2 = v - r = results2.pop() - result = func(r, *args2) + results_iter = iter(results) + for k, v, rebuild, rebuild_args in info: + if v is None: + variables[k] = rebuild(next(results_iter), *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - final = Dataset._construct_direct(variables, *args) + final = Dataset._construct_direct(variables, *construct_direct_args) return final @staticmethod - def _dask_postpersist(dsk, info, *args): - from dask.core import flatten - from dask.highlevelgraph import HighLevelGraph + def _dask_postpersist(dsk, info, construct_direct_args): from dask.optimization import cull variables = {} # postpersist is called in both dask.optimize and dask.persist # When persisting, we want to filter out unrelated keys for # each Variable's task graph. - for is_dask, k, v in info: - if is_dask: - rebuild, rebuild_args = v - tmp = rebuild(dsk, *rebuild_args) - keys = set(flatten(tmp.__dask_keys__())) - if isinstance(dsk, HighLevelGraph): - # __dask_postpersist__() was invoked by various functions in the - # dask.graph_manipulation module. - # - # In case of multiple layers, don't pollute a Variable's - # HighLevelGraph with layers belonging exclusively to other - # Variables. However, we need to prevent partial layers: - # https://github.com/dask/dask/issues/7137 - # TODO We're wasting a lot of key-level work. We should write a fast - # variant of HighLevelGraph.cull() that works at layer level - # only. - dsk2 = dsk.cull(keys) - dsk3 = HighLevelGraph( - {k: dsk.layers[k] for k in dsk2.layers}, dsk2.dependencies - ) - else: - # __dask_postpersist__() was invoked by dask.persist() - dsk3, _ = cull(dsk, keys) - result = rebuild(dsk3, *rebuild_args) + for k, v, dask_keys, rebuild, rebuild_args in info: + if v is None: + dsk2, _ = cull(dsk, dask_keys) + variables[k] = rebuild(dsk2, *rebuild_args) else: - result = v - variables[k] = result + variables[k] = v - return Dataset._construct_direct(variables, *args) + return Dataset._construct_direct(variables, *construct_direct_args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data