Skip to content

Support for dask.graph_manipulation #4965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ v0.17.1 (unreleased)

New Features
~~~~~~~~~~~~

- Support for `dask.graph_manipulation
<https://docs.dask.org/en/latest/graph_manipulation.html>`_ (requires dask >=2021.3)
By `Guido Imperiale <https://github.com/crusaderky>`_

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,15 +839,15 @@ def __dask_scheduler__(self):

def __dask_postcompute__(self):
func, args = self._to_temp_dataset().__dask_postcompute__()
return self._dask_finalize, (func, args, self.name)
return self._dask_finalize, (self.name, func) + args

def __dask_postpersist__(self):
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (func, args, self.name)
return self._dask_finalize, (self.name, func) + args

@staticmethod
def _dask_finalize(results, func, args, name):
ds = func(results, *args)
def _dask_finalize(results, name, func, *args, **kwargs):
ds = func(results, *args, **kwargs)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
return DataArray(variable, coords, name=name, fastpath=True)
Expand Down
107 changes: 59 additions & 48 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,72 +863,83 @@ def __dask_scheduler__(self):
return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
return self._dask_postcompute, ()

def __dask_postpersist__(self):
return self._dask_postpersist, ()

def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset":
import dask

info = [
(k, None) + v.__dask_postcompute__()
if dask.is_dask_collection(v)
else (k, v, None, None)
for k, v in self._variables.items()
]
construct_direct_args = (
variables = {}
results_iter = iter(results)

for k, v in self._variables.items():
if dask.is_dask_collection(v):
rebuild, args = v.__dask_postcompute__()
v = rebuild(next(results_iter), *args)
variables[k] = v

return Dataset._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postcompute, (info, construct_direct_args)

def __dask_postpersist__(self):
import dask
def _dask_postpersist(
self, dsk: Mapping, *, rename: Mapping[str, str] = None
) -> "Dataset":
from dask import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull

info = [
(k, None, v.__dask_keys__()) + v.__dask_postpersist__()
if dask.is_dask_collection(v)
else (k, v, None, None, None)
for k, v in self._variables.items()
]
construct_direct_args = (
variables = {}

for k, v in self._variables.items():
if not is_dask_collection(v):
variables[k] = v
continue

if isinstance(dsk, HighLevelGraph):
# dask >= 2021.3
# __dask_postpersist__() was called by dask.highlevelgraph.
# Don't use dsk.cull(), as we need to prevent partial layers:
# https://github.com/dask/dask/issues/7137
layers = v.__dask_layers__()
if rename:
layers = [rename.get(k, k) for k in layers]
dsk2 = dsk.cull_layers(layers)
elif rename: # pragma: nocover
# At the moment of writing, this is only for forward compatibility.
# replace_name_in_key requires dask >= 2021.3.
from dask.base import flatten, replace_name_in_key

keys = [
replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__())
]
dsk2, _ = cull(dsk, keys)
else:
# __dask_postpersist__() was called by dask.optimize or dask.persist
dsk2, _ = cull(dsk, v.__dask_keys__())

rebuild, args = v.__dask_postpersist__()
# rename was added in dask 2021.3
kwargs = {"rename": rename} if rename else {}
variables[k] = rebuild(dsk2, *args, **kwargs)

return Dataset._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postpersist, (info, construct_direct_args)

@staticmethod
def _dask_postcompute(results, info, construct_direct_args):
variables = {}
results_iter = iter(results)
for k, v, rebuild, rebuild_args in info:
if v is None:
variables[k] = rebuild(next(results_iter), *rebuild_args)
else:
variables[k] = v

final = Dataset._construct_direct(variables, *construct_direct_args)
return final

@staticmethod
def _dask_postpersist(dsk, info, construct_direct_args):
from dask.optimization import cull

variables = {}
# postpersist is called in both dask.optimize and dask.persist
# When persisting, we want to filter out unrelated keys for
# each Variable's task graph.
for k, v, dask_keys, rebuild, rebuild_args in info:
if v is None:
dsk2, _ = cull(dsk, dask_keys)
variables[k] = rebuild(dsk2, *rebuild_args)
else:
variables[k] = v

return Dataset._construct_direct(variables, *construct_direct_args)

def compute(self, **kwargs) -> "Dataset":
"""Manually trigger loading and/or computation of this dataset's data
Expand Down
17 changes: 5 additions & 12 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,22 +531,15 @@ def __dask_scheduler__(self):

def __dask_postcompute__(self):
array_func, array_args = self._data.__dask_postcompute__()
return (
self._dask_finalize,
(array_func, array_args, self._dims, self._attrs, self._encoding),
)
return self._dask_finalize, (array_func,) + array_args

def __dask_postpersist__(self):
array_func, array_args = self._data.__dask_postpersist__()
return (
self._dask_finalize,
(array_func, array_args, self._dims, self._attrs, self._encoding),
)
return self._dask_finalize, (array_func,) + array_args

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)
def _dask_finalize(self, results, array_func, *args, **kwargs):
data = array_func(results, *args, **kwargs)
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)

@property
def values(self):
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,3 +1599,38 @@ def test_optimize():
arr = xr.DataArray(a).chunk(5)
(arr2,) = dask.optimize(arr)
arr2.compute()


# The graph_manipulation module is in dask since 2021.2 but it became usable with
# xarray only since 2021.3
@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module")
def test_graph_manipulation():
"""dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder
function returned by __dask_postperist__; also, the dsk passed to the rebuilder is
a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict.
"""
import dask.graph_manipulation as gm

v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2
da = DataArray(v)
ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])})

v2, da2, ds2 = gm.clone(v, da, ds)

assert_equal(v2, v)
assert_equal(da2, da)
assert_equal(ds2, ds)

for a, b in ((v, v2), (da, da2), (ds, ds2)):
assert a.__dask_layers__() != b.__dask_layers__()
assert len(a.__dask_layers__()) == len(b.__dask_layers__())
assert a.__dask_graph__().keys() != b.__dask_graph__().keys()
assert len(a.__dask_graph__()) == len(b.__dask_graph__())
assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys()
assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers)

# Above we performed a slice operation; adding the two slices back together creates
# a diamond-shaped dependency graph, which in turn will trigger a collision in layer
# names if we were to use HighLevelGraph.cull() instead of
# HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__().
assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2)