Skip to content

Commit da8a8dc

Browse files
committed
support backwards compatibility
1 parent ffb0ca1 commit da8a8dc

File tree

5 files changed

+25
-15
lines changed

5 files changed

+25
-15
lines changed

xarray/core/dataarray.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,8 @@ def compute(self, **kwargs):
646646
--------
647647
dask.array.compute
648648
"""
649-
import dask
650-
(result,) = dask.compute(self, **kwargs)
651-
return result
649+
new = self.copy(deep=False)
650+
return new.load(**kwargs)
652651

653652
def persist(self, **kwargs):
654653
""" Trigger computation in constituent dask arrays
@@ -666,9 +665,8 @@ def persist(self, **kwargs):
666665
--------
667666
dask.persist
668667
"""
669-
import dask
670-
(result,) = dask.persist(self, **kwargs)
671-
return result
668+
ds = self._to_temp_dataset().persist(**kwargs)
669+
return self._from_temp_dataset(ds)
672670

673671
def copy(self, deep=True):
674672
"""Returns a copy of this array.

xarray/core/dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,8 @@ def persist(self, **kwargs):
618618
--------
619619
dask.persist
620620
"""
621-
import dask
622-
(result,) = dask.persist(self, **kwargs)
623-
return result
621+
new = self.copy(deep=False)
622+
return new._persist_inplace(**kwargs)
624623

625624
@classmethod
626625
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,

xarray/core/variable.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,8 @@ def compute(self, **kwargs):
360360
--------
361361
dask.array.compute
362362
"""
363-
import dask
364-
(result,) = dask.compute(self, **kwargs)
365-
return result
363+
new = self.copy(deep=False)
364+
return new.load(**kwargs)
366365

367366
def __dask_graph__(self):
368367
if isinstance(self._data, dask_array_type):

xarray/tests/test_dask.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def test_bivariate_ufunc(self):
206206
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
207207
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))
208208

209+
@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
210+
reason='Need dask 0.16+ for new interface')
209211
def test_compute(self):
210212
u = self.eager_var
211213
v = self.lazy_var
@@ -216,6 +218,8 @@ def test_compute(self):
216218

217219
assert ((u + 1).data == v2.data).all()
218220

221+
@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
222+
reason='Need dask 0.16+ for new interface')
219223
def test_persist(self):
220224
u = self.eager_var
221225
v = self.lazy_var + 1
@@ -275,6 +279,8 @@ def test_lazy_array(self):
275279
actual = xr.concat([v[:2], v[2:]], 'x')
276280
self.assertLazyAndAllClose(u, actual)
277281

282+
@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
283+
reason='Need dask 0.16+ for new interface')
278284
def test_compute(self):
279285
u = self.eager_array
280286
v = self.lazy_array
@@ -285,6 +291,8 @@ def test_compute(self):
285291

286292
assert ((u + 1).data == v2.data).all()
287293

294+
@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
295+
reason='Need dask 0.16+ for new interface')
288296
def test_persist(self):
289297
u = self.eager_array
290298
v = self.lazy_array + 1
@@ -730,6 +738,8 @@ def build_dask_array(name):
730738
chunks=((1,),), dtype=np.int64)
731739

732740

741+
@pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
742+
reason='Need dask 0.16+ for new interface')
733743
@pytest.mark.parametrize('persist', [lambda x: x.persist(),
734744
lambda x: dask.persist(x)[0]])
735745
def test_persist_Dataset(persist):
@@ -743,8 +753,12 @@ def test_persist_Dataset(persist):
743753
assert len(ds2.foo.data.dask) == 1
744754
assert len(ds.foo.data.dask) == n # doesn't mutate in place
745755

746-
@pytest.mark.parametrize('persist', [lambda x: x.persist(),
747-
lambda x: dask.persist(x)[0]])
756+
@pytest.mark.parametrize('persist', [
757+
lambda x: x.persist(),
758+
pytest.mark.skipif(LooseVersion(dask.__version__) < '0.16',
759+
lambda x: dask.persist(x)[0],
760+
reason='Need Dask 0.16+')
761+
])
748762
def test_persist_DataArray(persist):
749763
x = da.arange(10, chunks=(5,))
750764
y = DataArray(x)

xarray/tests/test_distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_dask_distributed_integration_test(loop, engine):
3636
assert_allclose(original, computed)
3737

3838

39+
@pytest.mark.skipif(distributed.__version__ <= '1.19.3')
3940
@gen_cluster(client=True, timeout=None)
4041
def test_async(c, s, a, b):
4142
x = create_test_data()
@@ -63,5 +64,4 @@ def test_async(c, s, a, b):
6364
assert not dask.is_dask_collection(w)
6465
assert_allclose(x + 10, w)
6566

66-
6767
assert s.task_state

0 commit comments

Comments
 (0)