From a4b439141b2ef35951e46716696e4c01bb88661c Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 29 Mar 2015 16:45:20 -0700 Subject: [PATCH 1/2] Rollover KafkaClient correlation ids at 2**31 to keep within int32 protocol encoding --- kafka/client.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/kafka/client.py b/kafka/client.py index 48a534e41..c36cd0848 100644 --- a/kafka/client.py +++ b/kafka/client.py @@ -2,7 +2,6 @@ import collections import copy import functools -import itertools import logging import time import kafka.common @@ -23,17 +22,18 @@ class KafkaClient(object): CLIENT_ID = b"kafka-python" - ID_GEN = itertools.count() # NOTE: The timeout given to the client should always be greater than the # one passed to SimpleConsumer.get_message(), otherwise you can get a # socket timeout. def __init__(self, hosts, client_id=CLIENT_ID, - timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS): + timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS, + correlation_id=0): # We need one connection to bootstrap self.client_id = kafka_bytestring(client_id) self.timeout = timeout self.hosts = collect_hosts(hosts) + self.correlation_id = correlation_id # create connections only when we need them self.conns = {} @@ -98,10 +98,10 @@ def _get_leader_for_partition(self, topic, partition): return self.brokers[meta.leader] def _next_id(self): - """ - Generate a new correlation id - """ - return next(KafkaClient.ID_GEN) + """Generate a new correlation id""" + # modulo to keep w/i int32 + self.correlation_id = (self.correlation_id + 1) % 2**31 + return self.correlation_id def _send_broker_unaware_request(self, payloads, encoder_fn, decoder_fn): """ From 1313388662d509ade01f71d0740cd0efe263c01f Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 29 Mar 2015 18:52:12 -0700 Subject: [PATCH 2/2] Add test for correlation_id rollover --- test/test_client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_client.py b/test/test_client.py index c522d9a53..abda421c6 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -401,3 +401,11 @@ def _timeout(*args, **kwargs): with self.assertRaises(ConnectionError): KafkaConnection("nowhere", 1234, 1.0) self.assertGreaterEqual(t.interval, 1.0) + + def test_correlation_rollover(self): + with patch.object(KafkaClient, 'load_metadata_for_topics'): + big_num = 2**31 - 3 + client = KafkaClient(hosts=[], correlation_id=big_num) + self.assertEqual(big_num + 1, client._next_id()) + self.assertEqual(big_num + 2, client._next_id()) + self.assertEqual(0, client._next_id())