From a825faaf60dc75e0365f18a0f24acb0fe288b263 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Tue, 9 Feb 2021 16:03:14 +0000
Subject: [PATCH 1/2] Compatibility with dask 2021.02.0

---
 ci/requirements/environment-windows.yml |  2 +-
 ci/requirements/environment.yml         |  2 +-
 xarray/core/dataset.py                  | 32 +++++++++++++++++++------
 3 files changed, 27 insertions(+), 9 deletions(-)

diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml
index 9455ef2f127..6de2bc8dc64 100644
--- a/ci/requirements/environment-windows.yml
+++ b/ci/requirements/environment-windows.yml
@@ -8,7 +8,7 @@ dependencies:
   # - cdms2  # Not available on Windows
   # - cfgrib  # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340
   - cftime
-  - dask<2021.02.0
+  - dask
   - distributed
   - h5netcdf
   - h5py=2
diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml
index 7261b5b6954..0f59d9570c8 100644
--- a/ci/requirements/environment.yml
+++ b/ci/requirements/environment.yml
@@ -9,7 +9,7 @@ dependencies:
   - cdms2
   - cfgrib
   - cftime
-  - dask<2021.02.0
+  - dask
   - distributed
   - h5netcdf
   - h5py=2
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 7d51adb5244..64c3899f597 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -920,20 +920,38 @@ def _dask_postcompute(results, info, *args):
 
     @staticmethod
     def _dask_postpersist(dsk, info, *args):
+        from dask.core import flatten
+        from dask.highlevelgraph import HighLevelGraph
+        from dask.optimization import cull
+
         variables = {}
         # postpersist is called in both dask.optimize and dask.persist
         # When persisting, we want to filter out unrelated keys for
         # each Variable's task graph.
-        is_persist = len(dsk) == len(info)
         for is_dask, k, v in info:
             if is_dask:
-                func, args2 = v
-                if is_persist:
-                    name = args2[1][0]
-                    dsk2 = {k: v for k, v in dsk.items() if k[0] == name}
+                rebuild, rebuild_args = v
+                tmp = rebuild(dsk, *rebuild_args)
+                keys = set(flatten(tmp.__dask_keys__()))
+                if isinstance(dsk, HighLevelGraph):
+                    # __dask_postpersist__() was invoked by various functions in the
+                    # dask.graph_manipulation module.
+                    #
+                    # In case of multiple layers, don't pollute a Variable's
+                    # HighLevelGraph with layers belonging exclusively to other
+                    # Variables. However, we need to prevent partial layers:
+                    # https://github.com/dask/dask/issues/7137
+                    # TODO We're wasting a lot of key-level work. We should write a fast
+                    #      variant of HighLevelGraph.cull() that works at layer level
+                    #      only.
+                    dsk2 = dsk.cull(keys)
+                    dsk3 = HighLevelGraph(
+                        {k: dsk.layers[k] for k in dsk2.layers}, dsk2.dependencies
+                    )
                 else:
-                    dsk2 = dsk
-                result = func(dsk2, *args2)
+                    # __dask_postpersist__() was invoked by dask.persist()
+                    dsk3, _ = cull(dsk, keys)
+                result = rebuild(dsk3, *rebuild_args)
             else:
                 result = v
             variables[k] = result

From 628263d3034ad190e4f52429ae55fb41564bdea8 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Thu, 11 Feb 2021 14:46:17 +0000
Subject: [PATCH 2/2] Rework postpersist and postcompute

---
 xarray/core/dataset.py | 72 +++++++++++++-----------------------------
 1 file changed, 22 insertions(+), 50 deletions(-)

diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index 64c3899f597..066a2f690b0 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -866,13 +866,12 @@ def __dask_postcompute__(self):
         import dask
 
         info = [
-            (True, k, v.__dask_postcompute__())
+            (k, None) + v.__dask_postcompute__()
             if dask.is_dask_collection(v)
-            else (False, k, v)
+            else (k, v, None, None)
             for k, v in self._variables.items()
         ]
