Skip to content

Implement client-side connection throttling / KIP-219 #2510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions kafka/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,16 @@ def connection_delay(self, node_id):
return 0
return conn.connection_delay()

def throttle_delay(self, node_id):
"""
Return the number of milliseconds to wait until a broker is no longer throttled.
When disconnected / connecting, returns 0.
"""
conn = self._conns.get(node_id)
if conn is None:
return 0
return conn.throttle_delay()

def is_ready(self, node_id, metadata_priority=True):
"""Check whether a node is ready to send more requests.

Expand Down Expand Up @@ -793,16 +803,17 @@ def _fire_pending_completed_requests(self):
break
future.success(response)
responses.append(response)

return responses

def least_loaded_node(self):
"""Choose the node with fewest outstanding requests, with fallbacks.

This method will prefer a node with an existing connection and no
in-flight-requests. If no such node is found, a node will be chosen
randomly from disconnected nodes that are not "blacked out" (i.e.,
This method will prefer a node with an existing connection (not throttled)
with no in-flight-requests. If no such node is found, a node will be chosen
randomly from all nodes that are not throttled or "blacked out" (i.e.,
are not subject to a reconnect backoff). If no node metadata has been
obtained, will return a bootstrap node (subject to exponential backoff).
obtained, will return a bootstrap node.

Returns:
node_id or None if no suitable node was found
Expand All @@ -814,11 +825,11 @@ def least_loaded_node(self):
found = None
for node_id in nodes:
conn = self._conns.get(node_id)
connected = conn is not None and conn.connected()
blacked_out = conn is not None and conn.blacked_out()
connected = conn is not None and conn.connected() and conn.can_send_more()
blacked_out = conn is not None and (conn.blacked_out() or conn.throttled())
curr_inflight = len(conn.in_flight_requests) if conn is not None else 0
if connected and curr_inflight == 0:
# if we find an established connection
# if we find an established connection (not throttled)
# with no in-flight requests, we can stop right away
return node_id
elif not blacked_out and curr_inflight < inflight:
Expand All @@ -828,16 +839,23 @@ def least_loaded_node(self):

return found

def _refresh_delay_ms(self, node_id):
conn = self._conns.get(node_id)
if conn is not None and conn.connected():
return self.throttle_delay(node_id)
else:
return self.connection_delay(node_id)

def least_loaded_node_refresh_ms(self):
"""Return connection delay in milliseconds for next available node.
"""Return connection or throttle delay in milliseconds for next available node.

This method is used primarily for retry/backoff during metadata refresh
during / after a cluster outage, in which there are no available nodes.

Returns:
float: delay_ms
"""
return min([self.connection_delay(broker.nodeId) for broker in self.cluster.brokers()])
return min([self._refresh_delay_ms(broker.nodeId) for broker in self.cluster.brokers()])

def set_topics(self, topics):
"""Set specific topics to track for metadata.
Expand Down Expand Up @@ -915,8 +933,8 @@ def _maybe_refresh_metadata(self, wakeup=False):
# Connection attempt failed immediately, need to retry with a different node
return self.config['reconnect_backoff_ms']
else:
# Existing connection with max in flight requests. Wait for request to complete.
return self.config['request_timeout_ms']
# Existing connection throttled or max in flight requests.
return self.throttle_delay(node_id) or self.config['request_timeout_ms']

# Recheck node_id in case we were able to connect immediately above
if self._can_send_request(node_id):
Expand Down
3 changes: 0 additions & 3 deletions kafka/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@ def update_metadata(self, metadata):

Returns: None
"""
if metadata.API_VERSION >= 3 and metadata.throttle_time_ms > 0:
log.warning("MetadataRequest throttled by broker (%d ms)", metadata.throttle_time_ms)

# In the common case where we ask for a single topic and get back an
# error, we should fail the future
if len(metadata.topics) == 1 and metadata.topics[0][0] != Errors.NoError.errno:
Expand Down
70 changes: 69 additions & 1 deletion kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(self, host, port, afi, **configs):
self._sock_afi = afi
self._sock_addr = None
self._api_versions = None
self._throttle_time = None

self.config = copy.copy(self.DEFAULT_CONFIG)
for key in self.config:
Expand Down Expand Up @@ -851,6 +852,27 @@ def blacked_out(self):
return self.connection_delay() > 0
return False

def throttled(self):
"""
Return True if we are connected but currently throttled.
"""
if self.state is not ConnectionStates.CONNECTED:
return False
return self.throttle_delay() > 0

def throttle_delay(self):
"""
Return the number of milliseconds to wait until connection is no longer throttled.
"""
if self._throttle_time is not None:
remaining_ms = (self._throttle_time - time.time()) * 1000
if remaining_ms > 0:
return remaining_ms
else:
self._throttle_time = None
return 0
return 0

def connection_delay(self):
"""
Return the number of milliseconds to wait, based on the connection
Expand Down Expand Up @@ -976,6 +998,9 @@ def send(self, request, blocking=True, request_timeout_ms=None):
elif not self.connected():
return future.failure(Errors.KafkaConnectionError(str(self)))
elif not self.can_send_more():
# very small race here, but prefer it over breaking abstraction to check self._throttle_time
if self.throttled():
return future.failure(Errors.ThrottlingQuotaExceededError(str(self)))
return future.failure(Errors.TooManyInFlightRequests(str(self)))
return self._send(request, blocking=blocking, request_timeout_ms=request_timeout_ms)

Expand Down Expand Up @@ -1063,8 +1088,26 @@ def send_pending_requests_v2(self):
self.close(error=error)
return False

def _maybe_throttle(self, response):
throttle_time_ms = getattr(response, 'throttle_time_ms', 0)
if self._sensors:
self._sensors.throttle_time.record(throttle_time_ms)
if not throttle_time_ms:
if self._throttle_time is not None:
self._throttle_time = None
return
# Client side throttling enabled in v2.0 brokers
# prior to that throttling (if present) was managed broker-side
if self.config['api_version'] is not None and self.config['api_version'] >= (2, 0):
throttle_time = time.time() + throttle_time_ms / 1000
self._throttle_time = max(throttle_time, self._throttle_time or 0)
log.warning("%s: %s throttled by broker (%d ms)", self,
response.__class__.__name__, throttle_time_ms)

def can_send_more(self):
"""Return True unless there are max_in_flight_requests_per_connection."""
"""Check for throttling / quota violations and max in-flight-requests"""
if self.throttle_delay() > 0:
return False
max_ifrs = self.config['max_in_flight_requests_per_connection']
return len(self.in_flight_requests) < max_ifrs

Expand Down Expand Up @@ -1097,6 +1140,7 @@ def recv(self):
self._sensors.request_time.record(latency_ms)

log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response)
self._maybe_throttle(response)
responses[i] = (response, future)

