diff --git a/kafka/client_async.py b/kafka/client_async.py index 301a5fd26..9d7492c14 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -542,7 +542,7 @@ def _can_send_request(self, node_id): return False return conn.connected() and conn.can_send_more() - def send(self, node_id, request, wakeup=True): + def send(self, node_id, request, wakeup=True, request_timeout_ms=None): """Send a request to a specific node. Bytes are placed on an internal per-connection send-queue. Actual network I/O will be triggered in a subsequent call to .poll() @@ -550,7 +550,13 @@ def send(self, node_id, request, wakeup=True): Arguments: node_id (int): destination node request (Struct): request object (not-encoded) - wakeup (bool): optional flag to disable thread-wakeup + + Keyword Arguments: + wakeup (bool, optional): optional flag to disable thread-wakeup. + request_timeout_ms (int, optional): Provide custom timeout in milliseconds. + If response is not processed before timeout, client will fail the + request and close the connection. + Default: None (uses value from client configuration) Raises: AssertionError: if node_id is not in current cluster metadata @@ -566,7 +572,7 @@ def send(self, node_id, request, wakeup=True): # conn.send will queue the request internally # we will need to call send_pending_requests() # to trigger network I/O - future = conn.send(request, blocking=False) + future = conn.send(request, blocking=False, request_timeout_ms=request_timeout_ms) if not future.is_done: self._sending.add(conn) @@ -725,11 +731,13 @@ def _poll(self, timeout): for conn in six.itervalues(self._conns): if conn.requests_timed_out(): + timed_out = conn.timed_out_ifrs() + timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000 log.warning('%s timed out after %s ms. Closing connection.', - conn, conn.config['request_timeout_ms']) + conn, timeout_ms) conn.close(error=Errors.RequestTimedOutError( 'Request timed out after %s ms' % - conn.config['request_timeout_ms'])) + timeout_ms)) if self._sensors: self._sensors.io_time.record((time.time() - end_select) * 1000000000) diff --git a/kafka/conn.py b/kafka/conn.py index 2a4f1df17..347e5000b 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -948,7 +948,7 @@ def close(self, error=None): # drop lock before state change callback and processing futures self.config['state_change_callback'](self.node_id, sock, self) sock.close() - for (_correlation_id, (future, _timestamp)) in ifrs: + for (_correlation_id, (future, _timestamp, _timeout)) in ifrs: future.failure(error) def _can_send_recv(self): @@ -956,8 +956,20 @@ def _can_send_recv(self): return self.state in (ConnectionStates.AUTHENTICATING, ConnectionStates.CONNECTED) - def send(self, request, blocking=True): - """Queue request for async network send, return Future()""" + def send(self, request, blocking=True, request_timeout_ms=None): + """Queue request for async network send, return Future() + + Arguments: + request (Request): kafka protocol request object to send. + + Keyword Arguments: + blocking (bool, optional): Whether to immediately send via + blocking socket I/O. Default: True. + request_timeout_ms: Custom timeout in milliseconds for request. + Default: None (uses value from connection configuration) + + Returns: future + """ future = Future() if self.connecting(): return future.failure(Errors.NodeNotReadyError(str(self))) @@ -965,9 +977,9 @@ def send(self, request, blocking=True): return future.failure(Errors.KafkaConnectionError(str(self))) elif not self.can_send_more(): return future.failure(Errors.TooManyInFlightRequests(str(self))) - return self._send(request, blocking=blocking) + return self._send(request, blocking=blocking, request_timeout_ms=request_timeout_ms) - def _send(self, request, blocking=True): + def _send(self, request, blocking=True, request_timeout_ms=None): future = Future() with self._lock: if not self._can_send_recv(): @@ -980,9 +992,11 @@ def _send(self, request, blocking=True): log.debug('%s Request %d: %s', self, correlation_id, request) if request.expect_response(): - sent_time = time.time() assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' - self.in_flight_requests[correlation_id] = (future, sent_time) + sent_time = time.time() + request_timeout_ms = request_timeout_ms or self.config['request_timeout_ms'] + timeout_at = sent_time + (request_timeout_ms / 1000) + self.in_flight_requests[correlation_id] = (future, sent_time, timeout_at) else: future.success(None) @@ -1061,18 +1075,20 @@ def recv(self): """ responses = self._recv() if not responses and self.requests_timed_out(): + timed_out = self.timed_out_ifrs() + timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000 log.warning('%s timed out after %s ms. Closing connection.', - self, self.config['request_timeout_ms']) + self, timeout_ms) self.close(error=Errors.RequestTimedOutError( 'Request timed out after %s ms' % - self.config['request_timeout_ms'])) + timeout_ms)) return () # augment responses w/ correlation_id, future, and timestamp for i, (correlation_id, response) in enumerate(responses): try: with self._lock: - (future, timestamp) = self.in_flight_requests.pop(correlation_id) + (future, timestamp, _timeout) = self.in_flight_requests.pop(correlation_id) except KeyError: self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) return () @@ -1143,13 +1159,17 @@ def _recv(self): def requests_timed_out(self): return self.next_ifr_request_timeout_ms() == 0 + def timed_out_ifrs(self): + now = time.time() + ifrs = sorted(self.in_flight_requests.values(), reverse=True, key=lambda ifr: ifr[2]) + return list(filter(lambda ifr: ifr[2] <= now, ifrs)) + def next_ifr_request_timeout_ms(self): with self._lock: if self.in_flight_requests: - get_timestamp = lambda v: v[1] - oldest_at = min(map(get_timestamp, - self.in_flight_requests.values())) - next_timeout = oldest_at + self.config['request_timeout_ms'] / 1000.0 + get_timeout = lambda v: v[2] + next_timeout = min(map(get_timeout, + self.in_flight_requests.values())) return max(0, (next_timeout - time.time()) * 1000) else: return float('inf') diff --git a/test/test_client_async.py b/test/test_client_async.py index b9b415012..ccdd57037 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -43,7 +43,7 @@ def test_bootstrap(mocker, conn): kwargs.pop('state_change_callback') kwargs.pop('node_id') assert kwargs == cli.config - conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False) + conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False, request_timeout_ms=None) assert cli._bootstrap_fails == 0 assert cli.cluster.brokers() == set([BrokerMetadata(0, 'foo', 12, None), BrokerMetadata(1, 'bar', 34, None)]) @@ -220,12 +220,12 @@ def test_send(cli, conn): request = ProduceRequest[0](0, 0, []) assert request.expect_response() is False ret = cli.send(0, request) - conn.send.assert_called_with(request, blocking=False) + conn.send.assert_called_with(request, blocking=False, request_timeout_ms=None) assert isinstance(ret, Future) request = MetadataRequest[0]([]) cli.send(0, request) - conn.send.assert_called_with(request, blocking=False) + conn.send.assert_called_with(request, blocking=False, request_timeout_ms=None) def test_poll(mocker): diff --git a/test/test_conn.py b/test/test_conn.py index fb4172814..f41153fc4 100644 --- a/test/test_conn.py +++ b/test/test_conn.py @@ -347,14 +347,14 @@ def test_requests_timed_out(conn): # No in-flight requests, not timed out assert not conn.requests_timed_out() - # Single request, timestamp = now (0) - conn.in_flight_requests[0] = ('foo', 0) + # Single request, timeout_at > now (0) + conn.in_flight_requests[0] = ('foo', 0, 1) assert not conn.requests_timed_out() # Add another request w/ timestamp > request_timeout ago request_timeout = conn.config['request_timeout_ms'] expired_timestamp = 0 - request_timeout - 1 - conn.in_flight_requests[1] = ('bar', expired_timestamp) + conn.in_flight_requests[1] = ('bar', 0, expired_timestamp) assert conn.requests_timed_out() # Drop the expired request and we should be good to go again