Skip to content

Commit 10495be

Browse files
mrocklinshoyer
authored andcommitted
Support Dask interface (#1674)
* add dask interface to variable * redirect compute and visualize methods to dask * add dask interface to DataArray * add dask interface to Dataset Also test distributed computing * remove visualize method * support backwards compatibility * cleanup * style edits * change versions in tests to trigger on dask dev versions * support dask arrays in DataArray coordinates * remove commented assertion * whats new * elaborate on what's new
1 parent 2a1d392 commit 10495be

File tree

6 files changed

+290
-23
lines changed

6 files changed

+290
-23
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@ What's New
1515
1616
.. _whats-new.0.10.0:
1717

18+
19+
v0.10.0 (unreleased)
20+
--------------------
21+
1822
Changes since v0.10.0 rc1 (Unreleased)
1923
--------------------------------------
2024

25+
- Experimental support for the Dask collection interface (:issue:`1674`).
26+
By `Matthew Rocklin <https://github.com/mrocklin>`_.
27+
2128
Bug fixes
2229
~~~~~~~~~
2330

xarray/core/dataarray.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,35 @@ def reset_coords(self, names=None, drop=False, inplace=False):
576576
dataset[self.name] = self.variable
577577
return dataset
578578

579+
def __dask_graph__(self):
580+
return self._to_temp_dataset().__dask_graph__()
581+
582+
def __dask_keys__(self):
583+
return self._to_temp_dataset().__dask_keys__()
584+
585+
@property
586+
def __dask_optimize__(self):
587+
return self._to_temp_dataset().__dask_optimize__
588+
589+
@property
590+
def __dask_scheduler__(self):
591+
return self._to_temp_dataset().__dask_optimize__
592+
593+
def __dask_postcompute__(self):
594+
func, args = self._to_temp_dataset().__dask_postcompute__()
595+
return self._dask_finalize, (func, args, self.name)
596+
597+
def __dask_postpersist__(self):
598+
func, args = self._to_temp_dataset().__dask_postpersist__()
599+
return self._dask_finalize, (func, args, self.name)
600+
601+
@staticmethod
602+
def _dask_finalize(results, func, args, name):
603+
ds = func(results, *args)
604+
variable = ds._variables.pop(_THIS_ARRAY)
605+
coords = ds._variables
606+
return DataArray(variable, coords, name=name, fastpath=True)
607+
579608
def load(self, **kwargs):
580609
"""Manually trigger loading of this array's data from disk or a
581610
remote source into memory and return this array.

xarray/core/dataset.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,77 @@ def load(self, **kwargs):
493493

494494
return self
495495

496+
def __dask_graph__(self):
497+
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
498+
graphs = {k: v for k, v in graphs.items() if v is not None}
499+
if not graphs:
500+
return None
501+
else:
502+
from dask import sharedict
503+
return sharedict.merge(*graphs.values())
504+
505+
def __dask_keys__(self):
506+
import dask
507+
return [v.__dask_keys__() for v in self.variables.values()
508+
if dask.is_dask_collection(v)]
509+
510+
@property
511+
def __dask_optimize__(self):
512+
import dask.array as da
513+
return da.Array.__dask_optimize__
514+
515+
@property
516+
def __dask_scheduler__(self):
517+
import dask.array as da
518+
return da.Array.__dask_scheduler__
519+
520+
def __dask_postcompute__(self):
521+
import dask
522+
info = [(True, k, v.__dask_postcompute__())
523+
if dask.is_dask_collection(v) else
524+
(False, k, v) for k, v in self._variables.items()]
525+
return self._dask_postcompute, (info, self._coord_names, self._dims,
526+
self._attrs, self._file_obj,
527+
self._encoding)
528+
529+
def __dask_postpersist__(self):
530+
import dask
531+
info = [(True, k, v.__dask_postpersist__())
532+
if dask.is_dask_collection(v) else
533+
(False, k, v) for k, v in self._variables.items()]
534+
return self._dask_postpersist, (info, self._coord_names, self._dims,
535+
self._attrs, self._file_obj,
536+
self._encoding)
537+
538+
@staticmethod
539+
def _dask_postcompute(results, info, *args):
540+
variables = OrderedDict()
541+
results2 = list(results[::-1])
542+
for is_dask, k, v in info:
543+
if is_dask:
544+
func, args2 = v
545+
r = results2.pop()
546+
result = func(r, *args2)
547+
else:
548+
result = v
549+
variables[k] = result
550+
551+
final = Dataset._construct_direct(variables, *args)
552+
return final
553+
554+
@staticmethod
555+
def _dask_postpersist(dsk, info, *args):
556+
variables = OrderedDict()
557+
for is_dask, k, v in info:
558+
if is_dask:
559+
func, args2 = v
560+
result = func(dsk, *args2)
561+
else:
562+
result = v
563+
variables[k] = result
564+
565+
return Dataset._construct_direct(variables, *args)
566+
496567
def compute(self, **kwargs):
497568
"""Manually trigger loading of this dataset's data from disk or a
498569
remote source into memory and return a new dataset. The original is

xarray/core/variable.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,41 @@ def compute(self, **kwargs):
355355
new = self.copy(deep=False)
356356
return new.load(**kwargs)
357357

358+
def __dask_graph__(self):
359+
if isinstance(self._data, dask_array_type):
360+
return self._data.__dask_graph__()
361+
else:
362+
return None
363+
364+
def __dask_keys__(self):
365+
return self._data.__dask_keys__()
366+
367+
@property
368+
def __dask_optimize__(self):
369+
return self._data.__dask_optimize__
370+
371+
@property
372+
def __dask_scheduler__(self):
373+
return self._data.__dask_scheduler__
374+
375+
def __dask_postcompute__(self):
376+
array_func, array_args = self._data.__dask_postcompute__()
377+
return self._dask_finalize, (array_func, array_args, self._dims,
378+
self._attrs, self._encoding)
379+
380+
def __dask_postpersist__(self):
381+
array_func, array_args = self._data.__dask_postpersist__()
382+
return self._dask_finalize, (array_func, array_args, self._dims,
383+
self._attrs, self._encoding)
384+
385+
@staticmethod
386+
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
387+
if isinstance(results, dict): # persist case
388+
name = array_args[0]
389+
results = {k: v for k, v in results.items() if k[0] == name} # cull
390+
data = array_func(results, *array_args)
391+
return Variable(dims, data, attrs=attrs, encoding=encoding)
392+
358393
@property
359394
def values(self):
360395
"""The variable's data as a numpy.ndarray"""

xarray/tests/test_dask.py

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,34 @@ def test_bivariate_ufunc(self):
206206
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
207207
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))
208208

