Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,8 @@ def _handle_api_versions_response(self, future, response):
future.failure(error_type())
if error_type is Errors.UnsupportedVersionError:
self._api_versions_idx -= 1
for api_key, min_version, max_version, *rest in response.api_versions:
for api_version_data in response.api_versions:
api_key, min_version, max_version = api_version_data[:3]
# If broker provides a lower max_version, skip to that
if api_key == response.API_KEY:
self._api_versions_idx = min(self._api_versions_idx, max_version)
Expand All @@ -607,8 +608,8 @@ def _handle_api_versions_response(self, future, response):
self.close(error=error_type())
return
self._api_versions = dict([
(api_key, (min_version, max_version))
for api_key, min_version, max_version, *rest in response.api_versions
(api_version_data[0], (api_version_data[1], api_version_data[2]))
for api_version_data in response.api_versions
])
self._api_version = self._infer_broker_version_from_api_versions(self._api_versions)
log.info('%s: Broker version identified as %s', self, '.'.join(map(str, self._api_version)))
Expand Down
3 changes: 3 additions & 0 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,9 @@ def _maybe_skip_record(self, record):
def __bool__(self):
return self.record_iterator is not None

# py2
__nonzero__ = __bool__

def drain(self):
if self.record_iterator is not None:
self.record_iterator = None
Expand Down
6 changes: 5 additions & 1 deletion kafka/consumer/subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ def resume(self, partition):

def move_partition_to_end(self, partition):
if partition in self.assignment:
self.assignment.move_to_end(partition)
try:
self.assignment.move_to_end(partition)
except AttributeError:
state = self.assignment.pop(partition)
self.assignment[partition] = state


class TopicPartitionState(object):
Expand Down
5 changes: 5 additions & 0 deletions test/record/test_default_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from kafka.errors import UnsupportedCodecError

from test.testutil import maybe_skip_unsupported_compression


@pytest.mark.parametrize("compression_type", [
DefaultRecordBatch.CODEC_NONE,
Expand All @@ -19,6 +21,7 @@
DefaultRecordBatch.CODEC_LZ4
])
def test_read_write_serde_v2(compression_type):
maybe_skip_unsupported_compression(compression_type)
builder = DefaultRecordBatchBuilder(
magic=2, compression_type=compression_type, is_transactional=1,
producer_id=123456, producer_epoch=123, base_sequence=9999,
Expand Down Expand Up @@ -186,6 +189,8 @@ def test_default_batch_size_limit():
])
@pytest.mark.parametrize("magic", [0, 1])
def test_unavailable_codec(magic, compression_type, name, checker_name):
if not getattr(kafka.codec, checker_name)():
pytest.skip('%s compression_type not installed' % (compression_type,))
builder = DefaultRecordBatchBuilder(
magic=2, compression_type=compression_type, is_transactional=0,
producer_id=-1, producer_epoch=-1, base_sequence=-1,
Expand Down
4 changes: 4 additions & 0 deletions test/record/test_legacy_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import kafka.codec
from kafka.errors import UnsupportedCodecError

from test.testutil import maybe_skip_unsupported_compression


@pytest.mark.parametrize("magic", [0, 1])
def test_read_write_serde_v0_v1_no_compression(magic):
Expand Down Expand Up @@ -39,6 +41,7 @@ def test_read_write_serde_v0_v1_no_compression(magic):
])
@pytest.mark.parametrize("magic", [0, 1])
def test_read_write_serde_v0_v1_with_compression(compression_type, magic):
maybe_skip_unsupported_compression(compression_type)
builder = LegacyRecordBatchBuilder(
magic=magic, compression_type=compression_type, batch_size=9999999)
for offset in range(10):
Expand Down Expand Up @@ -179,6 +182,7 @@ def test_legacy_batch_size_limit(magic):
])
@pytest.mark.parametrize("magic", [0, 1])
def test_unavailable_codec(magic, compression_type, name, checker_name):
maybe_skip_unsupported_compression(compression_type)
builder = LegacyRecordBatchBuilder(
magic=magic, compression_type=compression_type, batch_size=1024)
builder.append(0, timestamp=None, key=None, value=b"M")
Expand Down
3 changes: 3 additions & 0 deletions test/record/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from kafka.record import MemoryRecords, MemoryRecordsBuilder
from kafka.errors import CorruptRecordException

from test.testutil import maybe_skip_unsupported_compression

