From f65cb526e833e3b8300a75f8525b0e510eb04314 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Oct 2017 05:11:17 -0400 Subject: [PATCH 01/13] add dask interface to variable --- xarray/core/variable.py | 30 ++++++++++++++++++++++++++++++ xarray/tests/test_dask.py | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 57bb21e4dc3..714389e65bb 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -363,6 +363,36 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def __dask_graph__(self): + if isinstance(self._data, dask_array_type): + return self._data.__dask_graph__() + else: + return None + + def __dask_keys__(self): + return self._data.__dask_keys__() + + @property + def __dask_optimize__(self): + return self._data.__dask_optimize__ + + @property + def __dask_scheduler__(self): + return self._data.__dask_scheduler__ + + def __dask_postcompute__(self): + array_func, array_args = self._data.__dask_postcompute__() + return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding) + + def __dask_postpersist__(self): + array_func, array_args = self._data.__dask_postpersist__() + return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding) + + @staticmethod + def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): + data = array_func(results, *array_args) + return Variable(dims, data, attrs=attrs, encoding=encoding) + @property def values(self): """The variable's data as a numpy.ndarray""" diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f830c502c67..648053d1530 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -206,6 +206,30 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) + def test_compute(self): + u = self.eager_var + v = self.lazy_var + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + def test_persist(self): + u = self.eager_var + v = self.lazy_var + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + class TestDataArrayAndDataset(DaskTestCase): def assertLazyAndIdentical(self, expected, actual): From 4b590404d57dbef0233c6942c0f948f8f710d674 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Oct 2017 05:29:54 -0400 Subject: [PATCH 02/13] redirect compute and visualize methods to dask --- xarray/core/variable.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 714389e65bb..7523d30866a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -360,8 +360,13 @@ def compute(self, **kwargs): -------- dask.array.compute """ - new = self.copy(deep=False) - return new.load(**kwargs) + import dask + (result,) = dask.compute(self, **kwargs) + return result + + def visualize(self, **kwargs): + import dask + return dask.visualize(self, **kwargs) def __dask_graph__(self): if isinstance(self._data, dask_array_type): From 68cddff7613d185465ed78ec2798790721e38c1c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Oct 2017 06:59:40 -0400 Subject: [PATCH 03/13] add dask interface to DataArray --- xarray/core/dataarray.py | 32 ++++++++++++++++++++++++++++++++ xarray/tests/test_dask.py | 25 +++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e8330ef6c77..f42e532db22 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -576,6 +576,38 @@ def reset_coords(self, names=None, drop=False, inplace=False): dataset[self.name] = self.variable return dataset + def visualize(self, **kwargs): + import dask + return dask.visualize(self, **kwargs) + + def __dask_graph__(self): + return self._variable.__dask_graph__() + # These fully describe a DataArray + + def __dask_keys__(self): + return self._variable.__dask_keys__() + + @property + def __dask_optimize__(self): + return self._variable.__dask_optimize__ + + @property + def __dask_scheduler__(self): + return self._variable.__dask_scheduler__ + + def __dask_postcompute__(self): + variable_func, variable_args = self._variable.__dask_postcompute__() + return self._dask_finalize, (variable_func, variable_args, self._coords, self._name) + + def __dask_postpersist__(self): + variable_func, variable_args = self._variable.__dask_postpersist__() + return self._dask_finalize, (variable_func, variable_args, self._coords, self._name) + + @staticmethod + def _dask_finalize(results, variable_func, variable_args, coords, name): + var = variable_func(results, *variable_args) + return DataArray(var, coords=coords, name=name) + def load(self, **kwargs): """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 648053d1530..f6c66fa9812 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -275,6 +275,30 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) + def test_compute(self): + u = self.eager_array + v = self.lazy_array + + assert dask.is_dask_collection(v) + (v2,) = dask.compute(v + 1) + assert not dask.is_dask_collection(v2) + + assert ((u + 1).data == v2.data).all() + + def test_persist(self): + u = self.eager_array + v = self.lazy_array + 1 + + (v2,) = dask.persist(v) + assert v is not v2 + assert len(v2.__dask_graph__()) < len(v.__dask_graph__()) + assert v2.__dask_keys__() == v.__dask_keys__() + assert dask.is_dask_collection(v) + assert dask.is_dask_collection(v2) + + self.assertLazyAndAllClose(u + 1, v) + self.assertLazyAndAllClose(u + 1, v2) + def test_concat_loads_variables(self): # Test that concat() computes not-in-memory variables at most once # and loads them in the output, while leaving the input unaltered. @@ -666,6 +690,7 @@ def test_to_dask_dataframe_no_coordinate(self): assert_frame_equal(expected, actual.compute()) +@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute']) def test_dask_kwargs_variable(method): x = Variable('y', da.from_array(np.arange(3), chunks=(2,))) From 5429da176fede7318e85502aa17f18af78bffdc3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Oct 2017 15:13:48 -0400 Subject: [PATCH 04/13] add dask interface to Dataset Also test distributed computing --- xarray/core/dataarray.py | 11 +++-- xarray/core/dataset.py | 83 ++++++++++++++++++++++++++++++-- xarray/core/variable.py | 3 ++ xarray/tests/test_dask.py | 51 +++++++++++--------- xarray/tests/test_distributed.py | 35 +++++++++++++- 5 files changed, 151 insertions(+), 32 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f42e532db22..fd0c38e3b12 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -582,7 +582,6 @@ def visualize(self, **kwargs): def __dask_graph__(self): return self._variable.__dask_graph__() - # These fully describe a DataArray def __dask_keys__(self): return self._variable.__dask_keys__() @@ -651,8 +650,9 @@ def compute(self, **kwargs): -------- dask.array.compute """ - new = self.copy(deep=False) - return new.load(**kwargs) + import dask + (result,) = dask.compute(self, **kwargs) + return result def persist(self, **kwargs): """ Trigger computation in constituent dask arrays @@ -670,8 +670,9 @@ def persist(self, **kwargs): -------- dask.persist """ - ds = self._to_temp_dataset().persist(**kwargs) - return self._from_temp_dataset(ds) + import dask + (result,) = dask.persist(self, **kwargs) + return result def copy(self, deep=True): """Returns a copy of this array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0f9aa4c8229..201ed85096b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -493,6 +493,79 @@ def load(self, **kwargs): return self + def visualize(self, **kwargs): + import dask + return dask.visualize(self, **kwargs) + + def __dask_graph__(self): + graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} + graphs = {k: v for k, v in graphs.items() if v is not None} + if not graphs: + return None + else: + from dask import sharedict + return sharedict.merge(*graphs.values()) + + def __dask_keys__(self): + import dask + return [v.__dask_keys__() for v in self.variables.values() + if dask.is_dask_collection(v)] + + @property + def __dask_optimize__(self): + import dask.array as da + return da.Array.__dask_optimize__ + + @property + def __dask_scheduler__(self): + import dask.array as da + return da.Array.__dask_scheduler__ + + def __dask_postcompute__(self): + import dask + info = [(True, k, v.__dask_postcompute__()) + if dask.is_dask_collection(v) else + (False, k, v) for k, v in self._variables.items()] + return self._dask_postcompute, (info, self._coord_names, self._dims, + self._attrs, self._file_obj, self._encoding) + + def __dask_postpersist__(self): + import dask + info = [(True, k, v.__dask_postpersist__()) + if dask.is_dask_collection(v) else + (False, k, v) for k, v in self._variables.items()] + return self._dask_postpersist, (info, self._coord_names, self._dims, + self._attrs, self._file_obj, self._encoding) + + @staticmethod + def _dask_postcompute(results, info, *args): + variables = OrderedDict() + results2 = results[::-1] + for is_dask, k, v in info: + if is_dask: + func, args2 = v + r = results2.pop() + result = func(r, *args2) + else: + result = v + variables[k] = result + + final = Dataset._construct_direct(variables, *args) + return final + + @staticmethod + def _dask_postpersist(dsk, info, *args): + variables = OrderedDict() + for is_dask, k, v in info: + if is_dask: + func, args2 = v + result = func(dsk, *args2) + else: + result = v + variables[k] = result + + return Dataset._construct_direct(variables, *args) + def compute(self, **kwargs): """Manually trigger loading of this dataset's data from disk or a remote source into memory and return a new dataset. The original is @@ -549,8 +622,9 @@ def persist(self, **kwargs): -------- dask.persist """ - new = self.copy(deep=False) - return new._persist_inplace(**kwargs) + import dask + (result,) = dask.persist(self, **kwargs) + return result @classmethod def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, @@ -558,6 +632,7 @@ def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, """Shortcut around __init__ for internal use when we want to skip costly validation """ + assert not callable(coord_names), (cls, variables, coord_names, dims, attrs) obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -2424,7 +2499,7 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs): ------- applied : Dataset Resulting dataset from applying ``func`` over each data variable. - + Examples -------- >>> da = xr.DataArray(np.random.randn(2, 3)) @@ -2442,7 +2517,7 @@ def apply(self, func, keep_attrs=False, args=(), **kwargs): Dimensions without coordinates: dim_0, dim_1, x Data variables: foo (dim_0, dim_1) float64 0.3751 1.951 1.945 0.2948 0.711 0.3948 - bar (x) float64 1.0 2.0 + bar (x) float64 1.0 2.0 """ variables = OrderedDict( (k, maybe_wrap_array(v, func(v, *args, **kwargs))) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7523d30866a..1940e261dcf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -395,6 +395,9 @@ def __dask_postpersist__(self): @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): + if isinstance(results, dict): # persist case + name = array_args[0] + results = {k: v for k, v in results.items() if k[0] == name} # cull data = array_func(results, *array_args) return Variable(dims, data, attrs=attrs, encoding=encoding) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f6c66fa9812..1f1d1eb0d5a 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -450,28 +450,6 @@ def counting_get(*args, **kwargs): ds.load() self.assertEqual(count[0], 1) - def test_persist_Dataset(self): - ds = Dataset({'foo': ('x', range(5)), - 'bar': ('x', range(5))}).chunk() - ds = ds + 1 - n = len(ds.foo.data.dask) - - ds2 = ds.persist() - - assert len(ds2.foo.data.dask) == 1 - assert len(ds.foo.data.dask) == n # doesn't mutate in place - - def test_persist_DataArray(self): - x = da.arange(10, chunks=(5,)) - y = DataArray(x) - z = y + 1 - n = len(z.data.dask) - - zz = z.persist() - - assert len(z.data.dask) == n - assert len(zz.data.dask) == zz.data.npartitions - def test_stack(self): data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4)) arr = DataArray(data, dims=('w', 'x', 'y')) @@ -701,6 +679,7 @@ def test_dask_kwargs_variable(method): mock_compute.assert_called_with(foo='bar') +@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute', 'persist']) def test_dask_kwargs_dataarray(method): data = da.from_array(np.arange(3), chunks=(2,)) @@ -715,6 +694,7 @@ def test_dask_kwargs_dataarray(method): mock_func.assert_called_with(data, foo='bar') +@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute', 'persist']) def test_dask_kwargs_dataset(method): data = da.from_array(np.arange(3), chunks=(2,)) @@ -748,3 +728,30 @@ def build_dask_array(name): return dask.array.Array( dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64) + + +@pytest.mark.parametrize('persist', [lambda x: x.persist(), + lambda x: dask.persist(x)[0]]) +def test_persist_Dataset(persist): + ds = Dataset({'foo': ('x', range(5)), + 'bar': ('x', range(5))}).chunk() + ds = ds + 1 + n = len(ds.foo.data.dask) + + ds2 = persist(ds) + + assert len(ds2.foo.data.dask) == 1 + assert len(ds.foo.data.dask) == n # doesn't mutate in place + +@pytest.mark.parametrize('persist', [lambda x: x.persist(), + lambda x: dask.persist(x)[0]]) +def test_persist_DataArray(persist): + x = da.arange(10, chunks=(5,)) + y = DataArray(x) + z = y + 1 + n = len(z.data.dask) + + zz = persist(z) + + assert len(z.data.dask) == n + assert len(zz.data.dask) == zz.data.npartitions diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 1868486b01f..127a87f176b 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -4,7 +4,9 @@ distributed = pytest.importorskip('distributed') da = pytest.importorskip('dask.array') -from distributed.utils_test import cluster, loop +import dask +from distributed.utils_test import cluster, loop, gen_cluster +from distributed.client import futures_of, wait from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS from xarray.tests.test_dataset import create_test_data @@ -32,3 +34,34 @@ def test_dask_distributed_integration_test(loop, engine): assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) + + +@gen_cluster(client=True, timeout=None) +def test_async(c, s, a, b): + x = create_test_data() + assert not dask.is_dask_collection(x) + y = x.chunk({'dim2': 4}) + 10 + assert dask.is_dask_collection(y) + assert dask.is_dask_collection(y.var1) + assert dask.is_dask_collection(y.var2) + # assert not dask.is_dask_collection(y.var3) # TODO: avoid chunking unnecessarily in dataset.py::maybe_chunk + + z = y.persist() + assert str(z) + + assert dask.is_dask_collection(z) + assert dask.is_dask_collection(z.var1) + assert dask.is_dask_collection(z.var2) + # assert not dask.is_dask_collection(z.var3) + assert len(y.__dask_graph__()) > len(z.__dask_graph__()) + + assert not futures_of(y) + assert futures_of(z) + + future = c.compute(z) + w = yield future + assert not dask.is_dask_collection(w) + assert_allclose(x + 10, w) + + + assert s.task_state From ffb0ca1d50bb06dcac60b074f18cf9c1831dda02 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 31 Oct 2017 15:14:30 -0400 Subject: [PATCH 05/13] remove visualize method --- xarray/core/dataarray.py | 4 ---- xarray/core/dataset.py | 4 ---- xarray/core/variable.py | 4 ---- 3 files changed, 12 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fd0c38e3b12..588d3fe4c0d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -576,10 +576,6 @@ def reset_coords(self, names=None, drop=False, inplace=False): dataset[self.name] = self.variable return dataset - def visualize(self, **kwargs): - import dask - return dask.visualize(self, **kwargs) - def __dask_graph__(self): return self._variable.__dask_graph__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 201ed85096b..42c7058f45d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -493,10 +493,6 @@ def load(self, **kwargs): return self - def visualize(self, **kwargs): - import dask - return dask.visualize(self, **kwargs) - def __dask_graph__(self): graphs = {k: v.__dask_graph__() for k, v in self.variables.items()} graphs = {k: v for k, v in graphs.items() if v is not None} diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1940e261dcf..750f5dc352e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -364,10 +364,6 @@ def compute(self, **kwargs): (result,) = dask.compute(self, **kwargs) return result - def visualize(self, **kwargs): - import dask - return dask.visualize(self, **kwargs) - def __dask_graph__(self): if isinstance(self._data, dask_array_type): return self._data.__dask_graph__() From 56ec48757f52243c567a5b5d9eb5c3c66d631829 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 1 Nov 2017 12:36:56 -0400 Subject: [PATCH 06/13] support backwards compatibility --- xarray/core/dataarray.py | 10 ++++------ xarray/core/dataset.py | 5 ++--- xarray/core/variable.py | 5 ++--- xarray/tests/test_dask.py | 18 ++++++++++++++++-- xarray/tests/test_distributed.py | 3 ++- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 588d3fe4c0d..1d81d4e4ba6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -646,9 +646,8 @@ def compute(self, **kwargs): -------- dask.array.compute """ - import dask - (result,) = dask.compute(self, **kwargs) - return result + new = self.copy(deep=False) + return new.load(**kwargs) def persist(self, **kwargs): """ Trigger computation in constituent dask arrays @@ -666,9 +665,8 @@ def persist(self, **kwargs): -------- dask.persist """ - import dask - (result,) = dask.persist(self, **kwargs) - return result + ds = self._to_temp_dataset().persist(**kwargs) + return self._from_temp_dataset(ds) def copy(self, deep=True): """Returns a copy of this array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 42c7058f45d..6f76da14255 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -618,9 +618,8 @@ def persist(self, **kwargs): -------- dask.persist """ - import dask - (result,) = dask.persist(self, **kwargs) - return result + new = self.copy(deep=False) + return new._persist_inplace(**kwargs) @classmethod def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 750f5dc352e..dee17750281 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -360,9 +360,8 @@ def compute(self, **kwargs): -------- dask.array.compute """ - import dask - (result,) = dask.compute(self, **kwargs) - return result + new = self.copy(deep=False) + return new.load(**kwargs) def __dask_graph__(self): if isinstance(self._data, dask_array_type): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1f1d1eb0d5a..1dfb73760a0 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -206,6 +206,8 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) + @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + reason='Need dask 0.16+ for new interface') def test_compute(self): u = self.eager_var v = self.lazy_var @@ -216,6 +218,8 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() + @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + reason='Need dask 0.16+ for new interface') def test_persist(self): u = self.eager_var v = self.lazy_var + 1 @@ -275,6 +279,8 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) + @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + reason='Need dask 0.16+ for new interface') def test_compute(self): u = self.eager_array v = self.lazy_array @@ -285,6 +291,8 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() + @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + reason='Need dask 0.16+ for new interface') def test_persist(self): u = self.eager_array v = self.lazy_array + 1 @@ -730,6 +738,8 @@ def build_dask_array(name): chunks=((1,),), dtype=np.int64) +@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + reason='Need dask 0.16+ for new interface') @pytest.mark.parametrize('persist', [lambda x: x.persist(), lambda x: dask.persist(x)[0]]) def test_persist_Dataset(persist): @@ -743,8 +753,12 @@ def test_persist_Dataset(persist): assert len(ds2.foo.data.dask) == 1 assert len(ds.foo.data.dask) == n # doesn't mutate in place -@pytest.mark.parametrize('persist', [lambda x: x.persist(), - lambda x: dask.persist(x)[0]]) +@pytest.mark.parametrize('persist', [ + lambda x: x.persist(), + pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + lambda x: dask.persist(x)[0], + reason='Need Dask 0.16+') +]) def test_persist_DataArray(persist): x = da.arange(10, chunks=(5,)) y = DataArray(x) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 127a87f176b..a336ad1619a 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -36,6 +36,8 @@ def test_dask_distributed_integration_test(loop, engine): assert_allclose(original, computed) +@pytest.mark.skipif(distributed.__version__ <= '1.19.3', + reason='Need recent distributed version to clean up get') @gen_cluster(client=True, timeout=None) def test_async(c, s, a, b): x = create_test_data() @@ -63,5 +65,4 @@ def test_async(c, s, a, b): assert not dask.is_dask_collection(w) assert_allclose(x + 10, w) - assert s.task_state From f31509925f75c95a93b7ac2855f7e6abd02e21da Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 1 Nov 2017 15:56:20 -0400 Subject: [PATCH 07/13] cleanup --- xarray/core/dataset.py | 1 - xarray/tests/test_dask.py | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6f76da14255..fd87ea7358f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -627,7 +627,6 @@ def _construct_direct(cls, variables, coord_names, dims=None, attrs=None, """Shortcut around __init__ for internal use when we want to skip costly validation """ - assert not callable(coord_names), (cls, variables, coord_names, dims, attrs) obj = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 1dfb73760a0..3cccb81459c 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -676,7 +676,6 @@ def test_to_dask_dataframe_no_coordinate(self): assert_frame_equal(expected, actual.compute()) -@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute']) def test_dask_kwargs_variable(method): x = Variable('y', da.from_array(np.arange(3), chunks=(2,))) @@ -687,7 +686,6 @@ def test_dask_kwargs_variable(method): mock_compute.assert_called_with(foo='bar') -@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute', 'persist']) def test_dask_kwargs_dataarray(method): data = da.from_array(np.arange(3), chunks=(2,)) @@ -702,7 +700,6 @@ def test_dask_kwargs_dataarray(method): mock_func.assert_called_with(data, foo='bar') -@pytest.mark.xfail(reason="mock no longer targets the right method") @pytest.mark.parametrize("method", ['load', 'compute', 'persist']) def test_dask_kwargs_dataset(method): data = da.from_array(np.arange(3), chunks=(2,)) @@ -738,10 +735,12 @@ def build_dask_array(name): chunks=((1,),), dtype=np.int64) -@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', - reason='Need dask 0.16+ for new interface') -@pytest.mark.parametrize('persist', [lambda x: x.persist(), - lambda x: dask.persist(x)[0]]) +@pytest.mark.parametrize('persist', [ + lambda x: x.persist(), + pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + lambda x: dask.persist(x)[0], + reason='Need Dask 0.16+') +]) def test_persist_Dataset(persist): ds = Dataset({'foo': ('x', range(5)), 'bar': ('x', range(5))}).chunk() From fa968b99982eb8948f97f20dc3e50ec432512a3d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 2 Nov 2017 06:41:13 -0400 Subject: [PATCH 08/13] style edits --- xarray/core/dataarray.py | 6 ++++-- xarray/core/dataset.py | 6 ++++-- xarray/core/variable.py | 6 ++++-- xarray/tests/test_dask.py | 3 +++ xarray/tests/test_distributed.py | 1 - 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1d81d4e4ba6..3dcb66dab52 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -592,11 +592,13 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): variable_func, variable_args = self._variable.__dask_postcompute__() - return self._dask_finalize, (variable_func, variable_args, self._coords, self._name) + return self._dask_finalize, (variable_func, variable_args, + self._coords, self._name) def __dask_postpersist__(self): variable_func, variable_args = self._variable.__dask_postpersist__() - return self._dask_finalize, (variable_func, variable_args, self._coords, self._name) + return self._dask_finalize, (variable_func, variable_args, + self._coords, self._name) @staticmethod def _dask_finalize(results, variable_func, variable_args, coords, name): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fd87ea7358f..6a36e359709 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -523,7 +523,8 @@ def __dask_postcompute__(self): if dask.is_dask_collection(v) else (False, k, v) for k, v in self._variables.items()] return self._dask_postcompute, (info, self._coord_names, self._dims, - self._attrs, self._file_obj, self._encoding) + self._attrs, self._file_obj, + self._encoding) def __dask_postpersist__(self): import dask @@ -531,7 +532,8 @@ def __dask_postpersist__(self): if dask.is_dask_collection(v) else (False, k, v) for k, v in self._variables.items()] return self._dask_postpersist, (info, self._coord_names, self._dims, - self._attrs, self._file_obj, self._encoding) + self._attrs, self._file_obj, + self._encoding) @staticmethod def _dask_postcompute(results, info, *args): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index dee17750281..06fbc44f39f 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -382,11 +382,13 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding) + return self._dask_finalize, (array_func, array_args, self._dims, + self._attrs, self._encoding) def __dask_postpersist__(self): array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func, array_args, self._dims, self._attrs, self._encoding) + return self._dask_finalize, (array_func, array_args, self._dims, + self._attrs, self._encoding) @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 3cccb81459c..e095a952881 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -735,6 +735,8 @@ def build_dask_array(name): chunks=((1,),), dtype=np.int64) +# test both the perist method and the dask.persist function +# the dask.persist function requires a new version of dask @pytest.mark.parametrize('persist', [ lambda x: x.persist(), pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', @@ -752,6 +754,7 @@ def test_persist_Dataset(persist): assert len(ds2.foo.data.dask) == 1 assert len(ds.foo.data.dask) == n # doesn't mutate in place + @pytest.mark.parametrize('persist', [ lambda x: x.persist(), pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index a336ad1619a..56400243c73 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -46,7 +46,6 @@ def test_async(c, s, a, b): assert dask.is_dask_collection(y) assert dask.is_dask_collection(y.var1) assert dask.is_dask_collection(y.var2) - # assert not dask.is_dask_collection(y.var3) # TODO: avoid chunking unnecessarily in dataset.py::maybe_chunk z = y.persist() assert str(z) From 9df0af792eaa4d95f667ab1d72721897d82d15a6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 2 Nov 2017 14:19:08 -0400 Subject: [PATCH 09/13] change versions in tests to trigger on dask dev versions --- xarray/tests/test_dask.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e095a952881..2d730689de4 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -206,8 +206,8 @@ def test_bivariate_ufunc(self): self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0)) self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v)) - @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', - reason='Need dask 0.16+ for new interface') + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') def test_compute(self): u = self.eager_var v = self.lazy_var @@ -218,8 +218,8 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() - @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', - reason='Need dask 0.16+ for new interface') + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') def test_persist(self): u = self.eager_var v = self.lazy_var + 1 @@ -279,8 +279,8 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) - @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', - reason='Need dask 0.16+ for new interface') + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') def test_compute(self): u = self.eager_array v = self.lazy_array @@ -291,8 +291,8 @@ def test_compute(self): assert ((u + 1).data == v2.data).all() - @pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', - reason='Need dask 0.16+ for new interface') + @pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') def test_persist(self): u = self.eager_array v = self.lazy_array + 1 @@ -739,9 +739,9 @@ def build_dask_array(name): # the dask.persist function requires a new version of dask @pytest.mark.parametrize('persist', [ lambda x: x.persist(), - pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', lambda x: dask.persist(x)[0], - reason='Need Dask 0.16+') + reason='Need Dask 0.16') ]) def test_persist_Dataset(persist): ds = Dataset({'foo': ('x', range(5)), @@ -757,9 +757,9 @@ def test_persist_Dataset(persist): @pytest.mark.parametrize('persist', [ lambda x: x.persist(), - pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16', + pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', lambda x: dask.persist(x)[0], - reason='Need Dask 0.16+') + reason='Need Dask 0.16') ]) def test_persist_DataArray(persist): x = da.arange(10, chunks=(5,)) From d65569b271c36f7fa0a5ed458f9e3094d0b0b5b3 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 Nov 2017 17:55:19 -0400 Subject: [PATCH 10/13] support dask arrays in DataArray coordinates --- xarray/core/dataarray.py | 26 +++++++++++++------------- xarray/core/dataset.py | 2 +- xarray/tests/test_dask.py | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 3dcb66dab52..1dac72335d2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -577,33 +577,33 @@ def reset_coords(self, names=None, drop=False, inplace=False): return dataset def __dask_graph__(self): - return self._variable.__dask_graph__() + return self._to_temp_dataset().__dask_graph__() def __dask_keys__(self): - return self._variable.__dask_keys__() + return self._to_temp_dataset().__dask_keys__() @property def __dask_optimize__(self): - return self._variable.__dask_optimize__ + return self._to_temp_dataset().__dask_optimize__ @property def __dask_scheduler__(self): - return self._variable.__dask_scheduler__ + return self._to_temp_dataset().__dask_optimize__ def __dask_postcompute__(self): - variable_func, variable_args = self._variable.__dask_postcompute__() - return self._dask_finalize, (variable_func, variable_args, - self._coords, self._name) + func, args = self._to_temp_dataset().__dask_postcompute__() + return self._dask_finalize, (func, args, self.name) def __dask_postpersist__(self): - variable_func, variable_args = self._variable.__dask_postpersist__() - return self._dask_finalize, (variable_func, variable_args, - self._coords, self._name) + func, args = self._to_temp_dataset().__dask_postpersist__() + return self._dask_finalize, (func, args, self.name) @staticmethod - def _dask_finalize(results, variable_func, variable_args, coords, name): - var = variable_func(results, *variable_args) - return DataArray(var, coords=coords, name=name) + def _dask_finalize(results, func, args, name): + ds = func(results, *args) + variable = ds._variables.pop(_THIS_ARRAY) + coords = ds._variables + return DataArray(variable, coords, name=name, fastpath=True) def load(self, **kwargs): """Manually trigger loading of this array's data from disk or a diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6a36e359709..6f9a1416e56 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -538,7 +538,7 @@ def __dask_postpersist__(self): @staticmethod def _dask_postcompute(results, info, *args): variables = OrderedDict() - results2 = results[::-1] + results2 = list(results[::-1]) for is_dask, k, v in info: if is_dask: func, args2 = v diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 2d730689de4..ac5b76357b5 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -771,3 +771,24 @@ def test_persist_DataArray(persist): assert len(z.data.dask) == n assert len(zz.data.dask) == zz.data.npartitions + + +@pytest.mark.skipif(LooseVersion(dask.__version__) <= '0.15.4', + reason='Need dask 0.16 for new interface') +def test_dataarray_with_dask_coords(): + import toolz + x = xr.Variable('x', da.arange(8, chunks=(4,))) + y = xr.Variable('y', da.arange(8, chunks=(4,)) * 2) + data = da.random.random((8, 8), chunks=(4, 4)) + 1 + array = xr.DataArray(data, dims=['x', 'y']) + array.coords['xx'] = x + array.coords['yy'] = y + + assert dict(array.__dask_graph__()) == toolz.merge(data.__dask_graph__(), + x.__dask_graph__(), + y.__dask_graph__()) + + (array2,) = dask.compute(array) + assert not dask.is_dask_collection(array2) + + assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values()) From bbeafec2ec61677fb30b5acfebf73a3bfcc101c9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 Nov 2017 23:48:38 -0400 Subject: [PATCH 11/13] remove commented assertion --- xarray/tests/test_distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 56400243c73..9999ed9a669 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -53,7 +53,6 @@ def test_async(c, s, a, b): assert dask.is_dask_collection(z) assert dask.is_dask_collection(z.var1) assert dask.is_dask_collection(z.var2) - # assert not dask.is_dask_collection(z.var3) assert len(y.__dask_graph__()) > len(z.__dask_graph__()) assert not futures_of(y) From ff94d95dfc913c4c5496ecb0bcabd16e8d7ceb5c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 3 Nov 2017 23:48:45 -0400 Subject: [PATCH 12/13] whats new --- doc/whats-new.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7d6b1803e37..de0d3169683 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -27,6 +27,7 @@ backwards incompatible changes. Highlights include: - :py:func:`~xarray.apply_ufunc` facilitates wrapping and parallelizing functions written for NumPy arrays. - Performance improvements, particularly for dask and :py:func:`open_mfdataset`. +- Support Dask collection interface Breaking changes ~~~~~~~~~~~~~~~~ From a6cde57ac61b5d864ab2aed3440953a045a2d157 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 6 Nov 2017 17:45:49 -0800 Subject: [PATCH 13/13] elaborate on what's new --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index de0d3169683..409bf72c2a9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -15,6 +15,9 @@ What's New .. _whats-new.0.9.7: +- Experimental support for the Dask collection interface (:issue:`1674`). + By `Matthew Rocklin `_. + v0.10.0 (unreleased) -------------------- @@ -27,7 +30,6 @@ backwards incompatible changes. Highlights include: - :py:func:`~xarray.apply_ufunc` facilitates wrapping and parallelizing functions written for NumPy arrays. - Performance improvements, particularly for dask and :py:func:`open_mfdataset`. -- Support Dask collection interface Breaking changes ~~~~~~~~~~~~~~~~