209+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
210+
reason='Need dask 0.16 for new interface')
211+
def test_compute(self):
212+
u = self.eager_var
213+
v = self.lazy_var
214+
215+
assert dask.is_dask_collection(v)
216+
(v2,) = dask.compute(v + 1)
217+
assert not dask.is_dask_collection(v2)
218+
219+
assert ((u + 1).data == v2.data).all()
220+
221+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
222+
reason='Need dask 0.16 for new interface')
223+
def test_persist(self):
224+
u = self.eager_var
225+
v = self.lazy_var + 1
226+
227+
(v2,) = dask.persist(v)
228+
assert v is not v2
229+
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
230+
assert v2.__dask_keys__() == v.__dask_keys__()
231+
assert dask.is_dask_collection(v)
232+
assert dask.is_dask_collection(v2)
233+
234+
self.assertLazyAndAllClose(u + 1, v)
235+
self.assertLazyAndAllClose(u + 1, v2)
236+
209237

210238
class TestDataArrayAndDataset(DaskTestCase):
211239
def assertLazyAndIdentical(self, expected, actual):
@@ -251,6 +279,34 @@ def test_lazy_array(self):
251279
actual = xr.concat([v[:2], v[2:]], 'x')
252280
self.assertLazyAndAllClose(u, actual)
253281

282+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
283+
reason='Need dask 0.16 for new interface')
284+
def test_compute(self):
285+
u = self.eager_array
286+
v = self.lazy_array
287+
288+
assert dask.is_dask_collection(v)
289+
(v2,) = dask.compute(v + 1)
290+
assert not dask.is_dask_collection(v2)
291+
292+
assert ((u + 1).data == v2.data).all()
293+
294+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
295+
reason='Need dask 0.16 for new interface')
296+
def test_persist(self):
297+
u = self.eager_array
298+
v = self.lazy_array + 1
299+
300+
(v2,) = dask.persist(v)
301+
assert v is not v2
302+
assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
303+
assert v2.__dask_keys__() == v.__dask_keys__()
304+
assert dask.is_dask_collection(v)
305+
assert dask.is_dask_collection(v2)
306+
307+
self.assertLazyAndAllClose(u + 1, v)
308+
self.assertLazyAndAllClose(u + 1, v2)
309+
254310
def test_concat_loads_variables(self):
255311
# Test that concat() computes not-in-memory variables at most once
256312
# and loads them in the output, while leaving the input unaltered.
@@ -402,28 +458,6 @@ def counting_get(*args, **kwargs):
402458
ds.load()
403459
self.assertEqual(count[0], 1)
404460

