Skip to content

Use new dask interface #1513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import weakref

import dask
from dask.base import tokenize, normalize_token, Base, collections_to_dsk
from dask.base import tokenize, normalize_token, collections_to_dsk
from dask.core import flatten, get_dependencies
from dask.compatibility import apply, unicode
from dask.context import _globals
Expand Down Expand Up @@ -891,6 +891,12 @@ def _handle_error(self, exception=None):
def _close(self, fast=False):
""" Send close signal and wait until scheduler completes """
with log_errors():
with ignoring(AttributeError):
dask.set_options(get=self._previous_get)
with ignoring(AttributeError):
dask.set_options(shuffle=self._previous_shuffle)
if self.get == _globals.get('get'):
del _globals['get']
if self.status == 'closed':
raise gen.Return()
if self.scheduler_comm and self.scheduler_comm.comm and not self.scheduler_comm.comm.closed():
Expand Down Expand Up @@ -1963,23 +1969,25 @@ def normalize_collection(self, collection):

Examples
--------
>>> len(x.dask) # x is a dask collection with 100 tasks # doctest: +SKIP
>>> len(x.__dask_graph__()) # x is a dask collection with 100 tasks # doctest: +SKIP
100
>>> set(client.futures).intersection(x.dask) # some overlap exists # doctest: +SKIP
>>> set(client.futures).intersection(x.__dask_graph__()) # some overlap exists # doctest: +SKIP
10

>>> x = client.normalize_collection(x) # doctest: +SKIP
>>> len(x.dask) # smaller computational graph # doctest: +SKIP
>>> len(x.__dask_graph__()) # smaller computational graph # doctest: +SKIP
20