return responses
Expand Down Expand Up @@ -1399,6 +1443,16 @@ def __init__(self, metrics, metric_group_prefix, node_id):
'The maximum request latency in ms.'),
Max())

throttle_time = metrics.sensor('throttle-time')
throttle_time.add(metrics.metric_name(
'throttle-time-avg', metric_group_name,
'The average throttle time in ms.'),
Avg())
throttle_time.add(metrics.metric_name(
'throttle-time-max', metric_group_name,
'The maximum throttle time in ms.'),
Max())

# if one sensor of the metrics has been registered for the connection,
# then all other sensors should have been registered; and vice versa
node_str = 'node-{0}'.format(node_id)
Expand Down Expand Up @@ -1450,9 +1504,23 @@ def __init__(self, metrics, metric_group_prefix, node_id):
'The maximum request latency in ms.'),
Max())

throttle_time = metrics.sensor(
node_str + '.throttle',
parents=[metrics.get_sensor('throttle-time')])
throttle_time.add(metrics.metric_name(
'throttle-time-avg', metric_group_name,
'The average throttle time in ms.'),
Avg())
throttle_time.add(metrics.metric_name(
'throttle-time-max', metric_group_name,
'The maximum throttle time in ms.'),
Max())


self.bytes_sent = metrics.sensor(node_str + '.bytes-sent')
self.bytes_received = metrics.sensor(node_str + '.bytes-received')
self.request_time = metrics.sensor(node_str + '.latency')
self.throttle_time = metrics.sensor(node_str + '.throttle')