405-
def test_persist_Dataset(self):
406-
ds = Dataset({'foo': ('x', range(5)),
407-
'bar': ('x', range(5))}).chunk()
408-
ds = ds + 1
409-
n = len(ds.foo.data.dask)
410-
411-
ds2 = ds.persist()
412-
413-
assert len(ds2.foo.data.dask) == 1
414-
assert len(ds.foo.data.dask) == n # doesn't mutate in place
415-
416-
def test_persist_DataArray(self):
417-
x = da.arange(10, chunks=(5,))
418-
y = DataArray(x)
419-
z = y + 1
420-
n = len(z.data.dask)
421-
422-
zz = z.persist()
423-
424-
assert len(z.data.dask) == n
425-
assert len(zz.data.dask) == zz.data.npartitions
426-
427461
def test_stack(self):
428462
data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4))
429463
arr = DataArray(data, dims=('w', 'x', 'y'))
@@ -737,3 +771,62 @@ def build_dask_array(name):
737771
return dask.array.Array(
738772
dask={(name, 0): (kernel, name)}, name=name,
739773
chunks=((1,),), dtype=np.int64)
774+
775+
776+
# test both the perist method and the dask.persist function
777+
# the dask.persist function requires a new version of dask
778+
@pytest.mark.parametrize('persist', [
779+
lambda x: x.persist(),
780+
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
781+
lambda x: dask.persist(x)[0],
782+
reason='Need Dask 0.16')
783+
])
784+
def test_persist_Dataset(persist):
785+
ds = Dataset({'foo': ('x', range(5)),
786+
'bar': ('x', range(5))}).chunk()
787+
ds = ds + 1
788+
n = len(ds.foo.data.dask)
789+
790+
ds2 = persist(ds)
791+
792+
assert len(ds2.foo.data.dask) == 1
793+
assert len(ds.foo.data.dask) == n # doesn't mutate in place
794+
795+
796+
@pytest.mark.parametrize('persist', [
797+
lambda x: x.persist(),
798+
pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
799+
lambda x: dask.persist(x)[0],
800+
reason='Need Dask 0.16')
801+
])
802+
def test_persist_DataArray(persist):
803+
x = da.arange(10, chunks=(5,))
804+
y = DataArray(x)
805+
z = y + 1
806+
n = len(z.data.dask)
807+
808+
zz = persist(z)
809+
810+
assert len(z.data.dask) == n
811+
assert len(zz.data.dask) == zz.data.npartitions
812+
813+
814+
@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4',
815+
reason='Need dask 0.16 for new interface')
816+
def test_dataarray_with_dask_coords():
817+
import toolz
818+
x = xr.Variable('x', da.arange(8, chunks=(4,)))
819+
y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2)
820+
data = da.random.random((8, 8), chunks=(4, 4)) + 1
821+
array = xr.DataArray(data, dims=['x', 'y'])
822+
array.coords['xx'] = x
823+
array.coords['yy'] = y
824+
825+
assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(),
826+
x.__dask_graph__(),
827+
y.__dask_graph__())
828+
829+
(array2,) = dask.compute(array)
830+
assert not dask.is_dask_collection(array2)
831+
832+
assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())

xarray/tests/test_distributed.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
distributed = pytest.importorskip('distributed')
66
da = pytest.importorskip('dask.array')
7-
from distributed.utils_test import cluster, loop
7+
import dask
8+
from distributed.utils_test import cluster, loop, gen_cluster
9+
from distributed.client import futures_of, wait
810

911
from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS
1012
from xarray.tests.test_dataset import create_test_data
@@ -32,3 +34,33 @@ def test_dask_distributed_integration_test(loop, engine):
3234
assert isinstance(restored.var1.data, da.Array)
3335
computed = restored.compute()
3436
assert_allclose(original, computed)
37+
38+
39+
@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
40+
reason='Need recent distributed version to clean up get')
41+
@gen_cluster(client=True, timeout=None)
42+
def test_async(c, s, a, b):
43+
x = create_test_data()
44+
assert not dask.is_dask_collection(x)
45+
y = x.chunk({'dim2': 4}) + 10
46+
assert dask.is_dask_collection(y)
47+
assert dask.is_dask_collection(y.var1)
48+
assert dask.is_dask_collection(y.var2)
49+
50+
z = y.persist()
51+
assert str(z)
52+
53+
assert dask.is_dask_collection(z)
54+
assert dask.is_dask_collection(z.var1)
55+
assert dask.is_dask_collection(z.var2)
56+
assert len(y.__dask_graph__()) > len(z.__dask_graph__())
57+
58+
assert not futures_of(y)
59+
assert futures_of(z)
60+
61+
future = c.compute(z)
62+
w = yield future
63+
assert not dask.is_dask_collection(w)
64+
assert_allclose(x + 10, w)
65+
66+
assert s.task_state

0 commit comments

Comments
 (0)