Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,21 @@ def _can_send_request(self, node_id):
return False
return conn.connected() and conn.can_send_more()

def send(self, node_id, request, wakeup=True):
def send(self, node_id, request, wakeup=True, request_timeout_ms=None):
"""Send a request to a specific node. Bytes are placed on an
internal per-connection send-queue. Actual network I/O will be
triggered in a subsequent call to .poll()

Arguments:
node_id (int): destination node
request (Struct): request object (not-encoded)
wakeup (bool): optional flag to disable thread-wakeup

Keyword Arguments:
wakeup (bool, optional): optional flag to disable thread-wakeup.
request_timeout_ms (int, optional): Provide custom timeout in milliseconds.
If response is not processed before timeout, client will fail the
request and close the connection.
Default: None (uses value from client configuration)

Raises:
AssertionError: if node_id is not in current cluster metadata
Expand All @@ -566,7 +572,7 @@ def send(self, node_id, request, wakeup=True):
# conn.send will queue the request internally
# we will need to call send_pending_requests()
# to trigger network I/O
future = conn.send(request, blocking=False)
future = conn.send(request, blocking=False, request_timeout_ms=request_timeout_ms)
if not future.is_done:
self._sending.add(conn)

Expand Down Expand Up @@ -725,11 +731,13 @@ def _poll(self, timeout):

for conn in six.itervalues(self._conns):
if conn.requests_timed_out():
timed_out = conn.timed_out_ifrs()
timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000
log.warning('%s timed out after %s ms. Closing connection.',
conn, conn.config['request_timeout_ms'])
conn, timeout_ms)
conn.close(error=Errors.RequestTimedOutError(
'Request timed out after %s ms' %
conn.config['request_timeout_ms']))
timeout_ms))

if self._sensors:
self._sensors.io_time.record((time.time() - end_select) * 1000000000)
Expand Down
48 changes: 34 additions & 14 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,26 +948,38 @@ def close(self, error=None):
# drop lock before state change callback and processing futures
self.config['state_change_callback'](self.node_id, sock, self)
sock.close()
for (_correlation_id, (future, _timestamp)) in ifrs:
for (_correlation_id, (future, _timestamp, _timeout)) in ifrs:
future.failure(error)

def _can_send_recv(self):
"""Return True iff socket is ready for requests / responses"""
return self.state in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED)

def send(self, request, blocking=True):
"""Queue request for async network send, return Future()"""
def send(self, request, blocking=True, request_timeout_ms=None):
"""Queue request for async network send, return Future()

Arguments:
request (Request): kafka protocol request object to send.

Keyword Arguments:
blocking (bool, optional): Whether to immediately send via
blocking socket I/O. Default: True.
request_timeout_ms: Custom timeout in milliseconds for request.
Default: None (uses value from connection configuration)

Returns: future
"""
future = Future()
if self.connecting():
return future.failure(Errors.NodeNotReadyError(str(self)))
elif not self.connected():
return future.failure(Errors.KafkaConnectionError(str(self)))
elif not self.can_send_more():
return future.failure(Errors.TooManyInFlightRequests(str(self)))
return self._send(request, blocking=blocking)
return self._send(request, blocking=blocking, request_timeout_ms=request_timeout_ms)

def _send(self, request, blocking=True):
def _send(self, request, blocking=True, request_timeout_ms=None):
future = Future()
with self._lock:
if not self._can_send_recv():
Expand All @@ -980,9 +992,11 @@ def _send(self, request, blocking=True):

log.debug('%s Request %d: %s', self, correlation_id, request)
if request.expect_response():
sent_time = time.time()
assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!'
self.in_flight_requests[correlation_id] = (future, sent_time)
sent_time = time.time()
request_timeout_ms = request_timeout_ms or self.config['request_timeout_ms']
timeout_at = sent_time + (request_timeout_ms / 1000)
self.in_flight_requests[correlation_id] = (future, sent_time, timeout_at)
else:
future.success(None)

Expand Down Expand Up @@ -1061,18 +1075,20 @@ def recv(self):
"""
responses = self._recv()
if not responses and self.requests_timed_out():
timed_out = self.timed_out_ifrs()
timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000
log.warning('%s timed out after %s ms. Closing connection.',
self, self.config['request_timeout_ms'])
self, timeout_ms)
self.close(error=Errors.RequestTimedOutError(
'Request timed out after %s ms' %
self.config['request_timeout_ms']))
timeout_ms))
return ()

# augment responses w/ correlation_id, future, and timestamp
for i, (correlation_id, response) in enumerate(responses):
try:
with self._lock:
(future, timestamp) = self.in_flight_requests.pop(correlation_id)
(future, timestamp, _timeout) = self.in_flight_requests.pop(correlation_id)
except KeyError:
self.close(Errors.KafkaConnectionError('Received unrecognized correlation id'))
return ()
Expand Down Expand Up @@ -1143,13 +1159,17 @@ def _recv(self):
def requests_timed_out(self):
return self.next_ifr_request_timeout_ms() == 0

def timed_out_ifrs(self):
now = time.time()
ifrs = sorted(self.in_flight_requests.values(), reverse=True, key=lambda ifr: ifr[2])
return list(filter(lambda ifr: ifr[2] <= now, ifrs))

def next_ifr_request_timeout_ms(self):
with self._lock:
if self.in_flight_requests:
get_timestamp = lambda v: v[1]
oldest_at = min(map(get_timestamp,
self.in_flight_requests.values()))
next_timeout = oldest_at + self.config['request_timeout_ms'] / 1000.0
get_timeout = lambda v: v[2]
next_timeout = min(map(get_timeout,
self.in_flight_requests.values()))
return max(0, (next_timeout - time.time()) * 1000)
else:
return float('inf')
Expand Down
6 changes: 3 additions & 3 deletions test/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_bootstrap(mocker, conn):
kwargs.pop('state_change_callback')
kwargs.pop('node_id')
assert kwargs == cli.config
conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False)
conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False, request_timeout_ms=None)
assert cli._bootstrap_fails == 0
assert cli.cluster.brokers() == set([BrokerMetadata(0, 'foo', 12, None),
BrokerMetadata(1, 'bar', 34, None)])
Expand Down Expand Up @@ -220,12 +220,12 @@ def test_send(cli, conn):
request = ProduceRequest[0](0, 0, [])
assert request.expect_response() is False
ret = cli.send(0, request)
conn.send.assert_called_with(request, blocking=False)
conn.send.assert_called_with(request, blocking=False, request_timeout_ms=None)
assert isinstance(ret, Future)

request = MetadataRequest[0]([])
cli.send(0, request)
conn.send.assert_called_with(request, blocking=False)
conn.send.assert_called_with(request, blocking=False, request_timeout_ms=None)


def test_poll(mocker):
Expand Down
6 changes: 3 additions & 3 deletions test/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,14 @@ def test_requests_timed_out(conn):
# No in-flight requests, not timed out
assert not conn.requests_timed_out()

# Single request, timestamp = now (0)
conn.in_flight_requests[0] = ('foo', 0)
# Single request, timeout_at > now (0)
conn.in_flight_requests[0] = ('foo', 0, 1)
assert not conn.requests_timed_out()

# Add another request w/ timestamp > request_timeout ago
request_timeout = conn.config['request_timeout_ms']
expired_timestamp = 0 - request_timeout - 1
conn.in_flight_requests[1] = ('bar', expired_timestamp)
conn.in_flight_requests[1] = ('bar', 0, expired_timestamp)
assert conn.requests_timed_out()

# Drop the expired request and we should be good to go again
Expand Down