diff --git a/distributed/client.py b/distributed/client.py index d98c33a50f0..dd30c9c4634 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -24,7 +24,7 @@ import weakref import dask -from dask.base import tokenize, normalize_token, Base, collections_to_dsk +from dask.base import tokenize, normalize_token, collections_to_dsk from dask.core import flatten, get_dependencies from dask.compatibility import apply, unicode from dask.context import _globals @@ -891,6 +891,12 @@ def _handle_error(self, exception=None): def _close(self, fast=False): """ Send close signal and wait until scheduler completes """ with log_errors(): + with ignoring(AttributeError): + dask.set_options(get=self._previous_get) + with ignoring(AttributeError): + dask.set_options(shuffle=self._previous_shuffle) + if self.get == _globals.get('get'): + del _globals['get'] if self.status == 'closed': raise gen.Return() if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed(): @@ -1963,13 +1969,13 @@ def normalize_collection(self, collection): Examples -------- - >>> len(x.dask) # x is a dask collection with 100 tasks # doctest: +SKIP + >>> len(x.__dask_graph__()) # x is a dask collection with 100 tasks # doctest: +SKIP 100 - >>> set(client.futures).intersection(x.dask) # some overlap exists # doctest: +SKIP + >>> set(client.futures).intersection(x.__dask_graph__()) # some overlap exists # doctest: +SKIP 10 >>> x = client.normalize_collection(x) # doctest: +SKIP - >>> len(x.dask) # smaller computational graph # doctest: +SKIP + >>> len(x.__dask_graph__()) # smaller computational graph # doctest: +SKIP 20 See Also @@ -1977,9 +1983,11 @@ def normalize_collection(self, collection): Client.persist: trigger computation of collection's tasks """ with self._lock: - dsk = self._optimize_insert_futures(collection.dask, collection._keys()) + dsk = self._optimize_insert_futures( + collection.__dask_graph__(), + collection.__dask_keys__()) - if dsk is collection.dask: + if dsk is collection.__dask_graph__(): return collection else: return redict_collection(collection, dsk) @@ -2047,11 +2055,14 @@ def compute(self, collections, sync=False, optimize_graph=True, if isinstance(a, (list, set, tuple, dict, Iterator)) else a for a in collections) - variables = [a for a in collections if isinstance(a, Base)] + variables = [a for a in collections if dask.is_dask_collection(a)] dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs) names = ['finalize-%s' % tokenize(v) for v in variables] - dsk2 = {name: (v._finalize, v._keys()) for name, v in zip(names, variables)} + dsk2 = {} + for name, v in zip(names, variables): + func, extra_args = v.__dask_postcompute__() + dsk2[name] = (func, v.__dask_keys__()) + extra_args restrictions, loose_restrictions = self.get_restrictions(collections, workers, allow_other_workers) @@ -2066,7 +2077,7 @@ def compute(self, collections, sync=False, optimize_graph=True, i = 0 futures = [] for arg in collections: - if isinstance(arg, Base): + if dask.is_dask_collection(arg): futures.append(futures_dict[names[i]]) i += 1 else: @@ -2127,11 +2138,11 @@ def persist(self, collections, optimize_graph=True, workers=None, singleton = True collections = [collections] - assert all(isinstance(c, Base) for c in collections) + assert all(map(dask.is_dask_collection, collections)) dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs) - names = {k for c in collections for k in flatten(c._keys())} + names = {k for c in collections for k in flatten(c.__dask_keys__())} restrictions, loose_restrictions = self.get_restrictions(collections, workers, allow_other_workers) @@ -2142,9 +2153,10 @@ def persist(self, collections, optimize_graph=True, workers=None, futures = self._graph_to_futures(dsk, names, restrictions, loose_restrictions, resources=resources) - result = [redict_collection(c, {k: futures[k] - for k in flatten(c._keys())}) - for c in collections] + postpersists = [c.__dask_postpersist__() for c in collections] + result = [func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args) + for (func, args), c in zip(postpersists, collections)] + if singleton: return first(result) else: @@ -2913,8 +2925,8 @@ def expand_resources(resources): if not isinstance(k, tuple): k = (k,) for kk in k: - if hasattr(kk, '_keys'): - for kkk in kk._keys(): + if dask.is_dask_collection(kk): + for kkk in kk.__dask_keys__(): out[tokey(kkk)] = v else: out[tokey(kk)] = v @@ -2930,11 +2942,11 @@ def get_restrictions(collections, workers, allow_other_workers): for colls, ws in workers.items(): if isinstance(ws, str): ws = [ws] - if hasattr(colls, '._keys'): - keys = flatten(colls._keys()) + if dask.is_dask_collection(colls): + keys = flatten(colls.__dask_keys__()) else: keys = list({k for c in flatten(colls) - for k in flatten(c._keys())}) + for k in flatten(c.__dask_keys__())}) restrictions.update({k: ws for k in keys}) else: restrictions = {} @@ -2943,7 +2955,7 @@ def get_restrictions(collections, workers, allow_other_workers): loose_restrictions = list(restrictions) elif allow_other_workers: loose_restrictions = list({k for c in flatten(allow_other_workers) - for k in c._keys()}) + for k in c.__dask_keys__()}) else: loose_restrictions = [] @@ -3275,8 +3287,8 @@ def futures_of(o, client=None): stack.extend(x.values()) if isinstance(x, Future): futures.add(x) - if hasattr(x, 'dask') and isinstance(x.dask, Mapping): - stack.extend(x.dask.values()) + if dask.is_dask_collection(x): + stack.extend(x.__dask_graph__().values()) if client is not None: bad = {f for f in futures if f.cancelled()} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 02f69d5efe5..f11283742e3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2254,8 +2254,8 @@ def test_async_persist(c, s, a, b): assert len(yy.dask) == 1 assert len(ww.dask) == 1 assert len(w.dask) > 1 - assert y._keys() == yy._keys() - assert w._keys() == ww._keys() + assert y.__dask_keys__() == yy.__dask_keys__() + assert w.__dask_keys__() == ww.__dask_keys__() while y.key not in s.tasks and w.key not in s.tasks: yield gen.sleep(0.01) @@ -2285,7 +2285,7 @@ def test__persist(c, s, a, b): assert len(y.dask) == 6 assert len(yy.dask) == 2 assert all(isinstance(v, Future) for v in yy.dask.values()) - assert yy._keys() == y._keys() + assert yy.__dask_keys__() == y.__dask_keys__() g, h = c.compute([y, yy]) @@ -2305,7 +2305,7 @@ def test_persist(loop): assert len(y.dask) == 6 assert len(yy.dask) == 2 assert all(isinstance(v, Future) for v in yy.dask.values()) - assert yy._keys() == y._keys() + assert yy.__dask_keys__() == y.__dask_keys__() zz = yy.compute(get=c.get) z = y.compute(get=c.get) @@ -2634,7 +2634,7 @@ def test_persist_get(c, s, a, b): xxyy3 = delayed(add)(xxyy2, 10) yield gen.sleep(0.5) - result = yield c.get(xxyy3.dask, xxyy3._keys(), sync=False) + result = yield c.get(xxyy3.dask, xxyy3.__dask_keys__(), sync=False) assert result[0] == ((1 + 1) + (2 + 2)) + 10 result = yield c.compute(xxyy3) @@ -3331,7 +3331,7 @@ def test_persist_optimize_graph(c, s, a, b): b4 = method(b3, optimize_graph=False) yield wait(b4) - assert set(map(tokey, b3._keys())).issubset(s.tasks) + assert set(map(tokey, b3.__dask_keys__())).issubset(s.tasks) b = db.range(i, npartitions=2) i += 1 @@ -3341,7 +3341,7 @@ def test_persist_optimize_graph(c, s, a, b): b4 = method(b3, optimize_graph=True) yield wait(b4) - assert not any(tokey(k) in s.tasks for k in b2._keys()) + assert not any(tokey(k) in s.tasks for k in b2.__dask_keys__()) @gen_cluster(client=True, ncores=[]) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 2fae5f942ba..1fd265001cd 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -230,11 +230,11 @@ def test_persist_collections(c, s, a, b): z = y.map_blocks(lambda x: 2 * x) w = z.sum() - ww, yy = c.persist([w, y], resources={tuple(y._keys()): {'A': 1}}) + ww, yy = c.persist([w, y], resources={tuple(y.__dask_keys__()): {'A': 1}}) yield wait([ww, yy]) - assert all(tokey(key) in a.data for key in y._keys()) + assert all(tokey(key) in a.data for key in y.__dask_keys__()) @pytest.mark.xfail(reason="Should protect resource keys from optimization") @@ -247,7 +247,7 @@ def test_dont_optimize_out(c, s, a, b): z = y.map_blocks(lambda x: 2 * x) w = z.sum() - yield c.compute(w, resources={tuple(y._keys()): {'A': 1}},) + yield c.compute(w, resources={tuple(y.__dask_keys__()): {'A': 1}},) - for key in map(tokey, y._keys()): + for key in map(tokey, y.__dask_keys__()): assert 'executing' in str(a.story(key)) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 53cd184e260..715f6fae5e8 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -12,7 +12,7 @@ from distributed import Scheduler, Worker, Client, config from distributed.core import rpc from distributed.metrics import time -from distributed.utils_test import (cluster, gen_cluster, +from distributed.utils_test import (cluster, gen_cluster, inc, gen_test, wait_for_port, new_config, tls_only_security) from distributed.utils_test import loop # flake8: noqa @@ -41,6 +41,21 @@ def test_gen_cluster(c, s, a, b): assert s.ncores == {w.address: w.ncores for w in [a, b]} +@pytest.mark.skip(reason="This hangs on travis") +def test_gen_cluster_cleans_up_client(loop): + import dask.context + assert not dask.context._globals.get('get') + + @gen_cluster(client=True) + def f(c, s, a, b): + assert dask.context._globals.get('get') + yield c.submit(inc, 1) + + f() + + assert not dask.context._globals.get('get') + + @gen_cluster(client=False) def test_gen_cluster_without_client(s, a, b): assert isinstance(s, Scheduler) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index bb7d125c836..4359cb21691 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -25,6 +25,7 @@ import six +from dask.context import _globals from toolz import merge, memoize from tornado import gen, queues from tornado.gen import TimeoutError @@ -354,6 +355,8 @@ def wait_a_bit(): def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, scheduler_kwargs={}): ws = weakref.WeakSet() + old_globals = _globals.copy() + for name, level in logging_levels.items(): logging.getLogger(name).setLevel(level) @@ -438,6 +441,9 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1, for fn in glob('_test_worker-*'): shutil.rmtree(fn) + + _globals.clear() + _globals.update(old_globals) assert not ws @@ -571,6 +577,7 @@ def _(func): func = gen.coroutine(func) def test_func(): + old_globals = _globals.copy() result = None workers = [] @@ -596,6 +603,8 @@ def coro(): if client: yield c._close() yield end_cluster(s, workers) + _globals.clear() + _globals.update(old_globals) raise gen.Return(result) diff --git a/docs/source/locality.rst b/docs/source/locality.rst index a95657c7c84..4ee6568395d 100644 --- a/docs/source/locality.rst +++ b/docs/source/locality.rst @@ -189,7 +189,7 @@ run elsewhere if necessary: allow_other_workers=[x]) This works fine with ``persist`` and with any dask collection (any object with -a ``._keys()`` method): +a ``.__dask_graph__()`` method): .. code-block:: python diff --git a/docs/source/resources.rst b/docs/source/resources.rst index b239b50abf8..5d501ce4bd2 100644 --- a/docs/source/resources.rst +++ b/docs/source/resources.rst @@ -73,7 +73,7 @@ resource requirements during compute or persist calls. y = x.map_partitions(func1) z = y.map_parititons(func2) - z.compute(resources={tuple(y._keys()): {'GPU': 1}) + z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1}) In some cases (such as the case above) the keys for ``y`` may be optimized away before execution. You can avoid that either by requiring them as an explicit @@ -82,4 +82,4 @@ output, or by passing the ``optimize_graph=False`` keyword. .. code-block:: python - z.compute(resources={tuple(y._keys()): {'GPU': 1}, optimize_graph=False) + z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1}, optimize_graph=False)