diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index 07a1a109d..77742109b 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -6,6 +6,12 @@ from collections import Sequence except ImportError: from collections.abc import Sequence +try: + # enum in stdlib as of py3.4 + from enum import IntEnum # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor.enum34 import IntEnum import logging import random import re @@ -20,6 +26,13 @@ log = logging.getLogger(__name__) +class SubscriptionType(IntEnum): + NONE = 0 + AUTO_TOPICS = 1 + AUTO_PATTERN = 2 + USER_ASSIGNED = 3 + + class SubscriptionState(object): """ A class for tracking the topics, partitions, and offsets for the consumer. @@ -67,6 +80,7 @@ def __init__(self, offset_reset_strategy='earliest'): self._default_offset_reset_strategy = offset_reset_strategy self.subscription = None # set() or None + self.subscription_type = SubscriptionType.NONE self.subscribed_pattern = None # regex str or None self._group_subscription = set() self._user_assignment = set() @@ -76,6 +90,14 @@ def __init__(self, offset_reset_strategy='earliest'): # initialize to true for the consumers to fetch offset upon starting up self.needs_fetch_committed_offsets = True + def _set_subscription_type(self, subscription_type): + if not isinstance(subscription_type, SubscriptionType): + raise ValueError('SubscriptionType enum required') + if self.subscription_type == SubscriptionType.NONE: + self.subscription_type = subscription_type + elif self.subscription_type != subscription_type: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + def subscribe(self, topics=(), pattern=None, listener=None): """Subscribe to a list of topics, or a topic regex pattern. @@ -111,17 +133,19 @@ def subscribe(self, topics=(), pattern=None, listener=None): guaranteed, however, that the partitions revoked/assigned through this interface are from topics subscribed in this call. """ - if self._user_assignment or (topics and pattern): - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) assert topics or pattern, 'Must provide topics or pattern' + if (topics and pattern): + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - if pattern: + elif pattern: + self._set_subscription_type(SubscriptionType.AUTO_PATTERN) log.info('Subscribing to pattern: /%s/', pattern) self.subscription = set() self.subscribed_pattern = re.compile(pattern) else: if isinstance(topics, str) or not isinstance(topics, Sequence): raise TypeError('Topics must be a list (or non-str sequence)') + self._set_subscription_type(SubscriptionType.AUTO_TOPICS) self.change_subscription(topics) if listener and not isinstance(listener, ConsumerRebalanceListener): @@ -141,7 +165,7 @@ def change_subscription(self, topics): - a topic name is '.' or '..' or - a topic name does not consist of ASCII-characters/'-'/'_'/'.' """ - if self._user_assignment: + if not self.partitions_auto_assigned(): raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) if isinstance(topics, six.string_types): @@ -168,13 +192,13 @@ def group_subscribe(self, topics): Arguments: topics (list of str): topics to add to the group subscription """ - if self._user_assignment: + if not self.partitions_auto_assigned(): raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) self._group_subscription.update(topics) def reset_group_subscription(self): """Reset the group's subscription to only contain topics subscribed by this consumer.""" - if self._user_assignment: + if not self.partitions_auto_assigned(): raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) assert self.subscription is not None, 'Subscription required' self._group_subscription.intersection_update(self.subscription) @@ -197,9 +221,7 @@ def assign_from_user(self, partitions): Raises: IllegalStateError: if consumer has already called subscribe() """ - if self.subscription is not None: - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - + self._set_subscription_type(SubscriptionType.USER_ASSIGNED) if self._user_assignment != set(partitions): self._user_assignment = set(partitions) self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState()) @@ -250,6 +272,7 @@ def unsubscribe(self): self._user_assignment.clear() self.assignment.clear() self.subscribed_pattern = None + self.subscription_type = SubscriptionType.NONE def group_subscription(self): """Get the topic subscription for the group. @@ -300,7 +323,7 @@ def fetchable_partitions(self): def partitions_auto_assigned(self): """Return True unless user supplied partitions manually.""" - return self.subscription is not None + return self.subscription_type in (SubscriptionType.AUTO_TOPICS, SubscriptionType.AUTO_PATTERN) def all_consumed_offsets(self): """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}""" diff --git a/test/test_consumer_integration.py b/test/test_consumer_integration.py index af8ec6829..b181845a4 100644 --- a/test/test_consumer_integration.py +++ b/test/test_consumer_integration.py @@ -68,8 +68,8 @@ def test_kafka_consumer_unsupported_encoding( def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): TIMEOUT_MS = 500 consumer = kafka_consumer_factory(auto_offset_reset='earliest', - enable_auto_commit=False, - consumer_timeout_ms=TIMEOUT_MS) + enable_auto_commit=False, + consumer_timeout_ms=TIMEOUT_MS) # Manual assignment avoids overhead of consumer group mgmt consumer.unsubscribe() diff --git a/test/test_coordinator.py b/test/test_coordinator.py index 09422790e..35749f84d 100644 --- a/test/test_coordinator.py +++ b/test/test_coordinator.py @@ -189,6 +189,7 @@ def test_subscription_listener_failure(mocker, coordinator): def test_perform_assignment(mocker, coordinator): + coordinator._subscription.subscribe(topics=['foo1']) member_metadata = { 'member-foo': ConsumerProtocolMemberMetadata(0, ['foo1'], b''), 'member-bar': ConsumerProtocolMemberMetadata(0, ['foo1'], b'')