diff --git a/.travis.yml b/.travis.yml index d5b5e98b..2e8a3637 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,6 @@ install: - pip install --upgrade nose coveralls coverage - if [[ $TRAVIS_PYTHON_VERSION == 'pypy'* ]]; then export TRAVIS_WAIT=45; else export TRAVIS_WAIT=20; fi script: - travis_wait "${TRAVIS_WAIT}" nosetests --with-coverage --cover-package=neat -vd + nosetests --with-coverage --cover-package=neat -vd after_success: coveralls diff --git a/neat/distributed.py b/neat/distributed.py index 36e1847a..6b083c17 100644 --- a/neat/distributed.py +++ b/neat/distributed.py @@ -56,23 +56,22 @@ from __future__ import print_function import socket +import select +import struct import sys import time import warnings +import multiprocessing +import pickle +import json +import threading -# below still needed for queue.Empty try: - # pylint: disable=import-error import Queue as queue except ImportError: - # pylint: disable=import-error import queue -import multiprocessing -from multiprocessing import managers -from argparse import Namespace - -# Some of this code is based on +# Some of the original code is based on # http://eli.thegreenplace.net/2012/01/24/distributed-computing-in-python-with-multiprocessing # According to the website, the code is in the public domain # ('public domain' links to unlicense.org). @@ -91,6 +90,14 @@ _STATE_RUNNING = 0 _STATE_SHUTDOWN = 1 _STATE_FORCED_SHUTDOWN = 2 +_STATE_ERROR = 3 + + +# constants for network communication +_LENGTH_PREFIX = "!Q" +_LENGTH_PREFIX_LENGTH = struct.calcsize(_LENGTH_PREFIX) +_DEFAULT_NETWORK_ENCODING = "utf-8" # encoding for json messages +_DEFAULT_PICKLE_ENCODING = "latin1" # encoding for pickle class ModeError(RuntimeError): @@ -102,25 +109,51 @@ class ModeError(RuntimeError): pass +class ProtocolError(IOError): + """ + An Exception raised when either the client or the server does not + send a valid response. + """ + pass + + +class AuthError(Exception): + """raised if the Authentication failed.""" + pass + + def host_is_local(hostname, port=22): # no port specified, just use the ssh port """ Returns True if the hostname points to the localhost, otherwise False. """ - hostname = socket.getfqdn(hostname) - if hostname in ("localhost", "0.0.0.0", "127.0.0.1", "1.0.0.127.in-addr.arpa", - "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa"): - return True - localhost = socket.gethostname() - if hostname == localhost: - return True - localaddrs = socket.getaddrinfo(localhost, port) - targetaddrs = socket.getaddrinfo(hostname, port) - for (ignored_family, ignored_socktype, ignored_proto, ignored_canonname, - sockaddr) in localaddrs: - for (ignored_rfamily, ignored_rsocktype, ignored_rproto, - ignored_rcanonname, rsockaddr) in targetaddrs: - if rsockaddr[0] == sockaddr[0]: - return True + try: + fqdn = socket.getfqdn(hostname) + except TypeError: + # sometimes fails on pypy3 + fqdn = None + hostnames = [hostname, fqdn] + for hn in hostnames: + if hn in ( + # for py2/py3 compatibility, check for both binary and native strings + b"localhost", "localhost", + b"0.0.0.0", "0.0.0.0", + b"127.0.0.1", "127.0.0.1", + b"1.0.0.127.in-addr.arpa", "1.0.0.127.in-addr.arpa", + b"1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa", + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa" + ): + return True + localhost = socket.gethostname() + if hn == localhost: + return True + localaddrs = socket.getaddrinfo(localhost, port) + targetaddrs = socket.getaddrinfo(hn, port) + for (ignored_family, ignored_socktype, ignored_proto, ignored_canonname, + sockaddr) in localaddrs: + for (ignored_rfamily, ignored_rsocktype, ignored_rproto, + ignored_rcanonname, rsockaddr) in targetaddrs: + if rsockaddr[0] == sockaddr[0]: + return True return False @@ -137,7 +170,7 @@ def _determine_mode(addr, mode): elif isinstance(addr, bytes): host = addr else: - raise TypeError("'addr' needs to be a tuple or an bytestring!") + raise TypeError("'addr' needs to be a tuple or an bytestring (instead got {a!r})!".format(a=addr)) if mode == MODE_AUTO: if host_is_local(host): return MODE_PRIMARY @@ -169,153 +202,109 @@ def chunked(data, chunksize): return res -class _ExtendedManager(object): - """A class for managing the multiprocessing.managers.SyncManager""" - __safe_for_unpickling__ = True # this may not be safe for unpickling, - # but this is required by pickle. +def json_bytes_dumps(obj): + """ + Encodes obj into json, returning a bytestring. + This is mainly used for py2/py3 compatibility. + """ + dumped = json.dumps(obj, ensure_ascii=True) + encoded = dumped.encode(_DEFAULT_NETWORK_ENCODING) + return encoded - def __init__(self, addr, authkey, mode, start=False): - self.addr = addr - self.authkey = authkey - self.mode = _determine_mode(addr, mode) - self.manager = None - self._secondary_state = multiprocessing.managers.Value(int, _STATE_RUNNING) - if start: - self.start() +def json_bytes_loads(bytestr): + """ + Decodes a bytestring into json. + This is mainly used for py2/py3 cimpatibility. + """ + decoded = bytestr.decode(_DEFAULT_NETWORK_ENCODING) + loaded = json.loads(decoded) + return loaded - def __reduce__(self): - """ - This method is used by pickle to serialize instances of this class. - """ - return ( - self.__class__, - (self.addr, self.authkey, self.mode, True), - ) - def start(self): - """Starts or connects to the manager.""" - if self.mode == MODE_PRIMARY: - i = self._start() - else: - i = self._connect() - self.manager = i - - def stop(self): - """Stops the manager.""" - self.manager.shutdown() - - def set_secondary_state(self, value): - """Sets the value for 'secondary_state'.""" - if value not in (_STATE_RUNNING, _STATE_SHUTDOWN, _STATE_FORCED_SHUTDOWN): - raise ValueError( - "State {!r} is invalid - needs to be one of _STATE_RUNNING, _STATE_SHUTDOWN, or _STATE_FORCED_SHUTDOWN".format( - value) - ) - if self.manager is None: - raise RuntimeError("Manager not started") - self.manager.set_state(value) - - def _get_secondary_state(self): - """ - Returns the value for 'secondary_state'. - This is required for the manager. - """ - return self._secondary_state +def _serialize_tasks(tasks): + """serialize a tasklist.""" + return pickle.dumps(tasks, -1).decode(_DEFAULT_PICKLE_ENCODING) - def _get_manager_class(self, register_callables=False): - """ - Returns a new 'Manager' subclass with registered methods. - If 'register_callable' is True, defines the 'callable' arguments. - """ - class _EvaluatorSyncManager(managers.BaseManager): - """ - A custom BaseManager. - Please see the documentation of `multiprocessing` for more - information. - """ - pass +def _load_tasks(s): + """loads a tasklist from a string returned by _serialize_tasks()""" + return pickle.loads(s.encode(_DEFAULT_PICKLE_ENCODING)) - inqueue = queue.Queue() - outqueue = queue.Queue() - namespace = Namespace() - - if register_callables: - _EvaluatorSyncManager.register( - "get_inqueue", - callable=lambda: inqueue, - ) - _EvaluatorSyncManager.register( - "get_outqueue", - callable=lambda: outqueue, - ) - _EvaluatorSyncManager.register( - "get_state", - callable=self._get_secondary_state, - ) - _EvaluatorSyncManager.register( - "set_state", - callable=lambda v: self._secondary_state.set(v), - ) - _EvaluatorSyncManager.register( - "get_namespace", - callable=lambda: namespace, - ) + +class _MessageHandler(object): + """ + Class for managing a socket connection. + This includes detecting incomplete messages and completing them with + later messages. + """ + + # constants for managing the current state + _STATE_RECV_PREFIX = 0 # we are currently waiting for the length prefix + _STATE_RECV_MESSAGE = 1 # we arer currently receiving a message + + def __init__(self, s): + self._s = s + self._state = self._STATE_RECV_PREFIX + self._msg_size = _LENGTH_PREFIX_LENGTH + self._cur_buff = b"" + self.messages = [] + + def feed(self, data): + """ + Process received data. + Returns the number which still need to be received for this message. + """ + received_a_whole_message = False + self._cur_buff += data + while len(self._cur_buff) >= self._msg_size: + received_a_whole_message = received_a_whole_message or (len(self._cur_buff) >= self._msg_size) + msg = self._cur_buff[:self._msg_size] + self._cur_buff = self._cur_buff[self._msg_size:] + self._handle_message(msg) + if received_a_whole_message: + return 0 + else: + remaining = self._msg_size - len(self._cur_buff) + return remaining + + def _handle_message(self, msg): + """handle an incomming message as required by self._state""" + if self._state == self._STATE_RECV_PREFIX: + self._msg_size = struct.unpack(_LENGTH_PREFIX, msg)[0] + self._state = self._STATE_RECV_MESSAGE + elif self._state == self._STATE_RECV_MESSAGE: + self._msg_size = _LENGTH_PREFIX_LENGTH + self._state = self._STATE_RECV_PREFIX + self.messages.append(msg) else: - _EvaluatorSyncManager.register( - "get_inqueue", - ) - _EvaluatorSyncManager.register( - "get_outqueue", - ) - _EvaluatorSyncManager.register( - "get_state", - ) - _EvaluatorSyncManager.register( - "set_state", - ) - _EvaluatorSyncManager.register( - "get_namespace", - ) - return _EvaluatorSyncManager - - def _connect(self): - """Connects to the manager.""" - cls = self._get_manager_class(register_callables=False) - ins = cls(address=self.addr, authkey=self.authkey) - ins.connect() - return ins - - def _start(self): - """Starts the manager.""" - cls = self._get_manager_class(register_callables=True) - ins = cls(address=self.addr, authkey=self.authkey) - ins.start() - return ins - - @property - def secondary_state(self): - """Whether the secondary nodes should still process elements.""" - v = self.manager.get_state() - return v.get() - - def get_inqueue(self): - """Returns the inqueue.""" - if self.manager is None: - raise RuntimeError("Manager not started") - return self.manager.get_inqueue() - - def get_outqueue(self): - """Returns the outqueue.""" - if self.manager is None: - raise RuntimeError("Manager not started") - return self.manager.get_outqueue() - - def get_namespace(self): - """Returns the namespace.""" - if self.manager is None: - raise RuntimeError("Manager not started") - return self.manager.get_namespace() + raise RuntimeError("Internal error: invalid state!") + + def send_message(self, msg): + """sends a message.""" + length = len(msg) + prefix = struct.pack(_LENGTH_PREFIX, length) + data = prefix + msg + self._s.send(data) + + def send_json(self, d): + """serializes d into json, then sends the message.""" + ser = json_bytes_dumps(d) + return self.send_message(ser) + + def recv(self): + """receives a message from the socket (blocking).""" + to_recv = 1 # receive one byte initialy + while True: + data = self._s.recv(to_recv) + to_recv = self.feed(data) + if to_recv == 0: + return + + def get_message(self): + """if a message was received, return it. Otherwise, receive a message and return it.""" + while len(self.messages) == 0: + self.recv() + return self.messages.pop(0) class DistributedEvaluator(object): @@ -338,7 +327,7 @@ def __init__( ``authkey`` is the password used to restrict access to the manager; see ``Authentication Keys`` in the `multiprocessing` manual for more information. All DistributedEvaluators need to use the same authkey. Note that this needs - to be a `bytes` object for Python 3.X, and should be in 2.7 for compatibility + to be a `str` object for Python 3.X, and should be in 2.7 for compatibility (identical in 2.7 to a `str` object). ``eval_function`` should take two arguments (a genome object and the configuration) and return a single float (the genome's fitness). @@ -356,7 +345,7 @@ def __init__( self.authkey = authkey self.eval_function = eval_function self.secondary_chunksize = secondary_chunksize - self.slave_chunksize = secondary_chunksize # backward compatibility + self.slave_chunksize = secondary_chunksize # backwards compatibility if num_workers: self.num_workers = num_workers else: @@ -368,21 +357,12 @@ def __init__( self.num_workers = 1 self.worker_timeout = worker_timeout self.mode = _determine_mode(self.addr, mode) - self.em = _ExtendedManager(self.addr, self.authkey, mode=self.mode, start=False) - self.inqueue = None - self.outqueue = None - self.namespace = None self.started = False - - def __getstate__(self): - """Required by the pickle protocol.""" - # we do not actually save any state, but we need __getstate__ to be - # called. - return True # return some nonzero value - - def __setstate__(self, state): - """Called when instances of this class are unpickled.""" - self._set_shared_instances() + self._inqueue = queue.Queue() + self._outqueue = queue.Queue() + self._sock_thread = None + self._va_lock = threading.Lock() # lock to prevent parallel access to some vars + self._stopwaitevent = threading.Event() # event for waiting for the network thread to stop def is_primary(self): """Returns True if the caller is the primary node""" @@ -416,7 +396,6 @@ def start(self, exit_on_stop=True, secondary_wait=0, reconnect=False): self._start_primary() elif self.mode == MODE_SECONDARY: time.sleep(secondary_wait) - self._start_secondary() self._secondary_loop(reconnect=reconnect) if exit_on_stop: sys.exit(0) @@ -437,38 +416,251 @@ def stop(self, wait=1, shutdown=True, force_secondary_shutdown=False): if not self.started: raise RuntimeError("Not yet started!") if force_secondary_shutdown: - state = _STATE_FORCED_SHUTDOWN + self._state = _STATE_FORCED_SHUTDOWN else: - state = _STATE_SHUTDOWN - self.em.set_secondary_state(state) - time.sleep(wait) + self._state = _STATE_SHUTDOWN + time.sleep(wait) # wait is now mostly for backwards compability if shutdown: - self.em.stop() + try: + self._listen_s.close() + except: + pass + self._listen_s = None self.started = False - self.inqueue = self.outqueue = self.namespace = None + self._stopwaitevent.wait() def _start_primary(self): - """Start as the primary""" - self.em.start() - self.em.set_secondary_state(_STATE_RUNNING) - self._set_shared_instances() - - def _start_secondary(self): - """Start as a secondary.""" - self.em.start() - self._set_shared_instances() - - def _set_shared_instances(self): - """Sets attributes from the shared instances.""" - self.inqueue = self.em.get_inqueue() - self.outqueue = self.em.get_outqueue() - self.namespace = self.em.get_namespace() - - def _reset_em(self): - """Resets self.em and the shared instances.""" - self.em = _ExtendedManager(self.addr, self.authkey, mode=self.mode, start=False) - self.em.start() - self._set_shared_instances() + """Start as the primary node.""" + # setup primary specific vars + self._clients = {} # socket -> _MessageHandler + self._s2tasks = {} # socket -> tasks + self._authenticated_clients = [] # list of authenticated secondaries + self._waiting_clients = [] # list of secondaries waiting for tasks + + # create socket, bind and listen + self._bind_and_listen() + + # create and start network thread + self._sock_thread = threading.Thread( + name="{c} network thread".format(c=self.__class__), + target=self._primary_sock_thread, + ) + self._sock_thread.start() + + def _bind_and_listen(self): + """create a socket, bind it and starts listening for connections.""" + self._listen_s = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) # todo: ipv6 support + self._listen_s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._listen_s.bind(self.addr) + self._listen_s.listen(3) + + def _primary_sock_thread(self): + """method for the socket thread of the primary node.""" + if self.mode != MODE_PRIMARY: + raise ModeError("Not a primary node!") + self._stopwaitevent.clear() # just to be sure + while self.started: + to_check_read = [self._listen_s] + list(self._clients.keys()) # list() for python3 compatibility + to_check_err = [self._listen_s] + list(self._clients.keys()) # ^^ TODO: is there another way to do this? + to_read, to_write, has_err = select.select(to_check_read, [], to_check_err, 0.1) + if (len(to_read) + len(to_write) + len(has_err)) == 0: + continue + + for s in to_read: + if s is self._listen_s: + # new connection + try: + c, addr = s.accept() + except Exception: + if not self._started: + break + raise + mh = _MessageHandler(c) + self._clients[c] = mh + else: + # data available + data = s.recv(4096) # receive at most 4k bytes. If more are available, recv them in the next iteration. + if len(data) == 0: + # connection closed. + self._remove_client(s) + continue + mh = self._clients.get(s, None) + if mh is None: + self._remove_client(s) + continue + mh.feed(data) + + while len(mh.messages) > 0: + msg = mh.messages.pop(0) + try: + loaded = json_bytes_loads(msg) + except: + self._remove_client(s) + break + action = loaded.get("action", None) + + # authentication + if action == "auth": + authkey = loaded.get("authkey") + if authkey == self.authkey: + if s not in self._authenticated_clients: + self._authenticated_clients.append(s) + mh.send_json({"action": "auth_response", "success": True}) + else: + mh.send_json({"action": "auth_response", "success": False}) + self._remove_client(s) + break + elif s not in self._authenticated_clients: + # client did not authenticate + self._remove_client(s) + break + + # taks distribution + elif action == "get_task": + try: + tasks = self._inqueue.get(timeout=0) + except queue.Empty: + self._va_lock.acquire() + self._waiting_clients.append(s) + self._va_lock.release() + else: + self._send_tasks(mh, tasks) + + # results + elif action == "results": + results = loaded.get("results", None) + if results is not None: + self._outqueue.put(results) + self._va_lock.acquire() + if s in self._s2tasks: + del self._s2tasks[s] + self._va_lock.release() + else: + # results not send; reissue tasks if required. + self._va_lock.acquire() + if s in self._s2tasks: + tasks = self._s2tasks[s] + del self._s2tasks[s] + self._va_lock.release() + self._inqueue.put(tasks) + + else: + # unknown message; this is probably an error. + self._remove_client(s) + break + + for s in has_err: + if s is self._listen_s: + # cant listen anymore + # try to rebind + try: + s.close() + except: + # ignore exception during socket close, + # socket may already be closed. + pass + if self._state == self._STATE_RUNNING: + try: + self._bind_and_listen() + except: + self._state = self._STATE_ERROR + break + else: + # server stopped or not yet started, + # do not rebind and break this loop instead. + break + else: + self._remove_client(s) + + # stop connected clients + forced = (self._state == _STATE_FORCED_SHUTDOWN) + try: + for s in list(self._clients.keys()): # list() for py3 compatibility + self._send_stop(s, forced=forced) + self._remove_client(s) + finally: + self._stopwaitevent.set() + + def _remove_client(self, s): + """closes and removes the client.""" + self._va_lock.acquire() + try: + s.close() + except: + pass + if s in self._clients: + del self._clients[s] + if s in self._authenticated_clients: + self._authenticated_clients.remove(s) + if s in self._s2tasks: + tasks = self._s2tasks[s] + del self._s2tasks[s] + else: + tasks = None + self._va_lock.release() + if tasks is not None: + self._add_tasks(tasks) + + def _send_tasks(self, mh, tasks): + """sends some tasks to a secondary connected through the message handler mh.""" + ser_tasks = _serialize_tasks(tasks) + mh.send_json( + { + "action": "tasks", + "tasks": ser_tasks, + }, + ) + + def _send_stop(self, s, forced=False): + """sends a stop message to a client.""" + self._va_lock.acquire() + mh = self._clients.get(s, _MessageHandler(s)) + mh.send_json( + { + "action": "stop", + "forced": forced, + } + ) + self._va_lock.release() + + def _add_tasks(self, tasks): + """adds a task for evaluation.""" + if len(self._waiting_clients) > 0: + self._va_lock.acquire() + s = self._waiting_clients.pop(0) + mh = self._clients.get(s, None) + self._va_lock.release() + if mh is None: + self._remove_client(s) + return self._add_tasks(tasks) + self._send_tasks(mh, tasks) + else: + self._inqueue.put(tasks) + + def _reset(self): + """resets the internal state of the secondary nodes.""" + # connect + self._s = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) + self._s.connect(self.addr) + + # create a _MessageHandler + self._mh = _MessageHandler(self._s) + + # auth + self._mh.send_json( + { + "action": "auth", + "authkey": self.authkey, + } + ) + response = json_bytes_loads(self._mh.get_message()) + if response.get("action", None) != "auth_response": + self._s.close() + raise ProtocolError("Server did not send an auth response!") + success = response.get("success", False) + if not success: + self._s.close() + raise AuthError("Invalid authkey!") def _secondary_loop(self, reconnect=False): """The worker loop for the secondary nodes.""" @@ -476,45 +668,20 @@ def _secondary_loop(self, reconnect=False): pool = multiprocessing.Pool(self.num_workers) else: pool = None - should_reconnect = True - while should_reconnect: - i = 0 + self._should_reconnect = True + while self._should_reconnect: running = True try: - self._reset_em() - except (socket.error, EOFError, IOError, OSError, socket.gaierror, TypeError): + self._reset() + except (socket.error, EOFError, IOError, OSError, socket.gaierror,): continue while running: - i += 1 - if i % 5 == 0: - # for better performance, only check every 5 cycles - try: - state = self.em.secondary_state - except (socket.error, EOFError, IOError, OSError, socket.gaierror, TypeError): - if not reconnect: - raise - else: - break - if state == _STATE_FORCED_SHUTDOWN: - running = False - should_reconnect = False - elif state == _STATE_SHUTDOWN: - running = False - if not running: - continue try: - tasks = self.inqueue.get(block=True, timeout=0.2) - except queue.Empty: - continue - except (socket.error, EOFError, IOError, OSError, socket.gaierror, TypeError): - break - except (managers.RemoteError, multiprocessing.ProcessError) as e: - if ('Empty' in repr(e)) or ('TimeoutError' in repr(e)): - continue - if (('EOFError' in repr(e)) or ('PipeError' in repr(e)) or - ('AuthenticationError' in repr(e))): # Second for Python 3.X, Third for 3.6+ + tasks = self._get_tasks() + if tasks is None: break - raise + except (socket.error, EOFError, IOError, OSError, socket.gaierror): + break if pool is None: res = [] for genome_id, genome, config in tasks: @@ -533,25 +700,54 @@ def _secondary_loop(self, reconnect=False): results = [ job.get(timeout=self.worker_timeout) for job in jobs ] - res = zip(genome_ids, results) + res = list(zip(genome_ids, results)) # list() for py3 compatibility try: - self.outqueue.put(res) - except (socket.error, EOFError, IOError, OSError, socket.gaierror, TypeError): + self._send_results(res) + except (socket.error, EOFError, IOError, OSError, socket.gaierror): break - except (managers.RemoteError, multiprocessing.ProcessError) as e: - if ('Empty' in repr(e)) or ('TimeoutError' in repr(e)): - continue - if (('EOFError' in repr(e)) or ('PipeError' in repr(e)) or - ('AuthenticationError' in repr(e))): # Second for Python 3.X, Third for 3.6+ - break - raise if not reconnect: - should_reconnect = False + self._should_reconnect = False break + try: + self._s.close() + except: + pass if pool is not None: pool.terminate() + def _get_tasks(self): + """ + Receives some tasks from the primary. + This method returns either the received tasks or None if the + secondary was stopped by the primary. + """ + tosend = {"action": "get_task"} + self._mh.send_json(tosend) + while True: + msg = json_bytes_loads(self._mh.get_message()) + action = msg.get("action", None) + + if action == "stop": + forced = msg.get("forced", False) + if forced: + self._should_reconnect = False + return None + + elif action == "tasks": + tasks = msg.get("tasks", None) + if tasks is not None: + return _load_tasks(tasks) + + def _send_results(self, results): + """sends the results to the primary node.""" + self._mh.send_json( + { + "action": "results", + "results": results, + } + ) + def evaluate(self, genomes, config): """ Evaluates the genomes. @@ -564,13 +760,13 @@ def evaluate(self, genomes, config): id2genome = {genome_id: genome for genome_id, genome in genomes} tasks = chunked(tasks, self.secondary_chunksize) n_tasks = len(tasks) - for task in tasks: - self.inqueue.put(task) + for tasklist in tasks: + self._add_tasks(tasklist) tresults = [] while len(tresults) < n_tasks: try: - sr = self.outqueue.get(block=True, timeout=0.2) - except (queue.Empty, managers.RemoteError): + sr = self._outqueue.get(block=True, timeout=0.2) + except (queue.Empty): continue tresults.append(sr) results = [] diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 1db23598..3c5791b0 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -31,7 +31,7 @@ def eval_dummy_genome_nn(genome, config): def test_chunked(): - """Test for neat.distributed.chunked""" + """Test for neat.distributed.chunked()""" # test chunked(range(110), 10) # => 11 chunks of 10 elements d110 = list(range(110)) @@ -74,18 +74,17 @@ def test_chunked(): def test_host_is_local(): - """test for neat.distributed.host_is_local""" + """test for neat.distributed.host_is_local()""" tests = ( # (hostname or ip, expected value) - ("localhost", True), - ("0.0.0.0", True), - ("127.0.0.1", True), - # ("::1", True), # depends on IP, etc setup on host to work right + (b"localhost", True), + (b"0.0.0.0", True), + (b"127.0.0.1", True), (socket.gethostname(), True), (socket.getfqdn(), True), - ("github.com", False), - ("google.de", False), - ) + (b"github.com", False), + (b"google.de", False), + ) for hostname, islocal in tests: try: result = neat.host_is_local(hostname) @@ -93,62 +92,62 @@ def test_host_is_local(): print("test_host_is_local: Error with hostname {0!r} (expected {1!r})".format(hostname, islocal)) raise - else: # if do not want to do 'raise' above for some cases - assert result == islocal, "Hostname/IP: {h}; Expected: {e}; Got: {r!r}".format( + else: # if do not want to do 'raise' above for some cases + assert result == islocal, "Unexpected results for hostname/IP: {h}; Expected: {e}; Got: {r!r}".format( h=hostname, e=islocal, r=result) def test_DistributedEvaluator_mode(): - """Tests for the mode determination of DistributedEvaluator""" + """Tests for the mode determination of neat.distributed.DistributedEvaluator()""" # test auto mode setting # we also test that the mode is not # automatically determined when explicitly given. tests = ( # (hostname or ip, mode to pass, expected mode) - ("localhost", MODE_PRIMARY, MODE_PRIMARY), - ("0.0.0.0", MODE_PRIMARY, MODE_PRIMARY), - ("localhost", MODE_SECONDARY, MODE_SECONDARY), - ("example.org", MODE_PRIMARY, MODE_PRIMARY), - (socket.gethostname(), MODE_SECONDARY, MODE_SECONDARY), - ("localhost", MODE_AUTO, MODE_PRIMARY), - (socket.gethostname(), MODE_AUTO, MODE_PRIMARY), - (socket.getfqdn(), MODE_AUTO, MODE_PRIMARY), - ("example.org", MODE_AUTO, MODE_SECONDARY), - ) + (b"localhost", MODE_PRIMARY, MODE_PRIMARY), + (b"0.0.0.0", MODE_PRIMARY, MODE_PRIMARY), + (b"localhost", MODE_SECONDARY, MODE_SECONDARY), + (b"example.org", MODE_PRIMARY, MODE_PRIMARY), + (socket.gethostname().encode("ascii"), MODE_SECONDARY, MODE_SECONDARY), + (b"localhost", MODE_AUTO, MODE_PRIMARY), + (socket.gethostname().encode("ascii"), MODE_AUTO, MODE_PRIMARY), + (socket.getfqdn().encode("ascii"), MODE_AUTO, MODE_PRIMARY), + (b"example.org", MODE_AUTO, MODE_SECONDARY), + ) for hostname, mode, expected in tests: - addr = (hostname, 8022) - try: - de = neat.DistributedEvaluator( - addr, - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=mode, - ) - except EnvironmentError: - print("test_DistributedEvaluator_mode(): Error with hostname " + - "{!r}".format(hostname)) - raise - result = de.mode - assert result == expected, "Mode determination failed! Hostname: {h}; expected: {e}; got: {r!r}!".format( - h=hostname, e=expected, r=result) - - if result == MODE_AUTO: - raise Exception( - "DistributedEvaluator.__init__(mode=MODE_AUTO) did not automatically determine its mode!" - ) - elif (result == MODE_PRIMARY) and (not de.is_primary()): - raise Exception( - "DistributedEvaluator.is_primary() returns False even if the evaluator is in primary mode!" - ) - elif (result == MODE_SECONDARY) and de.is_primary(): - raise Exception( - "DistributedEvaluator.is_primary() returns True even if the evaluator is in secondary mode!" - ) + for addr in ((hostname, 8022), hostname): + try: + de = neat.DistributedEvaluator( + addr, + authkey="abcd1234", + eval_function=eval_dummy_genome_nn, + mode=mode, + ) + except EnvironmentError: + print("test_DistributedEvaluator_mode(): Error with hostname " + + "{!r}".format(hostname)) + raise + result = de.mode + assert result == expected, "Mode determination failed! Hostname: {h}; expected: {e}; got: {r!r}!".format( + h=hostname, e=expected, r=result) + + if result == MODE_AUTO: + raise Exception( + "DistributedEvaluator.__init__(mode=MODE_AUTO) did not automatically determine its mode!" + ) + elif (result == MODE_PRIMARY) and (not de.is_primary()): + raise Exception( + "DistributedEvaluator.is_primary() returns False even if the evaluator is in primary mode!" + ) + elif (result == MODE_SECONDARY) and de.is_primary(): + raise Exception( + "DistributedEvaluator.is_primary() returns True even if the evaluator is in secondary mode!" + ) # test invalid mode error try: de = neat.DistributedEvaluator( addr, - authkey=b"abcd1234", + authkey=u"abcd1234", eval_function=eval_dummy_genome_nn, mode="#invalid MODE!", ) @@ -159,11 +158,35 @@ def test_DistributedEvaluator_mode(): raise Exception("Passing an invalid mode did not cause an exception to be raised on start()!") +def test_json_bytes_dumps_loads(): + """ + Test for json_bytes_dumps() and json_bytes_loads(). + """ + test_objs = [ + 1, + 2, + 3, + [1, 2, 3], + {"one": 1, "two": 2, "three": "three"}, + {"1": 2, "3":4, "5":6}, + u"unicode_string", + "native_string", + ] + bytestype = type(b"some bytestring") + for obj in test_objs: + dumped = neat.distributed.json_bytes_dumps(obj) + if type(dumped) != bytestype: + raise Exception("neat.distributed.json_bytes_dumps({o}) did not return a bytestring!".format(o=repr(obj))) + loaded = neat.distributed.json_bytes_loads(dumped) + if loaded != obj: + raise Exception("neat.distributed.json_bytes_loads(): {lo}; expected: {o}!".format(lo=repr(loaded), o=repr(obj))) + + def test_DistributedEvaluator_primary_restrictions(): - """Tests that some primary-exclusive methods fail when called by the secondaries""" + """Test that primary-exclusive methods fail when called by the secondary nodes""" secondary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", + (b"localhost", 8022), + authkey=u"abcd1234", eval_function=eval_dummy_genome_nn, mode=MODE_SECONDARY, ) @@ -185,88 +208,6 @@ def test_DistributedEvaluator_primary_restrictions(): raise Exception("A DistributedEvaluator in secondary mode could call evaluate()!") -def test_DistributedEvaluator_state_error1(): - """Tests that attempts to use an unstarted manager for set_secondary_state cause an error.""" - primary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=MODE_PRIMARY, - ) - try: - primary.em.set_secondary_state(_STATE_RUNNING) - except RuntimeError: - pass - else: - raise Exception("primary.em.set_secondary_state with unstarted manager did not raise a RuntimeError!") - - -def test_DistributedEvaluator_state_error2(): - """Tests that attempts to use an unstarted manager for get_inqueue cause an error.""" - primary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=MODE_PRIMARY, - ) - try: - ignored = primary.em.get_inqueue() - except RuntimeError: - pass - else: - raise Exception("primary.em.get_inqueue() with unstarted manager did not raise a RuntimeError!") - - -def test_DistributedEvaluator_state_error3(): - """Tests that attempts to use an unstarted manager for get_outqueue cause an error.""" - primary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=MODE_PRIMARY, - ) - try: - ignored = primary.em.get_outqueue() - except RuntimeError: - pass - else: - raise Exception("primary.em.get_outqueue() with unstarted manager did not raise a RuntimeError!") - - -def test_DistributedEvaluator_state_error4(): - """Tests that attempts to use an unstarted manager for get_namespace cause an error.""" - primary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=MODE_PRIMARY, - ) - try: - ignored = primary.em.get_namespace() - except RuntimeError: - pass - else: - raise Exception("primary.em.get_namespace() with unstarted manager did not raise a RuntimeError!") - - -def test_DistributedEvaluator_state_error5(): - """Tests that attempts to set an invalid state cause an error.""" - primary = neat.DistributedEvaluator( - ("localhost", 8022), - authkey=b"abcd1234", - eval_function=eval_dummy_genome_nn, - mode=MODE_PRIMARY, - ) - primary.start() - try: - primary.em.set_secondary_state(-1) - except ValueError: - pass - else: - raise Exception("primary.em.set_secondary_state(-1) did not raise a ValueError!") - - -@unittest.skipIf(ON_PYPY, "This test fails on pypy during travis builds but usually works locally.") def test_distributed_evaluation_multiprocessing(do_mwcp=True): """ Full test run using the Distributed Evaluator (fake nodes using processes). @@ -276,8 +217,8 @@ def test_distributed_evaluation_multiprocessing(do_mwcp=True): We emulate the other machines using subprocesses created using the multiprocessing module. """ - addr = ("localhost", random.randint(12000, 30000)) - authkey = b"abcd1234" + addr = (b"localhost", random.randint(12000, 30000)) + authkey = u"abcd1234" mp = multiprocessing.Process( name="Primary evaluation process", target=run_primary, @@ -298,27 +239,32 @@ def test_distributed_evaluation_multiprocessing(do_mwcp=True): swcp.daemon = True # we cannot set this on mwcp if do_mwcp: mwcp.start() + else: + print("[distributed] Tests/Process-control: Skipping start of the multiworker secondary node process...") swcp.start() try: - print("Joining primary process") + print("[distributed] Tests/Process-control: Joining process of primary node...") sys.stdout.flush() mp.join() + print("[distributed] Tests/Process-control: successfully joined process of primary node.") if mp.exitcode != 0: - raise Exception("Primary-process exited with status {s}!".format(s=mp.exitcode)) + raise Exception("Primary-node-process exited with status {s}!".format(s=mp.exitcode)) if do_mwcp: if not mwcp.is_alive(): - print("mwcp is not 'alive'") - print("children: {c}".format(c=multiprocessing.active_children())) - print("Joining multiworker-secondary process") + print("[distributed] Tests/Process-control: mwcp is not 'alive'!") + print("[distributed] Tests/Process-control: children: {c}.".format(c=multiprocessing.active_children())) + print("[distributed] Tests/Process-control: Joining multiworker-secondary process...") sys.stdout.flush() mwcp.join() + print("[distributed] Tests/Process-control: successfully joinded multiworker-secondary process.") if mwcp.exitcode != 0: raise Exception("Multiworker-secondary-process exited with status {s}!".format(s=mwcp.exitcode)) if not swcp.is_alive(): - print("swcp is not 'alive'") - print("Joining singleworker-secondary process") + print("[distributed] Tests/Process-control: swcp is not 'alive'!") + print("[distributed] Tests/Process-control: Joining singleworker-secondary process.") sys.stdout.flush() swcp.join() + print("[distributed] Tests/Process-control: successfully joined singleworker-secondary process.") if swcp.exitcode != 0: raise Exception("Singleworker-secondary-process exited with status {s}!".format(s=swcp.exitcode)) @@ -331,7 +277,40 @@ def test_distributed_evaluation_multiprocessing(do_mwcp=True): swcp.terminate() -@unittest.skipIf(ON_PYPY, "Pypy has problems with threading.") +def test_distributed_evaluation_invalid_authkey_multiprocessing(pc=multiprocessing.Process): + """Test for DistributedEvaluator-behavior on invalid authkey (Fake nodes using processes)""" + addr = (b"localhost", random.randint(12000, 30000)) + valid_authkey = u"Linux>Windows" + invalid_authkey = u"Windows>Linux" + print("[distributed] Tests/Process-control: starting primary node/thread...") + mp = pc( + name="Primary evaluation process", + target=run_primary, + args=(addr, valid_authkey, 19), # 19 because stagnation is at 20 + ) + mp.daemon = True + mp.start() + print("[distributed] Tests/Process-control: running secondary node in main process/thread...") + try: + run_secondary(addr, authkey=invalid_authkey, num_workers=1) + except neat.distributed.AuthError: + # expected + print("[distributed] Tests/Process-control: catched expected AuthError.") + pass + else: + print("[distributed] Tests/Process-control: did not catch expected AuthError! This is an Error!") + raise Exception("Expected an auth error!") + finally: + print("[distributed] Tests/Process-control: Terminating primary process/thread.") + if hasattr(mp, "terminate"): + mp.terminate() + + +def test_distributed_evaluation_invalid_authkey_threaded(): + """Test for DistributedEvaluator-behavior on invalid authkey (Fake nodes using threads)""" + test_distributed_evaluation_invalid_authkey_multiprocessing(pc=threading.Thread) + + def test_distributed_evaluation_threaded(): """ Full test run using the Distributed Evaluator (fake nodes using threads). @@ -345,8 +324,8 @@ def test_distributed_evaluation_threaded(): """ if not HAVE_THREADING: raise unittest.SkipTest("Platform does not have threading") - addr = ("localhost", random.randint(12000, 30000)) - authkey = b"abcd1234" + addr = (b"localhost", random.randint(12000, 30000)) + authkey = u"abcd1234" mp = threading.Thread( name="Primary evaluation thread", target=run_primary, @@ -402,17 +381,17 @@ def run_primary(addr, authkey, generations): eval_function=eval_dummy_genome_nn, mode=MODE_PRIMARY, secondary_chunksize=15, - ) - print("Starting DistributedEvaluator") + ) + print("[distributed] primary: starting DistributedEvaluator") sys.stdout.flush() de.start() - print("Running evaluate") + print("[distributed] primary: starting evaluation of genomes...") sys.stdout.flush() p.run(de.evaluate, generations) - print("Evaluated") + print("[distributed] primary: evaluation finished.") sys.stdout.flush() de.stop(wait=5) - print("Did de.stop") + print("[distributed] primary: did de.stop") sys.stdout.flush() stats.save() @@ -442,12 +421,18 @@ def run_secondary(addr, authkey, num_workers=1): eval_function=eval_dummy_genome_nn, mode=MODE_SECONDARY, num_workers=num_workers, - ) + ) + print("[distributed] secondary: starting DistributedEvaluator...") + sys.stdout.flush() try: - de.start(secondary_wait=3, exit_on_stop=True) + de.start(secondary_wait=3, exit_on_stop=True, reconnect=False) except SystemExit: + print("[distributed] secondary: caught expected SystemExit.") + sys.stdout.flush() pass else: + print("[distributed] secondary: expected a SystemExit; not SystemExit caught!") + sys.stdout.flush() raise Exception("DistributedEvaluator in secondary mode did not try to exit!") diff --git a/tests/test_xor_example_distributed.py b/tests/test_xor_example_distributed.py index ea60f24e..c6eca0b7 100644 --- a/tests/test_xor_example_distributed.py +++ b/tests/test_xor_example_distributed.py @@ -62,6 +62,7 @@ def run_primary(addr, authkey, generations): winner = p.run(de.evaluate, generations) print("===== stopping DistributedEvaluator =====") de.stop(wait=3, shutdown=True, force_secondary_shutdown=False) + print("===== DistributedEvaluator stopped. =====") if winner: # Display the winning genome. @@ -85,9 +86,11 @@ def run_primary(addr, authkey, generations): winner2 = None time.sleep(3) de.start() + winner2 = p2.run(de.evaluate, (100-checkpointer.last_generation_checkpoint)) winner2 = p2.run(de.evaluate, (100 - checkpointer.last_generation_checkpoint)) print("===== stopping DistributedEvaluator (forced) =====") de.stop(wait=3, shutdown=True, force_secondary_shutdown=True) + print("===== DistributedEvaluator stopped. =====") if winner2: if not winner: @@ -123,24 +126,30 @@ def run_secondary(addr, authkey, num_workers=1): eval_function=eval_genome_distributed, mode=MODE_SECONDARY, num_workers=num_workers, - ) - try: - de.start(secondary_wait=3, exit_on_stop=True, reconnect=True) - except SystemExit: - pass - else: - raise Exception("DistributedEvaluator in secondary mode did not try to exit!") - - -@unittest.skipIf(ON_PYPY, - "This test fails on pypy during travis builds (frequently due to timeouts) but usually works locally.") + ) + max_tries = 3 + for i in range(max_tries): + # sometimes it may take a while before the port used is available again + try: + de.start(secondary_wait=3, exit_on_stop=True, reconnect=True) + except SystemExit: + # expected + break + except OSError: + if i == (max_tries -1): + raise + else: + raise Exception("DistributedEvaluator in secondary mode did not try to exit!") + + +# @unittest.skipIf(ON_PYPY, "This test fails on pypy during travis builds (frequently due to timeouts) but usually works locally.") def test_xor_example_distributed(): """ Test to make sure restoration after checkpoint works with distributed. """ addr = ("localhost", random.randint(12000, 30000)) - authkey = b"abcd1234" + authkey = "abcd1234" mp = multiprocessing.Process( name="Primary evaluation process", target=run_primary,