Skip to content

Support Dask interface #1674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 7, 2017
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ What's New

.. _whats-new.0.10.0:


v0.10.0 (unreleased)
--------------------

Changes since v0.10.0 rc1 (Unreleased)
--------------------------------------

- Experimental support for the Dask collection interface (:issue:`1674`).
By `Matthew Rocklin <https://github.com/mrocklin>`_.

Bug fixes
~~~~~~~~~

Expand Down
29 changes: 29 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,35 @@ def reset_coords(self, names=None, drop=False, inplace=False):
dataset[self.name] = self.variable
return dataset

def __dask_graph__(self):
return self._to_temp_dataset().__dask_graph__()

def __dask_keys__(self):
return self._to_temp_dataset().__dask_keys__()

@property
def __dask_optimize__(self):
return self._to_temp_dataset().__dask_optimize__

@property
def __dask_scheduler__(self):
return self._to_temp_dataset().__dask_optimize__

def __dask_postcompute__(self):
func, args = self._to_temp_dataset().__dask_postcompute__()
return self._dask_finalize, (func, args, self.name)

def __dask_postpersist__(self):
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (func, args, self.name)

@staticmethod
def _dask_finalize(results, func, args, name):
ds = func(results, *args)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
return DataArray(variable, coords, name=name, fastpath=True)

def load(self, **kwargs):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return this array.
Expand Down
71 changes: 71 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,77 @@ def load(self, **kwargs):

return self

def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
graphs = {k: v for k, v in graphs.items() if v is not None}
if not graphs:
return None
else:
from dask import sharedict
return sharedict.merge(*graphs.values())

def __dask_keys__(self):
import dask
return [v.__dask_keys__() for v in self.variables.values()
if dask.is_dask_collection(v)]

@property
def __dask_optimize__(self):
import dask.array as da
return da.Array.__dask_optimize__

@property
def __dask_scheduler__(self):
import dask.array as da
return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
import dask
info = [(True, k, v.__dask_postcompute__())
if dask.is_dask_collection(v) else
(False, k, v) for k, v in self._variables.items()]
return self._dask_postcompute, (info, self._coord_names, self._dims,
self._attrs, self._file_obj,
self._encoding)

def __dask_postpersist__(self):
import dask
info = [(True, k, v.__dask_postpersist__())
if dask.is_dask_collection(v) else
(False, k, v) for k, v in self._variables.items()]
return self._dask_postpersist, (info, self._coord_names, self._dims,
self._attrs, self._file_obj,
self._encoding)

@staticmethod
def _dask_postcompute(results, info, *args):
variables = OrderedDict()
results2 = list(results[::-1])
for is_dask, k, v in info:
if is_dask:
func, args2 = v
r = results2.pop()
result = func(r, *args2)
else:
result = v
variables[k] = result

final = Dataset._construct_direct(variables, *args)
return final

@staticmethod
def _dask_postpersist(dsk, info, *args):
variables = OrderedDict()
for is_dask, k, v in info:
if is_dask:
func, args2 = v
result = func(dsk, *args2)
else:
result = v
variables[k] = result

return Dataset._construct_direct(variables, *args)

def compute(self, **kwargs):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
Expand Down
35 changes: 35 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,41 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def __dask_graph__(self):
if isinstance(self._data, dask_array_type):
return self._data.__dask_graph__()
else:
return None

def __dask_keys__(self):
return self._data.__dask_keys__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK if these methods error (with AttributeError) when self._data is not a dask array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we always check if the object is a dask collection first by calling __dask_graph__


@property
def __dask_optimize__(self):
return self._data.__dask_optimize__

@property
def __dask_scheduler__(self):
return self._data.__dask_scheduler__

def __dask_postcompute__(self):
array_func, array_args = self._data.__dask_postcompute__()
return self._dask_finalize, (array_func, array_args, self._dims,
self._attrs, self._encoding)

def __dask_postpersist__(self):
array_func, array_args = self._data.__dask_postpersist__()
return self._dask_finalize, (array_func, array_args, self._dims,
self._attrs, self._encoding)

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
if isinstance(results, dict): # persist case
name = array_args[0]
results = {k: v for k, v in results.items() if k[0] == name} # cull
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)

@property
def values(self):
"""The variable's data as a numpy.ndarray"""
Expand Down
137 changes: 115 additions & 22 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@ def test_bivariate_ufunc(self):
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_compute(self):
u = self.eager_var
v = self.lazy_var

assert dask.is_dask_collection(v)
(v2,) = dask.compute(v + 1)
assert not dask.is_dask_collection(v2)

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_persist(self):
u = self.eager_var
v = self.lazy_var + 1

(v2,) = dask.persist(v)
assert v is not v2
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
assert v2.__dask_keys__() == v.__dask_keys__()
assert dask.is_dask_collection(v)
assert dask.is_dask_collection(v2)

self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)