def _address_family(address):
Expand Down
23 changes: 12 additions & 11 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,6 @@ def _handle_list_offsets_response(self, future, response):
Raises:
AssertionError: if response does not match partition
"""
if response.API_VERSION >= 2 and response.throttle_time_ms > 0:
log.warning("ListOffsetsRequest throttled by broker (%d ms)", response.throttle_time_ms)
timestamp_offset_map = {}
for topic, part_data in response.topics:
for partition_info in part_data:
Expand Down Expand Up @@ -688,7 +686,7 @@ def _create_fetch_requests(self):
"""
# create the fetch info as a dict of lists of partition info tuples
# which can be passed to FetchRequest() via .items()
version = self._client.api_version(FetchRequest, max_version=7)
version = self._client.api_version(FetchRequest, max_version=8)
fetchable = collections.defaultdict(dict)

for partition in self._fetchable_partitions():
Expand Down Expand Up @@ -816,8 +814,6 @@ def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response):
)
self._completed_fetches.append(completed_fetch)

if response.API_VERSION >= 1:
self._sensors.fetch_throttle_time_sensor.record(response.throttle_time_ms)
self._sensors.fetch_latency.record((time.time() - send_time) * 1000)

def _handle_fetch_error(self, node_id, exception):
Expand Down Expand Up @@ -1032,6 +1028,11 @@ def handle_response(self, response):
self.node_id, len(response_tps))
self.next_metadata = FetchMetadata.INITIAL
return True
elif response.session_id == FetchMetadata.THROTTLED_SESSION_ID:
log.debug("Node %s sent a empty full fetch response due to a quota violation (%s partitions)",
self.node_id, len(response_tps))
# Keep current metadata
return True
else:
# The server created a new incremental fetch session.
log.debug("Node %s sent a full fetch response that created a new incremental fetch session %s"
Expand All @@ -1054,6 +1055,11 @@ def handle_response(self, response):
len(response_tps), len(self.session_partitions) - len(response_tps))
self.next_metadata = FetchMetadata.INITIAL
return True
elif response.session_id == FetchMetadata.THROTTLED_SESSION_ID:
log.debug("Node %s sent a empty incremental fetch response due to a quota violation (%s partitions)",
self.node_id, len(response_tps))
# Keep current metadata
return True
else:
# The incremental fetch session was continued by the server.
log.debug("Node %s sent an incremental fetch response for session %s"
Expand All @@ -1077,6 +1083,7 @@ class FetchMetadata(object):

MAX_EPOCH = 2147483647
INVALID_SESSION_ID = 0 # used by clients with no session.
THROTTLED_SESSION_ID = -1 # returned with empty response on quota violation
INITIAL_EPOCH = 0 # client wants to create or recreate a session.
FINAL_EPOCH = -1 # client wants to close any existing session, and not create a new one.

Expand Down Expand Up @@ -1217,12 +1224,6 @@ def __init__(self, metrics, prefix):
self.records_fetch_lag.add(metrics.metric_name('records-lag-max', self.group_name,
'The maximum lag in terms of number of records for any partition in self window'), Max())

self.fetch_throttle_time_sensor = metrics.sensor('fetch-throttle-time')
self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-avg', self.group_name,
'The average throttle time in ms'), Avg())
self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-max', self.group_name,
'The maximum throttle time in ms'), Max())

def record_topic_fetch_metrics(self, topic, num_bytes, num_records):
# record bytes fetched
name = '.'.join(['topic', topic, 'bytes-fetched'])
Expand Down
31 changes: 0 additions & 31 deletions kafka/coordinator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,6 @@ def _failed_request(self, node_id, request, future, error):
future.failure(error)