# This is real live data from Kafka 11 broker
record_batch_data_v2 = [
# First Batch value == "123"
Expand Down Expand Up @@ -179,6 +181,7 @@ def test_memory_records_corrupt():
@pytest.mark.parametrize("compression_type", [0, 1, 2, 3])
@pytest.mark.parametrize("magic", [0, 1, 2])
def test_memory_records_builder(magic, compression_type):
maybe_skip_unsupported_compression(compression_type)
builder = MemoryRecordsBuilder(
magic=magic, compression_type=compression_type, batch_size=1024 * 10)
base_size = builder.size_in_bytes() # V2 has a header before
Expand Down
5 changes: 4 additions & 1 deletion test/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts
from kafka.future import Future
from kafka.protocol.api import RequestHeader
from kafka.protocol.group import HeartbeatResponse
from kafka.protocol.metadata import MetadataRequest
Expand Down Expand Up @@ -69,8 +70,10 @@ def test_connect(_socket, conn, states):
assert conn.state is state


def test_api_versions_check(_socket):
def test_api_versions_check(_socket, mocker):
conn = BrokerConnection('localhost', 9092, socket.AF_INET)
mocker.patch.object(conn, '_send', return_value=Future())
mocker.patch.object(conn, 'recv', return_value=[])
assert conn._api_versions_future is None
conn.connect()
assert conn._api_versions_future is not None
Expand Down
3 changes: 2 additions & 1 deletion test/test_consumer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def consumer_thread(i):

num_consumers = 4
for i in range(num_consumers):
t = threading.Thread(target=consumer_thread, args=(i,), daemon=True)
t = threading.Thread(target=consumer_thread, args=(i,))
t.daemon = True
t.start()
threads[i] = t

Expand Down
24 changes: 19 additions & 5 deletions test/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,25 @@
from kafka.structs import OffsetAndMetadata, TopicPartition
from kafka.util import WeakMethod


@pytest.fixture
def client(conn):
return KafkaClient(api_version=(0, 9))
def client(conn, mocker):
cli = KafkaClient(api_version=(0, 9))
mocker.patch.object(cli, '_init_connect', return_value=True)
try:
yield cli
finally:
cli._close()

@pytest.fixture
def coordinator(client):
return ConsumerCoordinator(client, SubscriptionState(), Metrics())
def coordinator(client, mocker):
metrics = Metrics()
coord = ConsumerCoordinator(client, SubscriptionState(), metrics)
try:
yield coord
finally:
mocker.patch.object(coord, 'coordinator_unknown', return_value=True) # avoid attempting to leave group during close()
coord.close(timeout_ms=0)
metrics.close()


def test_init(client, coordinator):
Expand All @@ -55,6 +66,7 @@ def test_autocommit_enable_api_version(conn, api_version):
assert coordinator.config['enable_auto_commit'] is False
else:
assert coordinator.config['enable_auto_commit'] is True
coordinator.close()


def test_protocol_type(coordinator):
Expand Down Expand Up @@ -117,6 +129,7 @@ def test_pattern_subscription(conn, api_version):
else:
assert set(coordinator._subscription.assignment.keys()) == {TopicPartition('foo1', 0),
TopicPartition('foo2', 0)}
coordinator.close()


def test_lookup_assignor(coordinator):
Expand Down Expand Up @@ -398,6 +411,7 @@ def test_maybe_auto_commit_offsets_sync(mocker, api_version, group_id, enable,
assert commit_sync.call_count == (1 if commit_offsets else 0)
assert mock_warn.call_count == (1 if warn else 0)
assert mock_exc.call_count == (1 if exc else 0)
coordinator.close()


@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion test/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from kafka import KafkaConsumer, KafkaProducer, TopicPartition
from kafka.producer.buffer import SimpleBufferPool
from test.testutil import env_kafka_version, random_string
from test.testutil import env_kafka_version, random_string, maybe_skip_unsupported_compression


def test_buffer_pool():
Expand Down Expand Up @@ -44,6 +44,7 @@ def consumer_factory(**kwargs):
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd'])
def test_end_to_end(kafka_broker, compression):
maybe_skip_unsupported_compression(compression)
if compression == 'lz4':
if env_kafka_version() < (0, 8, 2):
pytest.skip('LZ4 requires 0.8.2')
Expand Down Expand Up @@ -104,6 +105,7 @@ def test_kafka_producer_gc_cleanup():
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd'])
def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
maybe_skip_unsupported_compression(compression)
if compression == 'zstd' and env_kafka_version() < (2, 1, 0):
pytest.skip('zstd requires 2.1.0 or more')
connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)])
Expand Down
4 changes: 2 additions & 2 deletions test/test_subscription_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_assign_from_subscribed():

s.assign_from_subscribed([TopicPartition('foo', 0), TopicPartition('foo', 1)])
assert set(s.assignment.keys()) == set([TopicPartition('foo', 0), TopicPartition('foo', 1)])
assert all([isinstance(s, TopicPartitionState) for s in six.itervalues(s.assignment)])
assert all([not s.has_valid_position for s in six.itervalues(s.assignment)])
assert all([isinstance(tps, TopicPartitionState) for tps in six.itervalues(s.assignment)])
assert all([not tps.has_valid_position for tps in six.itervalues(s.assignment)])


def test_change_subscription_after_assignment():
Expand Down
16 changes: 16 additions & 0 deletions test/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import string
import time

import pytest

import kafka.codec


def special_to_underscore(string, _matcher=re.compile(r'[^a-zA-Z0-9_]+')):
return _matcher.sub('_', string)
Expand Down Expand Up @@ -36,6 +40,18 @@ def assert_message_count(messages, num_messages):
assert len(unique_messages) == num_messages, 'Expected %d unique messages, got %d' % (num_messages, len(unique_messages))


def maybe_skip_unsupported_compression(compression_type):
codecs = {1: 'gzip', 2: 'snappy', 3: 'lz4', 4: 'zstd'}
if not compression_type:
return
elif compression_type in codecs:
compression_type = codecs[compression_type]

checker = getattr(kafka.codec, 'has_' + compression_type, None)
if checker and not checker():
pytest.skip("Compression libraries not installed for %s" % (compression_type,))


class Timer(object):
def __enter__(self):
self.start = time.time()
Expand Down
Loading