class TestDataArrayAndDataset(DaskTestCase):
def assertLazyAndIdentical(self, expected, actual):
Expand Down Expand Up @@ -251,6 +279,34 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_compute(self):
u = self.eager_array
v = self.lazy_array

assert dask.is_dask_collection(v)
(v2,) = dask.compute(v + 1)
assert not dask.is_dask_collection(v2)

assert ((u + 1).data == v2.data).all()

@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_persist(self):
u = self.eager_array
v = self.lazy_array + 1

(v2,) = dask.persist(v)
assert v is not v2
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
assert v2.__dask_keys__() == v.__dask_keys__()
assert dask.is_dask_collection(v)
assert dask.is_dask_collection(v2)

self.assertLazyAndAllClose(u + 1, v)
self.assertLazyAndAllClose(u + 1, v2)

def test_concat_loads_variables(self):
# Test that concat() computes not-in-memory variables at most once
# and loads them in the output, while leaving the input unaltered.
Expand Down Expand Up @@ -402,28 +458,6 @@ def counting_get(*args, **kwargs):
ds.load()
self.assertEqual(count[0], 1)

def test_persist_Dataset(self):
ds = Dataset({'foo': ('x', range(5)),
'bar': ('x', range(5))}).chunk()
ds = ds + 1
n = len(ds.foo.data.dask)

ds2 = ds.persist()

assert len(ds2.foo.data.dask) == 1
assert len(ds.foo.data.dask) == n # doesn't mutate in place

def test_persist_DataArray(self):
x = da.arange(10, chunks=(5,))
y = DataArray(x)
z = y + 1
n = len(z.data.dask)

zz = z.persist()

assert len(z.data.dask) == n
assert len(zz.data.dask) == zz.data.npartitions

def test_stack(self):
data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4))
arr = DataArray(data, dims=('w', 'x', 'y'))
Expand Down Expand Up @@ -737,3 +771,62 @@ def build_dask_array(name):
return dask.array.Array(
dask={(name, 0): (kernel, name)}, name=name,
chunks=((1,),), dtype=np.int64)


# test both the perist method and the dask.persist function
# the dask.persist function requires a new version of dask
@pytest.mark.parametrize('persist', [
lambda x: x.persist(),
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
lambda x: dask.persist(x)[0],
reason='Need Dask 0.16')
])
def test_persist_Dataset(persist):
ds = Dataset({'foo': ('x', range(5)),
'bar': ('x', range(5))}).chunk()
ds = ds + 1
n = len(ds.foo.data.dask)

ds2 = persist(ds)

assert len(ds2.foo.data.dask) == 1
assert len(ds.foo.data.dask) == n # doesn't mutate in place


@pytest.mark.parametrize('persist', [
lambda x: x.persist(),
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
lambda x: dask.persist(x)[0],
reason='Need Dask 0.16')
])
def test_persist_DataArray(persist):
x = da.arange(10, chunks=(5,))
y = DataArray(x)
z = y + 1
n = len(z.data.dask)

zz = persist(z)

assert len(z.data.dask) == n
assert len(zz.data.dask) == zz.data.npartitions


@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
reason='Need dask 0.16 for new interface')
def test_dataarray_with_dask_coords():
import toolz
x = xr.Variable('x', da.arange(8, chunks=(4,)))
y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2)
data = da.random.random((8, 8), chunks=(4, 4)) + 1
array = xr.DataArray(data, dims=['x', 'y'])
array.coords['xx'] = x
array.coords['yy'] = y

assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(),
x.__dask_graph__(),
y.__dask_graph__())

(array2,) = dask.compute(array)
assert not dask.is_dask_collection(array2)

assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())
34 changes: 33 additions & 1 deletion xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

distributed = pytest.importorskip('distributed')
da = pytest.importorskip('dask.array')
from distributed.utils_test import cluster, loop
import dask
from distributed.utils_test import cluster, loop, gen_cluster
from distributed.client import futures_of, wait

from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS
from xarray.tests.test_dataset import create_test_data
Expand Down Expand Up @@ -32,3 +34,33 @@ def test_dask_distributed_integration_test(loop, engine):
assert isinstance(restored.var1.data, da.Array)
computed = restored.compute()
assert_allclose(original, computed)


@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
reason='Need recent distributed version to clean up get')
@gen_cluster(client=True, timeout=None)
def test_async(c, s, a, b):
x = create_test_data()
assert not dask.is_dask_collection(x)
y = x.chunk({'dim2': 4}) + 10
assert dask.is_dask_collection(y)
assert dask.is_dask_collection(y.var1)
assert dask.is_dask_collection(y.var2)

z = y.persist()
assert str(z)

assert dask.is_dask_collection(z)
assert dask.is_dask_collection(z.var1)
assert dask.is_dask_collection(z.var2)
assert len(y.__dask_graph__()) > len(z.__dask_graph__())

assert not futures_of(y)
assert futures_of(z)

future = c.compute(z)
w = yield future
assert not dask.is_dask_collection(w)
assert_allclose(x + 10, w)

assert s.task_state