def _handle_join_group_response(self, future, send_time, response):
if response.API_VERSION >= 2:
self.sensors.throttle_time.record(response.throttle_time_ms)
if response.throttle_time_ms > 0:
log.warning("JoinGroupRequest throttled by broker (%d ms)", response.throttle_time_ms)

error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
log.debug("Received successful JoinGroup response for group %s: %s",
Expand Down Expand Up @@ -614,11 +609,6 @@ def _send_sync_group_request(self, request):
return future

def _handle_sync_group_response(self, future, send_time, response):
if response.API_VERSION >= 1:
self.sensors.throttle_time.record(response.throttle_time_ms)
if response.throttle_time_ms > 0:
log.warning("SyncGroupRequest throttled by broker (%d ms)", response.throttle_time_ms)

error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
self.sensors.sync_latency.record((time.time() - send_time) * 1000)
Expand Down Expand Up @@ -678,9 +668,6 @@ def _send_group_coordinator_request(self):
return future

def _handle_group_coordinator_response(self, future, response):
if response.API_VERSION >= 1 and response.throttle_time_ms > 0:
log.warning("FindCoordinatorRequest throttled by broker (%d ms)", response.throttle_time_ms)

log.debug("Received group coordinator response %s", response)

error_type = Errors.for_code(response.error_code)
Expand Down Expand Up @@ -785,11 +772,6 @@ def maybe_leave_group(self):
self.reset_generation()

def _handle_leave_group_response(self, response):
if response.API_VERSION >= 1:
self.sensors.throttle_time.record(response.throttle_time_ms)
if response.throttle_time_ms > 0:
log.warning("LeaveGroupRequest throttled by broker (%d ms)", response.throttle_time_ms)

error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
log.debug("LeaveGroup request for group %s returned successfully",
Expand Down Expand Up @@ -821,11 +803,6 @@ def _send_heartbeat_request(self):
return future

def _handle_heartbeat_response(self, future, send_time, response):
if response.API_VERSION >= 1:
self.sensors.throttle_time.record(response.throttle_time_ms)
if response.throttle_time_ms > 0:
log.warning("HeartbeatRequest throttled by broker (%d ms)", response.throttle_time_ms)

self.sensors.heartbeat_latency.record((time.time() - send_time) * 1000)
error_type = Errors.for_code(response.error_code)
if error_type is Errors.NoError:
Expand Down Expand Up @@ -914,14 +891,6 @@ def __init__(self, heartbeat, metrics, prefix, tags=None):
tags), AnonMeasurable(
lambda _, now: (now / 1000) - self.heartbeat.last_send))

self.throttle_time = metrics.sensor('throttle-time')
self.throttle_time.add(metrics.metric_name(
'throttle-time-avg', self.metric_group_name,
'The average throttle time in ms'), Avg())
self.throttle_time.add(metrics.metric_name(
'throttle-time-max', self.metric_group_name,
'The maximum throttle time in ms'), Max())


class HeartbeatThread(threading.Thread):
def __init__(self, coordinator):
Expand Down
6 changes: 0 additions & 6 deletions kafka/coordinator/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,9 +665,6 @@ def _send_offset_commit_request(self, offsets):
return future

def _handle_offset_commit_response(self, offsets, future, send_time, response):
if response.API_VERSION >= 3 and response.throttle_time_ms > 0:
log.warning("OffsetCommitRequest throttled by broker (%d ms)", response.throttle_time_ms)

# TODO look at adding request_latency_ms to response (like java kafka)
self.consumer_sensors.commit_latency.record((time.time() - send_time) * 1000)
unauthorized_topics = set()
Expand Down Expand Up @@ -785,9 +782,6 @@ def _send_offset_fetch_request(self, partitions):
return future

def _handle_offset_fetch_response(self, future, response):
if response.API_VERSION >= 3 and response.throttle_time_ms > 0:
log.warning("OffsetFetchRequest throttled by broker (%d ms)", response.throttle_time_ms)

if response.API_VERSION >= 2 and response.error_code != Errors.NoError.errno:
error_type = Errors.for_code(response.error_code)
log.debug("Offset fetch failed: %s", error_type.__name__)
Expand Down
Loading