-        args = (
-            info,
+        construct_direct_args = (
             self._coord_names,
             self._dims,
             self._attrs,
@@ -880,19 +879,18 @@ def __dask_postcompute__(self):
             self._encoding,
             self._close,
         )
-        return self._dask_postcompute, args
+        return self._dask_postcompute, (info, construct_direct_args)
 
     def __dask_postpersist__(self):
         import dask
 
         info = [
-            (True, k, v.__dask_postpersist__())
+            (k, None, v.__dask_keys__()) + v.__dask_postpersist__()
             if dask.is_dask_collection(v)
-            else (False, k, v)
+            else (k, v, None, None, None)
             for k, v in self._variables.items()
         ]
-        args = (
-            info,
+        construct_direct_args = (
             self._coord_names,
             self._dims,
             self._attrs,
@@ -900,63 +898,37 @@ def __dask_postpersist__(self):
             self._encoding,
             self._close,
         )
-        return self._dask_postpersist, args
+        return self._dask_postpersist, (info, construct_direct_args)
 
     @staticmethod
-    def _dask_postcompute(results, info, *args):
+    def _dask_postcompute(results, info, construct_direct_args):
         variables = {}
-        results2 = list(results[::-1])
-        for is_dask, k, v in info:
-            if is_dask:
-                func, args2 = v
-                r = results2.pop()
-                result = func(r, *args2)
+        results_iter = iter(results)
+        for k, v, rebuild, rebuild_args in info:
+            if v is None:
+                variables[k] = rebuild(next(results_iter), *rebuild_args)
             else:
-                result = v
-            variables[k] = result
+                variables[k] = v
 
-        final = Dataset._construct_direct(variables, *args)
+        final = Dataset._construct_direct(variables, *construct_direct_args)
         return final
 
     @staticmethod
-    def _dask_postpersist(dsk, info, *args):
-        from dask.core import flatten
-        from dask.highlevelgraph import HighLevelGraph
+    def _dask_postpersist(dsk, info, construct_direct_args):
         from dask.optimization import cull
 
         variables = {}
         # postpersist is called in both dask.optimize and dask.persist
         # When persisting, we want to filter out unrelated keys for
         # each Variable's task graph.
-        for is_dask, k, v in info:
-            if is_dask:
-                rebuild, rebuild_args = v
-                tmp = rebuild(dsk, *rebuild_args)
-                keys = set(flatten(tmp.__dask_keys__()))
-                if isinstance(dsk, HighLevelGraph):
-                    # __dask_postpersist__() was invoked by various functions in the
-                    # dask.graph_manipulation module.
-                    #
-                    # In case of multiple layers, don't pollute a Variable's
-                    # HighLevelGraph with layers belonging exclusively to other
-                    # Variables. However, we need to prevent partial layers:
-                    # https://github.com/dask/dask/issues/7137
-                    # TODO We're wasting a lot of key-level work. We should write a fast
-                    #      variant of HighLevelGraph.cull() that works at layer level
-                    #      only.
-                    dsk2 = dsk.cull(keys)
-                    dsk3 = HighLevelGraph(
-                        {k: dsk.layers[k] for k in dsk2.layers}, dsk2.dependencies
-                    )
-                else:
-                    # __dask_postpersist__() was invoked by dask.persist()
-                    dsk3, _ = cull(dsk, keys)
-                result = rebuild(dsk3, *rebuild_args)
+        for k, v, dask_keys, rebuild, rebuild_args in info:
+            if v is None:
+                dsk2, _ = cull(dsk, dask_keys)
+                variables[k] = rebuild(dsk2, *rebuild_args)
             else:
-                result = v
-            variables[k] = result
+                variables[k] = v
 
-        return Dataset._construct_direct(variables, *args)
+        return Dataset._construct_direct(variables, *construct_direct_args)
 
     def compute(self, **kwargs) -> "Dataset":
         """Manually trigger loading and/or computation of this dataset's data