See Also
--------
Client.persist: trigger computation of collection's tasks
"""
with self._lock:
dsk = self._optimize_insert_futures(collection.dask, collection._keys())
dsk = self._optimize_insert_futures(
collection.__dask_graph__(),
collection.__dask_keys__())

if dsk is collection.dask:
if dsk is collection.__dask_graph__():
return collection
else:
return redict_collection(collection, dsk)
Expand Down Expand Up @@ -2047,11 +2055,14 @@ def compute(self, collections, sync=False, optimize_graph=True,
if isinstance(a, (list, set, tuple, dict, Iterator))
else a for a in collections)

variables = [a for a in collections if isinstance(a, Base)]
variables = [a for a in collections if dask.is_dask_collection(a)]

dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs)
names = ['finalize-%s' % tokenize(v) for v in variables]
dsk2 = {name: (v._finalize, v._keys()) for name, v in zip(names, variables)}
dsk2 = {}
for name, v in zip(names, variables):
func, extra_args = v.__dask_postcompute__()
dsk2[name] = (func, v.__dask_keys__()) + extra_args

restrictions, loose_restrictions = self.get_restrictions(collections,
workers, allow_other_workers)
Expand All @@ -2066,7 +2077,7 @@ def compute(self, collections, sync=False, optimize_graph=True,
i = 0
futures = []
for arg in collections:
if isinstance(arg, Base):
if dask.is_dask_collection(arg):
futures.append(futures_dict[names[i]])
i += 1
else:
Expand Down Expand Up @@ -2127,11 +2138,11 @@ def persist(self, collections, optimize_graph=True, workers=None,
singleton = True
collections = [collections]

assert all(isinstance(c, Base) for c in collections)
assert all(map(dask.is_dask_collection, collections))

dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs)

names = {k for c in collections for k in flatten(c._keys())}
names = {k for c in collections for k in flatten(c.__dask_keys__())}

restrictions, loose_restrictions = self.get_restrictions(collections,
workers, allow_other_workers)
Expand All @@ -2142,9 +2153,10 @@ def persist(self, collections, optimize_graph=True, workers=None,
futures = self._graph_to_futures(dsk, names, restrictions,
loose_restrictions, resources=resources)

result = [redict_collection(c, {k: futures[k]
for k in flatten(c._keys())})
for c in collections]
postpersists = [c.__dask_postpersist__() for c in collections]
result = [func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args)
for (func, args), c in zip(postpersists, collections)]

if singleton:
return first(result)
else:
Expand Down Expand Up @@ -2913,8 +2925,8 @@ def expand_resources(resources):
if not isinstance(k, tuple):
k = (k,)
for kk in k:
if hasattr(kk, '_keys'):
for kkk in kk._keys():
if dask.is_dask_collection(kk):
for kkk in kk.__dask_keys__():
out[tokey(kkk)] = v
else:
out[tokey(kk)] = v
Expand All @@ -2930,11 +2942,11 @@ def get_restrictions(collections, workers, allow_other_workers):
for colls, ws in workers.items():
if isinstance(ws, str):
ws = [ws]
if hasattr(colls, '._keys'):
keys = flatten(colls._keys())
if dask.is_dask_collection(colls):
keys = flatten(colls.__dask_keys__())
else:
keys = list({k for c in flatten(colls)
for k in flatten(c._keys())})
for k in flatten(c.__dask_keys__())})
restrictions.update({k: ws for k in keys})
else:
restrictions = {}
Expand All @@ -2943,7 +2955,7 @@ def get_restrictions(collections, workers, allow_other_workers):
loose_restrictions = list(restrictions)
elif allow_other_workers:
loose_restrictions = list({k for c in flatten(allow_other_workers)
for k in c._keys()})
for k in c.__dask_keys__()})
else:
loose_restrictions = []

Expand Down Expand Up @@ -3275,8 +3287,8 @@ def futures_of(o, client=None):
stack.extend(x.values())
if isinstance(x, Future):
futures.add(x)
if hasattr(x, 'dask') and isinstance(x.dask, Mapping):
stack.extend(x.dask.values())
if dask.is_dask_collection(x):
stack.extend(x.__dask_graph__().values())

if client is not None:
bad = {f for f in futures if f.cancelled()}
Expand Down
14 changes: 7 additions & 7 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,8 +2254,8 @@ def test_async_persist(c, s, a, b):
assert len(yy.dask) == 1
assert len(ww.dask) == 1
assert len(w.dask) > 1
assert y._keys() == yy._keys()
assert w._keys() == ww._keys()
assert y.__dask_keys__() == yy.__dask_keys__()
assert w.__dask_keys__() == ww.__dask_keys__()

while y.key not in s.tasks and w.key not in s.tasks:
yield gen.sleep(0.01)
Expand Down Expand Up @@ -2285,7 +2285,7 @@ def test__persist(c, s, a, b):
assert len(y.dask) == 6
assert len(yy.dask) == 2
assert all(isinstance(v, Future) for v in yy.dask.values())
assert yy._keys() == y._keys()
assert yy.__dask_keys__() == y.__dask_keys__()

g, h = c.compute([y, yy])

Expand All @@ -2305,7 +2305,7 @@ def test_persist(loop):
assert len(y.dask) == 6
assert len(yy.dask) == 2
assert all(isinstance(v, Future) for v in yy.dask.values())
assert yy._keys() == y._keys()
assert yy.__dask_keys__() == y.__dask_keys__()

zz = yy.compute(get=c.get)
z = y.compute(get=c.get)
Expand Down Expand Up @@ -2634,7 +2634,7 @@ def test_persist_get(c, s, a, b):
xxyy3 = delayed(add)(xxyy2, 10)

yield gen.sleep(0.5)
result = yield c.get(xxyy3.dask, xxyy3._keys(), sync=False)
result = yield c.get(xxyy3.dask, xxyy3.__dask_keys__(), sync=False)
assert result[0] == ((1 + 1) + (2 + 2)) + 10

result = yield c.compute(xxyy3)
Expand Down Expand Up @@ -3331,7 +3331,7 @@ def test_persist_optimize_graph(c, s, a, b):
b4 = method(b3, optimize_graph=False)
yield wait(b4)

assert set(map(tokey, b3._keys())).issubset(s.tasks)
assert set(map(tokey, b3.__dask_keys__())).issubset(s.tasks)

b = db.range(i, npartitions=2)
i += 1
Expand All @@ -3341,7 +3341,7 @@ def test_persist_optimize_graph(c, s, a, b):
b4 = method(b3, optimize_graph=True)
yield wait(b4)

assert not any(tokey(k) in s.tasks for k in b2._keys())
assert not any(tokey(k) in s.tasks for k in b2.__dask_keys__())


@gen_cluster(client=True, ncores=[])
Expand Down
8 changes: 4 additions & 4 deletions distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ def test_persist_collections(c, s, a, b):
z = y.map_blocks(lambda x: 2 * x)
w = z.sum()

ww, yy = c.persist([w, y], resources={tuple(y._keys()): {'A': 1}})
ww, yy = c.persist([w, y], resources={tuple(y.__dask_keys__()): {'A': 1}})

yield wait([ww, yy])

assert all(tokey(key) in a.data for key in y._keys())
assert all(tokey(key) in a.data for key in y.__dask_keys__())


@pytest.mark.xfail(reason="Should protect resource keys from optimization")
Expand All @@ -247,7 +247,7 @@ def test_dont_optimize_out(c, s, a, b):
z = y.map_blocks(lambda x: 2 * x)
w = z.sum()

yield c.compute(w, resources={tuple(y._keys()): {'A': 1}},)
yield c.compute(w, resources={tuple(y.__dask_keys__()): {'A': 1}},)

for key in map(tokey, y._keys()):
for key in map(tokey, y.__dask_keys__()):
assert 'executing' in str(a.story(key))
17 changes: 16 additions & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from distributed import Scheduler, Worker, Client, config
from distributed.core import rpc
from distributed.metrics import time
from distributed.utils_test import (cluster, gen_cluster,
from distributed.utils_test import (cluster, gen_cluster, inc,
gen_test, wait_for_port, new_config,
tls_only_security)
from distributed.utils_test import loop # flake8: noqa
Expand Down Expand Up @@ -41,6 +41,21 @@ def test_gen_cluster(c, s, a, b):
assert s.ncores == {w.address: w.ncores for w in [a, b]}


@pytest.mark.skip(reason="This hangs on travis")
def test_gen_cluster_cleans_up_client(loop):
import dask.context
assert not dask.context._globals.get('get')

@gen_cluster(client=True)
def f(c, s, a, b):
assert dask.context._globals.get('get')
yield c.submit(inc, 1)

f()

assert not dask.context._globals.get('get')


@gen_cluster(client=False)
def test_gen_cluster_without_client(s, a, b):
assert isinstance(s, Scheduler)
Expand Down
9 changes: 9 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import six

from dask.context import _globals
from toolz import merge, memoize
from tornado import gen, queues
from tornado.gen import TimeoutError
Expand Down Expand Up @@ -354,6 +355,8 @@ def wait_a_bit():
def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1,
scheduler_kwargs={}):
ws = weakref.WeakSet()
old_globals = _globals.copy()

for name, level in logging_levels.items():
logging.getLogger(name).setLevel(level)

Expand Down Expand Up @@ -438,6 +441,9 @@ def cluster(nworkers=2, nanny=False, worker_kwargs={}, active_rpc_timeout=1,

for fn in glob('_test_worker-*'):
shutil.rmtree(fn)

_globals.clear()
_globals.update(old_globals)
assert not ws


Expand Down Expand Up @@ -571,6 +577,7 @@ def _(func):
func = gen.coroutine(func)

def test_func():
old_globals = _globals.copy()
result = None
workers = []

Expand All @@ -596,6 +603,8 @@ def coro():
if client:
yield c._close()
yield end_cluster(s, workers)
_globals.clear()
_globals.update(old_globals)

raise gen.Return(result)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/locality.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ run elsewhere if necessary:
allow_other_workers=[x])

This works fine with ``persist`` and with any dask collection (any object with
a ``._keys()`` method):
a ``.__dask_graph__()`` method):

.. code-block:: python

Expand Down
4 changes: 2 additions & 2 deletions docs/source/resources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ resource requirements during compute or persist calls.
y = x.map_partitions(func1)
z = y.map_parititons(func2)

z.compute(resources={tuple(y._keys()): {'GPU': 1})
z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1})

In some cases (such as the case above) the keys for ``y`` may be optimized away
before execution. You can avoid that either by requiring them as an explicit
Expand All @@ -82,4 +82,4 @@ output, or by passing the ``optimize_graph=False`` keyword.

.. code-block:: python

z.compute(resources={tuple(y._keys()): {'GPU': 1}, optimize_graph=False)
z.compute(resources={tuple(y.__dask_keys__()): {'GPU': 1}, optimize_graph=False)