From b7d2f4fd3ba0b7ea78c9470a309a81910e9dd209 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Fri, 4 Apr 2025 14:16:07 -0700 Subject: [PATCH 01/27] Add transactional kwargs to MemoryRecordsBuilder --- kafka/record/default_records.py | 6 +++++- kafka/record/memory_records.py | 25 +++++++++++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index c8305c88e..465d3e3cf 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -493,10 +493,14 @@ def __init__( self._buffer = bytearray(self.HEADER_STRUCT.size) - def set_producer_state(self, producer_id, producer_epoch, base_sequence): + def set_producer_state(self, producer_id, producer_epoch, base_sequence, is_transactional): + assert not is_transactional or producer_id != -1, "Cannot write transactional messages without a valid producer ID" + assert producer_id == -1 or producer_epoch != -1, "Invalid negative producer epoch" + assert producer_id == -1 or base_sequence != -1, "Invalid negative sequence number" self._producer_id = producer_id self._producer_epoch = producer_epoch self._base_sequence = base_sequence + self._is_transactional = is_transactional @property def producer_id(self): diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index 77e38b9ed..c71b3bd4c 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -115,16 +115,23 @@ class MemoryRecordsBuilder(object): __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", "_magic", "_bytes_written", "_producer_id") - def __init__(self, magic, compression_type, batch_size, offset=0): + def __init__(self, magic, compression_type, batch_size, offset=0, + transactional=False, producer_id=-1, producer_epoch=-1, base_sequence=-1): assert magic in [0, 1, 2], "Not supported magic" assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" if magic >= 2: + assert not transactional or producer_id != -1, "Cannot write transactional messages without a valid producer ID" + assert producer_id == -1 or producer_epoch != -1, "Invalid negative producer epoch" + assert producer_id == -1 or base_sequence != -1, "Invalid negative sequence number used" + self._builder = DefaultRecordBatchBuilder( magic=magic, compression_type=compression_type, - is_transactional=False, producer_id=-1, producer_epoch=-1, - base_sequence=-1, batch_size=batch_size) - self._producer_id = -1 + is_transactional=transactional, producer_id=producer_id, + producer_epoch=producer_epoch, base_sequence=base_sequence, + batch_size=batch_size) + self._producer_id = producer_id else: + assert not transactional and producer_id == -1, "Idempotent messages are not supported for magic %s" % (magic,) self._builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=batch_size) @@ -158,7 +165,7 @@ def append(self, timestamp, key, value, headers=[]): self._next_offset += 1 return metadata - def set_producer_state(self, producer_id, producer_epoch, base_sequence): + def set_producer_state(self, producer_id, producer_epoch, base_sequence, is_transactional): if self._magic < 2: raise UnsupportedVersionError('Producer State requires Message format v2+') elif self._closed: @@ -167,15 +174,17 @@ def set_producer_state(self, producer_id, producer_epoch, base_sequence): # be re queued. In this case, we should not attempt to set the state again, since changing the pid and sequence # once a batch has been sent to the broker risks introducing duplicates. raise IllegalStateError("Trying to set producer state of an already closed batch. This indicates a bug on the client.") - self._builder.set_producer_state(producer_id, producer_epoch, base_sequence) + self._builder.set_producer_state(producer_id, producer_epoch, base_sequence, is_transactional) self._producer_id = producer_id @property def producer_id(self): - if self._magic < 2: - raise UnsupportedVersionError('Producer State requires Message format v2+') return self._producer_id + @property + def producer_epoch(self): + return self._producer_epoch + def close(self): # This method may be called multiple times on the same batch # i.e., on retries From f9b7c0030683820a4127a53556f03aae738846de Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 12:42:21 -0700 Subject: [PATCH 02/27] DefaultRecordBatch.has_sequence / ProducerBatch.has_sequence --- kafka/producer/record_accumulator.py | 4 ++++ kafka/record/default_records.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 6490f48aa..75514102b 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -56,6 +56,10 @@ def record_count(self): def producer_id(self): return self.records.producer_id if self.records else None + @property + def has_sequence(self): + return self.records.has_sequence if self.records else False + def try_append(self, timestamp_ms, key, value, headers, now=None): metadata = self.records.append(timestamp_ms, key, value, headers) if metadata is None: diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 465d3e3cf..96d60d6cb 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -210,6 +210,10 @@ def producer_epoch(self): def base_sequence(self): return self._header_data[11] + @property + def has_sequence(self): + return self._header_data[11] != -1 # NO_SEQUENCE + @property def last_sequence(self): if self.base_sequence == self.NO_SEQUENCE: From a9434424b50ee1db823c72bd07c510a47138a342 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 14:15:12 -0700 Subject: [PATCH 03/27] Add producer_epoch to records builder --- kafka/record/default_records.py | 4 ++++ kafka/record/memory_records.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 96d60d6cb..c91d62977 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -510,6 +510,10 @@ def set_producer_state(self, producer_id, producer_epoch, base_sequence, is_tran def producer_id(self): return self._producer_id + @property + def producer_epoch(self): + return self._producer_epoch + def _get_attributes(self, include_compression_type=True): attrs = 0 if include_compression_type: diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index c71b3bd4c..4bf3115c8 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -113,7 +113,7 @@ def next_batch(self, _min_slice=MIN_SLICE, class MemoryRecordsBuilder(object): __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", - "_magic", "_bytes_written", "_producer_id") + "_magic", "_bytes_written", "_producer_id", "_producer_epoch") def __init__(self, magic, compression_type, batch_size, offset=0, transactional=False, producer_id=-1, producer_epoch=-1, base_sequence=-1): @@ -130,6 +130,7 @@ def __init__(self, magic, compression_type, batch_size, offset=0, producer_epoch=producer_epoch, base_sequence=base_sequence, batch_size=batch_size) self._producer_id = producer_id + self._producer_epoch = producer_epoch else: assert not transactional and producer_id == -1, "Idempotent messages are not supported for magic %s" % (magic,) self._builder = LegacyRecordBatchBuilder( @@ -196,6 +197,7 @@ def close(self): self._buffer = bytes(self._builder.build()) if self._magic == 2: self._producer_id = self._builder.producer_id + self._producer_epoch = self._builder.producer_epoch self._builder = None self._closed = True From a1145ac5328a60a450c623bccf107afb276f1661 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 12:42:46 -0700 Subject: [PATCH 04/27] Add __str__ to DefaultRecordBatch + Builder --- kafka/record/default_records.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index c91d62977..91d4a9d62 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -360,6 +360,17 @@ def validate_crc(self): verify_crc = calc_crc32c(data_view.tobytes()) return crc == verify_crc + def __str__(self): + return ( + "DefaultRecordBatch(magic={}, base_offset={}, last_offset_delta={}," + " first_timestamp={}, max_timestamp={}," + " is_transactional={}, producer_id={}, producer_epoch={}, base_sequence={}," + " records_count={})".format( + self.magic, self.base_offset, self.last_offset_delta, + self.first_timestamp, self.max_timestamp, + self.is_transactional, self.producer_id, self.producer_epoch, self.base_sequence, + self.records_count)) + class DefaultRecord(ABCRecord): @@ -718,6 +729,17 @@ def estimate_size_in_bytes(cls, key, value, headers): cls.size_of(key, value, headers) ) + def __str__(self): + return ( + "DefaultRecordBatchBuilder(magic={}, base_offset={}, last_offset_delta={}," + " first_timestamp={}, max_timestamp={}," + " is_transactional={}, producer_id={}, producer_epoch={}, base_sequence={}," + " records_count={})".format( + self._magic, 0, self._last_offset, + self._first_timestamp or 0, self._max_timestamp or 0, + self._is_transactional, self._producer_id, self._producer_epoch, self._base_sequence, + self._num_records)) + class DefaultRecordMetadata(object): From c980f0c968be75291be6e062196e50730cf2ed3b Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sat, 5 Apr 2025 12:20:13 -0700 Subject: [PATCH 05/27] TransactionManager - producer only (no consumer offsets) --- kafka/producer/transaction_manager.py | 1047 +++++++++++++++++++++++++ 1 file changed, 1047 insertions(+) create mode 100644 kafka/producer/transaction_manager.py diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py new file mode 100644 index 000000000..dea0456ad --- /dev/null +++ b/kafka/producer/transaction_manager.py @@ -0,0 +1,1047 @@ +from __future__ import absolute_import, division + +import abc +import collections +import heapq +import logging +import threading + +from kafka.vendor import six + +try: + # enum in stdlib as of py3.4 + from enum import IntEnum # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor.enum34 import IntEnum + +import kafka.errors as Errors +from kafka.protocol.add_partitions_to_txn import AddPartitionsToTxnRequest +from kafka.protocol.end_txn import EndTxnRequest +from kafka.protocol.find_coordinator import FindCoordinatorRequest +from kafka.protocol.init_producer_id import InitProducerIdRequest +from kafka.structs import TopicPartition + + +log = logging.getLogger(__name__) + + +NO_PRODUCER_ID = -1 +NO_PRODUCER_EPOCH = -1 +NO_SEQUENCE = -1 + + +class ProducerIdAndEpoch(object): + __slots__ = ('producer_id', 'epoch') + + def __init__(self, producer_id, epoch): + self.producer_id = producer_id + self.epoch = epoch + + @property + def is_valid(self): + return NO_PRODUCER_ID < self.producer_id + + def match(self, batch): + return self.producer_id == batch.producer_id and self.epoch == batch.producer_epoch + + def __str__(self): + return "ProducerIdAndEpoch(producer_id={}, epoch={})".format(self.producer_id, self.epoch) + + +class TransactionState(IntEnum): + UNINITIALIZED = 0 + INITIALIZING = 1 + READY = 2 + IN_TRANSACTION = 3 + COMMITTING_TRANSACTION = 4 + ABORTING_TRANSACTION = 5 + ABORTABLE_ERROR = 6 + FATAL_ERROR = 7 + + @classmethod + def is_transition_valid(cls, source, target): + if target == cls.INITIALIZING: + return source == cls.UNINITIALIZED + elif target == cls.READY: + return source in (cls.INITIALIZING, cls.COMMITTING_TRANSACTION, cls.ABORTING_TRANSACTION) + elif target == cls.IN_TRANSACTION: + return source == cls.READY + elif target == cls.COMMITTING_TRANSACTION: + return source == cls.IN_TRANSACTION + elif target == cls.ABORTING_TRANSACTION: + return source in (cls.IN_TRANSACTION, cls.ABORTABLE_ERROR) + elif target == cls.ABORTABLE_ERROR: + return source in (cls.IN_TRANSACTION, cls.COMMITTING_TRANSACTION, cls.ABORTABLE_ERROR) + elif target == cls.UNINITIALIZED: + # Disallow transitions to UNITIALIZED + return False + elif target == cls.FATAL_ERROR: + # We can transition to FATAL_ERROR unconditionally. + # FATAL_ERROR is never a valid starting state for any transition. So the only option is to close the + # producer or do purely non transactional requests. + return True + + +class Priority(IntEnum): + # We use the priority to determine the order in which requests need to be sent out. For instance, if we have + # a pending FindCoordinator request, that must always go first. Next, If we need a producer id, that must go second. + # The endTxn request must always go last. + FIND_COORDINATOR = 0 + INIT_PRODUCER_ID = 1 + ADD_PARTITIONS_OR_OFFSETS = 2 + END_TXN = 3 + + +class TransactionManager(object): + """ + A class which maintains state for transactions. Also keeps the state necessary to ensure idempotent production. + """ + NO_INFLIGHT_REQUEST_CORRELATION_ID = -1 + # The retry_backoff_ms is overridden to the following value if the first AddPartitions receives a + # CONCURRENT_TRANSACTIONS error. + ADD_PARTITIONS_RETRY_BACKOFF_MS = 20 + + def __init__(self, transactional_id=None, transaction_timeout_ms=0, retry_backoff_ms=100, api_version=(0, 11), metadata=None): + self._api_version = api_version + self._metadata = metadata + # Keep track of the in flight batches bound for a partition, ordered by sequence. This helps us to ensure that + # we continue to order batches by the sequence numbers even when the responses come back out of order during + # leader failover. We add a batch to the queue when it is drained, and remove it when the batch completes + # (either successfully or through a fatal failure). + # use heapq methods to push/pop from queues + self._in_flight_batches_by_sequence = collections.defaultdict(list) + self._in_flight_batches_sort_id = 0 + + # The base sequence of the next batch bound for a given partition. + self._next_sequence = collections.defaultdict(lambda: 0) + # The sequence of the last record of the last ack'd batch from the given partition. When there are no + # in flight requests for a partition, the self._last_acked_sequence(topicPartition) == nextSequence(topicPartition) - 1. + self._last_acked_sequence = collections.defaultdict(lambda: -1) + self.transactional_id = transactional_id + self.transaction_timeout_ms = transaction_timeout_ms + self._transaction_coordinator = None + self._consumer_group_coordinator = None + self._new_partitions_in_transaction = set() + self._pending_partitions_in_transaction = set() + self._partitions_in_transaction = set() + + self._current_state = TransactionState.UNINITIALIZED + self._last_error = None + self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) + + self._transaction_started = False + + self._pending_requests = [] # priority queue via heapq + self._pending_requests_sort_id = 0 + self._in_flight_request_correlation_id = self.NO_INFLIGHT_REQUEST_CORRELATION_ID + + # If a batch bound for a partition expired locally after being sent at least once, the partition has is considered + # to have an unresolved state. We keep track fo such partitions here, and cannot assign any more sequence numbers + # for this partition until the unresolved state gets cleared. This may happen if other inflight batches returned + # successfully (indicating that the expired batch actually made it to the broker). If we don't get any successful + # responses for the partition once the inflight request count falls to zero, we reset the producer id and + # consequently clear this data structure as well. + self._partitions_with_unresolved_sequences = set() + self._inflight_batches_by_sequence = dict() + # We keep track of the last acknowledged offset on a per partition basis in order to disambiguate UnknownProducer + # responses which are due to the retention period elapsing, and those which are due to actual lost data. + self._last_acked_offset = collections.defaultdict(lambda: -1) + + # This is used by the TxnRequestHandlers to control how long to back off before a given request is retried. + # For instance, this value is lowered by the AddPartitionsToTxnHandler when it receives a CONCURRENT_TRANSACTIONS + # error for the first AddPartitionsRequest in a transaction. + self.retry_backoff_ms = retry_backoff_ms + self._lock = threading.Condition() + + def initialize_transactions(self): + with self._lock: + self._ensure_transactional() + self._transition_to(TransactionState.INITIALIZING) + self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) + self._next_sequence.clear() + handler = InitProducerIdHandler(self, self.transactional_id, self.transaction_timeout_ms) + self._enqueue_request(handler) + return handler.result + + def begin_transaction(self): + with self._lock: + self._ensure_transactional() + self._maybe_fail_with_error() + self._transition_to(TransactionState.IN_TRANSACTION) + + def begin_commit(self): + with self._lock: + self._ensure_transactional() + self._maybe_fail_with_error() + self._transition_to(TransactionState.COMMITTING_TRANSACTION) + return self._begin_completing_transaction(True) + + def begin_abort(self): + with self._lock: + self._ensure_transactional() + if self._current_state != TransactionState.ABORTABLE_ERROR: + self._maybe_fail_with_error() + self._transition_to(TransactionState.ABORTING_TRANSACTION) + + # We're aborting the transaction, so there should be no need to add new partitions + self._new_partitions_in_transaction.clear() + return self._begin_completing_transaction(False) + + def _begin_completing_transaction(self, committed): + if self._new_partitions_in_transaction: + self._enqueue_request(self._add_partitions_to_transaction_handler()) + handler = EndTxnHandler(self, self.transactional_id, self.producer_id_and_epoch.producer_id, self.producer_id_and_epoch.epoch, committed) + self._enqueue_request(handler) + return handler.result + + def maybe_add_partition_to_transaction(self, topic_partition): + with self._lock: + self._fail_if_not_ready_for_send() + + if self.is_partition_added(topic_partition) or self.is_partition_pending_add(topic_partition): + return + + log.debug("Begin adding new partition %s to transaction", topic_partition) + self._new_partitions_in_transaction.add(topic_partition) + + def _fail_if_not_ready_for_send(self): + with self._lock: + if self.has_error(): + raise Errors.KafkaError( + "Cannot perform send because at least one previous transactional or" + " idempotent request has failed with errors.", self._last_error) + + if self.is_transactional(): + if not self.has_producer_id(): + raise Errors.IllegalStateError( + "Cannot perform a 'send' before completing a call to initTransactions" + " when transactions are enabled.") + + if self._current_state != TransactionState.IN_TRANSACTION: + raise Errors.IllegalStateError("Cannot call send in state %s" % (self._current_state.name,)) + + def is_send_to_partition_allowed(self, tp): + with self._lock: + if self.has_fatal_error(): + return False + return not self.is_transactional() or tp in self._partitions_in_transaction + + def has_producer_id(self, producer_id=None): + if producer_id is None: + return self.producer_id_and_epoch.is_valid + else: + return self.producer_id_and_epoch.producer_id == producer_id + + def is_transactional(self): + return self.transactional_id is not None + + def has_partitions_to_add(self): + with self._lock: + return bool(self._new_partitions_in_transaction) or bool(self._pending_partitions_in_transaction) + + def is_completing(self): + with self._lock: + return self._current_state in ( + TransactionState.COMMITTING_TRANSACTION, + TransactionState.ABORTING_TRANSACTION) + + @property + def last_error(self): + return self._last_error + + def has_error(self): + with self._lock: + return self._current_state in ( + TransactionState.ABORTABLE_ERROR, + TransactionState.FATAL_ERROR) + + def is_aborting(self): + with self._lock: + return self._current_state == TransactionState.ABORTING_TRANSACTION + + def transition_to_abortable_error(self, exc): + with self._lock: + if self._current_state == TransactionState.ABORTING_TRANSACTION: + log.debug("Skipping transition to abortable error state since the transaction is already being " + " aborted. Underlying exception: ", exc) + return + self._transition_to(TransactionState.ABORTABLE_ERROR, error=exc) + + def transition_to_fatal_error(self, exc): + with self._lock: + self._transition_to(TransactionState.FATAL_ERROR, error=exc) + + # visible for testing + def is_partition_added(self, partition): + with self._lock: + return partition in self._partitions_in_transaction + + # visible for testing + def is_partition_pending_add(self, partition): + return partition in self._new_partitions_in_transaction or partition in self._pending_partitions_in_transaction + + def has_producer_id_and_epoch(self, producer_id, producer_epoch): + return ( + self.producer_id_and_epoch.producer_id == producer_id and + self.producer_id_and_epoch.epoch == producer_epoch + ) + + def set_producer_id_and_epoch(self, producer_id_and_epoch): + if not isinstance(producer_id_and_epoch, ProducerIdAndEpoch): + raise TypeError("ProducerAndIdEpoch type required") + log.info("ProducerId set to %s with epoch %s", + producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch) + self.producer_id_and_epoch = producer_id_and_epoch + + def reset_producer_id(self): + """ + This method is used when the producer needs to reset its internal state because of an irrecoverable exception + from the broker. + + We need to reset the producer id and associated state when we have sent a batch to the broker, but we either get + a non-retriable exception or we run out of retries, or the batch expired in the producer queue after it was already + sent to the broker. + + In all of these cases, we don't know whether batch was actually committed on the broker, and hence whether the + sequence number was actually updated. If we don't reset the producer state, we risk the chance that all future + messages will return an OutOfOrderSequenceNumberError. + + Note that we can't reset the producer state for the transactional producer as this would mean bumping the epoch + for the same producer id. This might involve aborting the ongoing transaction during the initProducerIdRequest, + and the user would not have any way of knowing this happened. So for the transactional producer, + it's best to return the produce error to the user and let them abort the transaction and close the producer explicitly. + """ + with self._lock: + if self.is_transactional: + raise Errors.IllegalStateError( + "Cannot reset producer state for a transactional producer." + " You must either abort the ongoing transaction or" + " reinitialize the transactional producer instead") + self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) + self._next_sequence.clear() + self._last_acked_sequence.clear() + self._inflight_batches_by_sequence.clear() + self._partitions_with_unresolved_sequences.clear() + self._last_acked_offset.clear() + + def sequence_number(self, tp): + with self._lock: + return self._next_sequence[tp] + + def increment_sequence_number(self, tp, increment): + with self._lock: + if tp not in self._next_sequence: + raise Errors.IllegalStateError("Attempt to increment sequence number for a partition with no current sequence.") + # Sequence number wraps at java max int + base = self._next_sequence[tp] + if base > (2147483647 - increment): + self._next_sequence[tp] = increment - (2147483647 - base) - 1 + else: + self._next_sequence[tp] += increment + + def _next_in_flight_batches_sort_id(self): + self._in_flight_batches_sort_id += 1 + return self._in_flight_batches_sort_id + + def add_in_flight_batch(self, batch): + with self._lock: + if not batch.has_sequence(): + raise Errors.IllegalStateError("Can't track batch for partition %s when sequence is not set." % (batch.topic_partition,)) + heapq.heappush( + self._in_flight_batches_by_sequence[batch.topic_partition], + (batch.base_sequence, self._next_in_flight_batches_sort_id(), batch) + ) + + def first_in_flight_sequence(self, tp): + """ + Returns the first inflight sequence for a given partition. This is the base sequence of an inflight batch with + the lowest sequence number. If there are no inflight requests being tracked for this partition, this method will return -1 + """ + with self._lock: + if not self._in_flight_batches_by_sequence[tp]: + return NO_SEQUENCE + else: + return self._in_flight_batches_by_sequence[tp][0][2].base_sequence + + def next_batch_by_sequence(self, tp): + with self._lock: + if not self._in_flight_batches_by_sequence[tp]: + return None + else: + return self._in_flight_batches_by_sequence[tp][0][2] + + def remove_in_flight_batch(self, batch): + with self._lock: + if not self._in_flight_batches_by_sequence[batch.topic_partition]: + return + else: + try: + # see https://stackoverflow.com/questions/10162679/python-delete-element-from-heap + queue = self._in_flight_batches_by_sequence[batch.topic_partition] + idx = [item[2] for item in queue].index(batch) + queue[idx] = queue[-1] + queue.pop() + heapq.heapify(queue) + except ValueError: + pass + + def maybe_update_last_acked_sequence(self, tp, sequence): + with self._lock: + if sequence > self._last_acked_sequence[tp]: + self._last_acked_sequence[tp] = sequence + + def update_last_acked_offset(self, base_offset, batch): + if base_offset == -1: + return + last_offset = base_offset + batch.record_count - 1 + if last_offset > self._last_acked_offset[batch.topic_partition]: + self._last_acked_offset[batch.topic_partition] = last_offset + else: + log.debug("Partition %s keeps last_offset at %s", batch.topic_partition, last_offset) + + def adjust_sequences_due_to_failed_batch(self, batch): + # If a batch is failed fatally, the sequence numbers for future batches bound for the partition must be adjusted + # so that they don't fail with the OutOfOrderSequenceNumberError. + # + # This method must only be called when we know that the batch is question has been unequivocally failed by the broker, + # ie. it has received a confirmed fatal status code like 'Message Too Large' or something similar. + with self._lock: + if batch.topic_partition not in self._next_sequence: + # Sequence numbers are not being tracked for this partition. This could happen if the producer id was just + # reset due to a previous OutOfOrderSequenceNumberError. + return + log.debug("producer_id: %s, send to partition %s failed fatally. Reducing future sequence numbers by %s", + batch.producer_id, batch.topic_partition, batch.record_count) + current_sequence = self.sequence_number(batch.topic_partition) + current_sequence -= batch.record_count + if current_sequence < 0: + raise Errors.IllegalStateError( + "Sequence number for partition %s is going to become negative: %s" % (batch.topic_partition, current_sequence)) + + self._set_next_sequence(batch.topic_partition, current_sequence) + + for in_flight_batch in self._in_flight_batches_by_sequence[batch.topic_partition]: + if in_flight_batch.base_sequence < batch.base_sequence: + continue + new_sequence = in_flight_batch.base_sequence - batch.record_count + if new_sequence < 0: + raise Errors.IllegalStateError( + "Sequence number for batch with sequence %s for partition %s is going to become negative: %s" % ( + in_flight_batch.base_sequence, batch.topic_partition, new_sequence)) + + log.info("Resetting sequence number of batch with current sequence %s for partition %s to %s", + in_flight_batch.base_sequence(), batch.topic_partition, new_sequence) + in_flight_batch.reset_producer_state( + ProducerIdAndEpoch(in_flight_batch.producer_id, in_flight_batch.producer_epoch), + new_sequence, + in_flight_batch.is_transactional()) + + def _start_sequences_at_beginning(self, tp): + with self._lock: + sequence = 0 + for in_flight_batch in self._in_flight_batches_by_sequence[tp]: + log.info("Resetting sequence number of batch with current sequence %s for partition %s to %s", + in_flight_batch.base_sequence, in_flight_batch.topic_partition, sequence) + in_flight_batch.reset_producer_state( + ProducerIdAndEpoch(in_flight_batch.producer_id, in_flight_batch.producer_epoch), + sequence, + in_flight_batch.is_transactional()) + sequence += in_flight_batch.record_count + self._set_next_sequence(tp, sequence) + try: + del self._last_acked_sequence[tp] + except KeyError: + pass + + def has_in_flight_batches(self, tp): + with self._lock: + return len(self._in_flight_batches_by_sequence[tp]) > 0 + + def has_unresolved_sequences(self): + with self._lock: + return len(self._partitions_with_unresolved_sequences) > 0 + + def has_unresolved_sequence(self, tp): + with self._lock: + return tp in self._partitions_with_unresolved_sequences + + def mark_sequence_unresolved(self, tp): + with self._lock: + log.debug("Marking partition %s unresolved", tp) + self._partitions_with_unresolved_sequences.add(tp) + + # Checks if there are any partitions with unresolved partitions which may now be resolved. Returns True if + # the producer id needs a reset, False otherwise. + def should_reset_producer_state_after_resolving_sequences(self): + with self._lock: + try: + remove = set() + if self.is_transactional(): + # We should not reset producer state if we are transactional. We will transition to a fatal error instead. + return False + for tp in self._partitions_with_unresolved_sequences: + if not self.has_in_flight_batches(tp): + # The partition has been fully drained. At this point, the last ack'd sequence should be once less than + # next sequence destined for the partition. If so, the partition is fully resolved. If not, we should + # reset the sequence number if necessary. + if self.is_next_sequence(tp, self.sequence_number(tp)): + # This would happen when a batch was expired, but subsequent batches succeeded. + remove.add(tp) + else: + # We would enter this branch if all in flight batches were ultimately expired in the producer. + log.info("No inflight batches remaining for %s, last ack'd sequence for partition is %s, next sequence is %s." + " Going to reset producer state.", tp, self._last_acked_sequence(tp), self.sequence_number(tp)) + return True + return False + finally: + self._partitions_with_unresolved_sequences -= remove + + def is_next_sequence(self, tp, sequence): + with self._lock: + return sequence - self._last_acked_sequence(tp) == 1 + + def _set_next_sequence(self, tp, sequence): + with self._lock: + if tp not in self._next_sequence and sequence != 0: + raise Errors.IllegalStateError( + "Trying to set the sequence number for %s to %s but the sequence number was never set for this partition." % ( + tp, sequence)) + self._next_sequence[tp] = sequence + + def next_request_handler(self, has_incomplete_batches): + with self._lock: + if self._new_partitions_in_transaction: + self._enqueue_request(self._add_partitions_to_transaction_handler()) + + if not self._pending_requests: + return None + + _, _, next_request_handler = self._pending_requests[0] + # Do not send the EndTxn until all batches have been flushed + if isinstance(next_request_handler, EndTxnHandler) and has_incomplete_batches: + return None + + heapq.heappop(self._pending_requests) + if self._maybe_terminate_request_with_error(next_request_handler): + log.debug("Not sending transactional request %s because we are in an error state", + next_request_handler.request) + return None + + if isinstance(next_request_handler, EndTxnHandler) and not self._transaction_started: + next_request_handler.result.done() + if self._current_state != TransactionState.FATAL_ERROR: + log.debug("Not sending EndTxn for completed transaction since no partitions" + " or offsets were successfully added") + self._complete_transaction() + try: + _, _, next_request_handler = heapq.heappop(self._pending_requests) + except IndexError: + next_request_handler = None + + if next_request_handler: + log.debug("Request %s dequeued for sending", next_request_handler.request) + + return next_request_handler + + def retry(self, request): + with self._lock: + request.set_retry() + self._enqueue_request(request) + + def authentication_failed(self, exc): + with self._lock: + for _, _, request in self._pending_requests: + request.fatal_error(exc) + + def coordinator(self, coord_type): + if coord_type == 'group': + return self._consumer_group_coordinator + elif coord_type == 'transaction': + return self._transaction_coordinator + else: + raise Errors.IllegalStateError("Received an invalid coordinator type: %s" % (coord_type,)) + + def lookup_coordinator_for_request(self, request): + self._lookup_coordinator(request.coordinator_type, request.coordinator_key) + + def next_in_flight_request_correlation_id(self): + self._in_flight_request_correlation_id += 1 + return self._in_flight_request_correlation_id + + def clear_in_flight_transactional_request_correlation_id(self): + self._in_flight_request_correlation_id = self.NO_INFLIGHT_REQUEST_CORRELATION_ID + + def has_in_flight_transactional_request(self): + return self._in_flight_request_correlation_id != self.NO_INFLIGHT_REQUEST_CORRELATION_ID + + # visible for testing. + def has_fatal_error(self): + return self._current_state == TransactionState.FATAL_ERROR + + # visible for testing. + def has_abortable_error(self): + return self._current_state == TransactionState.ABORTABLE_ERROR + + # visible for testing + def transactionContainsPartition(self, tp): + with self._lock: + return tp in self._partitions_in_transaction + + # visible for testing + def has_ongoing_transaction(self): + with self._lock: + # transactions are considered ongoing once started until completion or a fatal error + return self._current_state == TransactionState.IN_TRANSACTION or self.is_completing() or self.has_abortable_error() + + def can_retry(self, batch, error, log_start_offset): + with self._lock: + if not self.has_producer_id(batch.producer_id): + return False + + elif ( + error is Errors.OutOfOrderSequenceNumberError + and not self.has_unresolved_sequence(batch.topic_partition) + and (batch.sequence_has_been_reset() or not self.is_next_sequence(batch.topic_partition, batch.base_sequence)) + ): + # We should retry the OutOfOrderSequenceNumberError if the batch is _not_ the next batch, ie. its base + # sequence isn't the self._last_acked_sequence + 1. However, if the first in flight batch fails fatally, we will + # adjust the sequences of the other inflight batches to account for the 'loss' of the sequence range in + # the batch which failed. In this case, an inflight batch will have a base sequence which is + # the self._last_acked_sequence + 1 after adjustment. When this batch fails with an OutOfOrderSequenceNumberError, we want to retry it. + # To account for the latter case, we check whether the sequence has been reset since the last drain. + # If it has, we will retry it anyway. + return True + + elif error is Errors.UnknownProducerIdError: + if log_start_offset == -1: + # We don't know the log start offset with this response. We should just retry the request until we get it. + # The UNKNOWN_PRODUCER_ID error code was added along with the new ProduceResponse which includes the + # logStartOffset. So the '-1' sentinel is not for backward compatibility. Instead, it is possible for + # a broker to not know the logStartOffset at when it is returning the response because the partition + # may have moved away from the broker from the time the error was initially raised to the time the + # response was being constructed. In these cases, we should just retry the request: we are guaranteed + # to eventually get a logStartOffset once things settle down. + return True + + if batch.sequence_has_been_reset(): + # When the first inflight batch fails due to the truncation case, then the sequences of all the other + # in flight batches would have been restarted from the beginning. However, when those responses + # come back from the broker, they would also come with an UNKNOWN_PRODUCER_ID error. In this case, we should not + # reset the sequence numbers to the beginning. + return True + elif self._last_acked_offset(batch.topic_partition) < log_start_offset: + # The head of the log has been removed, probably due to the retention time elapsing. In this case, + # we expect to lose the producer state. Reset the sequences of all inflight batches to be from the beginning + # and retry them. + self._start_sequences_at_beginning(batch.topic_partition) + return True + return False + + # visible for testing + def is_ready(self): + with self._lock: + return self.is_transactional() and self._current_state == TransactionState.READY + + def _transition_to(self, target, error=None): + with self._lock: + if not self._current_state.is_transition_valid(self._current_state, target): + raise Errors.KafkaError("TransactionalId %s: Invalid transition attempted from state %s to state %s" % ( + self.transactional_id, self._current_state.name, target.name)) + + if target in (TransactionState.FATAL_ERROR, TransactionState.ABORTABLE_ERROR): + if error is None: + raise Errors.IllegalArgumentError("Cannot transition to %s with an None exception" % (target.name,)) + self._last_error = error + else: + self._last_error = None + + if self._last_error is not None: + log.debug("Transition from state %s to error state %s (%s)", self._current_state.name, target.name, self._last_error) + else: + log.debug("Transition from state %s to %s", self._current_state, target) + self._current_state = target + + def _ensure_transactional(self): + if not self.is_transactional(): + raise Errors.IllegalStateError("Transactional method invoked on a non-transactional producer.") + + def _maybe_fail_with_error(self): + if self.has_error(): + raise Errors.KafkaError("Cannot execute transactional method because we are in an error state: %s" % (self._last_error,)) + + def _maybe_terminate_request_with_error(self, request_handler): + if self.has_error(): + if self.has_abortable_error() and isinstance(request_handler, FindCoordinatorHandler): + # No harm letting the FindCoordinator request go through if we're expecting to abort + return False + request_handler.fail(self._last_error) + return True + return False + + def _next_pending_requests_sort_id(self): + self._pending_requests_sort_id += 1 + return self._pending_requests_sort_id + + def _enqueue_request(self, request_handler): + log.debug("Enqueuing transactional request %s", request_handler.request) + heapq.heappush( + self._pending_requests, + ( + request_handler.priority, # keep lowest priority at head of queue + self._next_pending_requests_sort_id(), # break ties + request_handler + ) + ) + + def _lookup_coordinator(self, coord_type, coord_key): + with self._lock: + if coord_type == 'group': + self._consumer_group_coordinator = None + elif coord_type == 'transaction': + self._transaction_coordinator = None + else: + raise Errors.IllegalStateError("Invalid coordinator type: %s" % (coord_type,)) + self._enqueue_request(FindCoordinatorHandler(self, coord_type, coord_key)) + + def _complete_transaction(self): + with self._lock: + self._transition_to(TransactionState.READY) + self._transaction_started = False + self._new_partitions_in_transaction.clear() + self._pending_partitions_in_transaction.clear() + self._partitions_in_transaction.clear() + + def _add_partitions_to_transaction_handler(self): + with self._lock: + self._pending_partitions_in_transaction.update(self._new_partitions_in_transaction) + self._new_partitions_in_transaction.clear() + return AddPartitionsToTxnHandler(self, self.transactional_id, self.producer_id_and_epoch.producer_id, self.producer_id_and_epoch.epoch, self._pending_partitions_in_transaction) + + +class TransactionalRequestResult(object): + def __init__(self): + self._latch = threading.Event() + self._error = None + + def done(self, error=None): + self._error = error + self._latch.set() + + def wait(self, timeout_ms=None): + timeout = timeout_ms / 1000 if timeout_ms is not None else None + success = self._latch.wait(timeout) + if self._error: + raise self._error + return success + + @property + def is_done(self): + return self._latch.is_set() + + @property + def succeeded(self): + return self._error is None and self._latch.is_set() + + +@six.add_metaclass(abc.ABCMeta) +class TxnRequestHandler(object): + def __init__(self, transaction_manager, result=None): + self.transaction_manager = transaction_manager + self.retry_backoff_ms = transaction_manager.retry_backoff_ms + self.request = None + self._result = result or TransactionalRequestResult() + self._is_retry = False + + def fatal_error(self, exc): + self.transaction_manager._transition_to_fatal_error(exc) + self._result.done(error=exc) + + def abortable_error(self, exc): + self.transaction_manager._transition_to_abortable_error(exc) + self._result.done(error=exc) + + def fail(self, exc): + self._result.done(error=exc) + + def reenqueue(self): + with self.transaction_manager._lock: + self._is_retry = True + self.transaction_manager._enqueue_request(self) + + def on_complete(self, correlation_id, response_or_exc): + if correlation_id != self.transaction_manager._in_flight_request_correlation_id: + self.fatal_error(RuntimeError("Detected more than one in-flight transactional request.")) + else: + self.transaction_manager.clear_in_flight_transactional_request_correlation_id() + if isinstance(response_or_exc, Errors.KafkaConnectionError): + log.debug("Disconnected from node. Will retry.") + if self.needs_coordinator(): + self.transaction_manager._lookup_coordinator(self.coordinator_type, self.coordinator_key) + self.reenqueue() + elif isinstance(response_or_exc, Errors.UnsupportedVersionError): + self.fatal_error(response_or_exc) + elif not isinstance(response_or_exc, (Exception, type(None))): + log.debug("Received transactional response %s for request %s", response_or_exc, self.request) + with self.transaction_manager._lock: + self.handle_response(response_or_exc) + else: + self.fatal_error(Errors.KafkaError("Could not execute transactional request for unknown reasons: %s" % response_or_exc)) + + def needs_coordinator(self): + return self.coordinator_type is not None + + @property + def result(self): + return self._result + + @property + def coordinator_type(self): + return 'transaction' + + @property + def coordinator_key(self): + return self.transaction_manager.transactional_id + + def set_retry(self): + self._is_retry = True + + @property + def is_retry(self): + return self._is_retry + + @abc.abstractmethod + def handle_response(self, response): + pass + + @abc.abstractproperty + def priority(self): + pass + + +class InitProducerIdHandler(TxnRequestHandler): + def __init__(self, transaction_manager, transactional_id, transaction_timeout_ms): + super(InitProducerIdHandler, self).__init__(transaction_manager) + + self.transactional_id = transactional_id + if transaction_manager._api_version >= (2, 0): + version = 1 + else: + version = 0 + self.request = InitProducerIdRequest[version]( + transactional_id=transactional_id, + transaction_timeout_ms=transaction_timeout_ms) + + @property + def priority(self): + return Priority.INIT_PRODUCER_ID + + def handle_response(self, response): + error = Errors.for_code(response.error_code) + + if error is Errors.NoError: + self.transaction_manager.set_producer_id_and_epoch(ProducerIdAndEpoch(response.producer_id, response.producer_epoch)) + self.transaction_manager._transition_to(TransactionState.READY) + self._result.done() + elif error in (Errors.NotCoordinatorError, Errors.CoordinatorNotAvailableError): + self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) + self.reenqueue() + elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError): + self.reenqueue() + elif error is Errors.TransactionalIdAuthorizationFailedError: + self.fatal_error(error()) + else: + self.fatal_error(Errors.KafkaError("Unexpected error in InitProducerIdResponse: %s" % (error()))) + +class AddPartitionsToTxnHandler(TxnRequestHandler): + def __init__(self, transaction_manager, transactional_id, producer_id, producer_epoch, topic_partitions): + super(AddPartitionsToTxnHandler, self).__init__(transaction_manager) + + self.transactional_id = transactional_id + if transaction_manager._api_version >= (2, 7): + version = 2 + elif transaction_manager._api_version >= (2, 0): + version = 1 + else: + version = 0 + topic_data = collections.defaultdict(list) + for tp in topic_partitions: + topic_data[tp.topic].append(tp.partition) + self.request = AddPartitionsToTxnRequest[version]( + transactional_id=transactional_id, + producer_id=producer_id, + producer_epoch=producer_epoch, + topics=list(topic_data.items())) + + @property + def priority(self): + return Priority.ADD_PARTITIONS_OR_OFFSETS + + def handle_response(self, response): + has_partition_errors = False + unauthorized_topics = set() + self.retry_backoff_ms = self.transaction_manager.retry_backoff_ms + + results = {TopicPartition(topic, partition): Errors.for_code(error_code) + for topic, partition_data in response.results + for partition, error_code in partition_data} + + for tp, error in six.iteritems(results): + if error is Errors.NoError: + continue + elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError): + self.transaction_manager._lookup_coordinator('transaction', self.transactiona_id) + self.reenqueue() + return + elif error is Errors.ConcurrentTransactionError: + self.maybe_override_retry_backoff_ms() + self.reenqueue() + return + elif error in (Errors.CoordinatorLoadInProgressError, Errors.UnknownTopicOrPartitionError): + self.reenqueue() + return + elif error is Errors.InvalidProducerEpochError: + self.fatal_error(error()) + return + elif error is Errors.TransactionalIdAuthorizationFailedError: + self.fatal_error(error()) + return + elif error in (Errors.InvalidProducerIdMappingError, Errors.InvalidTxnStateError): + self.fatal_error(Errors.KafkaError(error())) + return + elif error is Errors.TopicAuthorizationFailedError: + unauthorized_topics.add(tp.topic) + elif error is Errors.OperationNotAttemptedError: + log.debug("Did not attempt to add partition %s to transaction because other partitions in the" + " batch had errors.", tp) + has_partition_errors = True + else: + log.error("Could not add partition %s due to unexpected error %s", tp, error()) + has_partition_errors = True + + partitions = set(results) + + # Remove the partitions from the pending set regardless of the result. We use the presence + # of partitions in the pending set to know when it is not safe to send batches. However, if + # the partitions failed to be added and we enter an error state, we expect the batches to be + # aborted anyway. In this case, we must be able to continue sending the batches which are in + # retry for partitions that were successfully added. + self.transaction_manager._pending_partitions_in_transaction -= partitions + + if unauthorized_topics: + self.abortable_error(Errors.TopicAuthorizationError(unauthorized_topics)) + elif has_partition_errors: + self.abortable_error(Errors.KafkaError("Could not add partitions to transaction due to errors: %s" % (results))) + else: + log.debug("Successfully added partitions %s to transaction", partitions) + self.transaction_manager._partitions_in_transaction.update(partitions) + self.transaction_manager._transaction_started = True + self._result.done() + + def maybe_override_retry_backoff_ms(self): + # We only want to reduce the backoff when retrying the first AddPartition which errored out due to a + # CONCURRENT_TRANSACTIONS error since this means that the previous transaction is still completing and + # we don't want to wait too long before trying to start the new one. + # + # This is only a temporary fix, the long term solution is being tracked in + # https://issues.apache.org/jira/browse/KAFKA-5482 + if not self._partitions_in_transaction: + self.retry_backoff_ms = min(self.transaction_manager.ADD_PARTITIONS_RETRY_BACKOFF_MS, self.retry_backoff_ms) + + +class FindCoordinatorHandler(TxnRequestHandler): + def __init__(self, transaction_manager, coord_type, coord_key): + super(FindCoordinatorHandler, self).__init__(transaction_manager) + + self._coord_type = coord_type + self._coord_key = coord_key + if transaction_manager._api_version >= (2, 0): + version = 2 + else: + version = 1 + if coord_type == 'group': + coord_type_int8 = 0 + elif coord_type == 'transaction': + coord_type_int8 = 1 + else: + raise ValueError("Unrecognized coordinator type: %s" % (coord_type,)) + self.request = FindCoordinatorRequest[version]( + coordinator_key=coord_key, + coordinator_type=coord_type_int8, + ) + + @property + def priority(self): + return Priority.FIND_COORDINATOR + + @property + def coordinator_type(self): + return None + + @property + def coordinator_key(self): + return None + + def handle_response(self, response): + error = Errors.for_code(response.error_code) + + if error is Errors.NoError: + coordinator_id = self.transaction_manager._metadata.add_coordinator( + response, self._coord_type, self._coord_key) + if self._coord_type == 'group': + self.transaction_manager._consumer_group_coordinator = coordinator_id + elif self._coord_type == 'transaction': + self.transaction_manager._transaction_coordinator = coordinator_id + self._result.done() + elif error is Errors.CoordinatorNotAvailableError: + self.reenqueue() + elif error is Errors.TransactionalIdAuthorizationFailedError: + self.fatal_error(error()) + elif error is Errors.GroupAuthorizationFailedError: + self.abortable_error(Errors.GroupAuthorizationError(self._coord_key)) + else: + self.fatal_error(Errors.KafkaError( + "Could not find a coordinator with type %s with key %s due to" + " unexpected error: %s" % (self._coord_type, self._coord_key, error()))) + + +class EndTxnHandler(TxnRequestHandler): + def __init__(self, transaction_manager, transactional_id, producer_id, producer_epoch, committed): + super(EndTxnHandler, self).__init__(transaction_manager) + + self.transactional_id = transactional_id + if self.transaction_manager._api_version >= (2, 7): + version = 2 + elif self.transaction_manager._api_version >= (2, 0): + version = 1 + else: + version = 0 + self.request = EndTxnRequest[version]( + transactional_id=transactional_id, + producer_id=producer_id, + producer_epoch=producer_epoch, + committed=committed) + + @property + def priority(self): + return Priority.END_TXN + + def handle_response(self, response): + error = Errors.for_code(response.error_code) + + if error is Errors.NoError: + self.transaction_manager._complete_transaction() + self._result.done() + elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError): + self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) + self.reenqueue() + elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError): + self.reenqueue() + elif error is Errors.InvalidProducerEpochError: + self.fatal_error(error()) + elif error is Errors.TransactionalIdAuthorizationFailedError: + self.fatal_error(error()) + elif error is Errors.InvalidTxnStateError: + self.fatal_error(error()) + else: + self.fatal_error(Errors.KafkaError("Unhandled error in EndTxnResponse: %s" % (error()))) From 1561d45e1bd6028c4adccd93e9aa9611d83c865a Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sat, 5 Apr 2025 15:28:40 -0700 Subject: [PATCH 06/27] Drop producer TransactionState --- kafka/producer/transaction_state.py | 96 ----------------------------- 1 file changed, 96 deletions(-) delete mode 100644 kafka/producer/transaction_state.py diff --git a/kafka/producer/transaction_state.py b/kafka/producer/transaction_state.py deleted file mode 100644 index 05cdc5766..000000000 --- a/kafka/producer/transaction_state.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import absolute_import, division - -import collections -import threading -import time - -from kafka.errors import IllegalStateError - - -NO_PRODUCER_ID = -1 -NO_PRODUCER_EPOCH = -1 - - -class ProducerIdAndEpoch(object): - __slots__ = ('producer_id', 'epoch') - - def __init__(self, producer_id, epoch): - self.producer_id = producer_id - self.epoch = epoch - - @property - def is_valid(self): - return NO_PRODUCER_ID < self.producer_id - - def __str__(self): - return "ProducerIdAndEpoch(producer_id={}, epoch={})".format(self.producer_id, self.epoch) - -class TransactionState(object): - __slots__ = ('producer_id_and_epoch', '_sequence_numbers', '_lock') - - def __init__(self): - self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) - self._sequence_numbers = collections.defaultdict(lambda: 0) - self._lock = threading.Condition() - - def has_pid(self): - return self.producer_id_and_epoch.is_valid - - - def await_producer_id_and_epoch(self, max_wait_time_ms): - """ - A blocking call to get the pid and epoch for the producer. If the PID and epoch has not been set, this method - will block for at most maxWaitTimeMs. It is expected that this method be called from application thread - contexts (ie. through Producer.send). The PID it self will be retrieved in the background thread. - - Arguments: - max_wait_time_ms (numeric): The maximum time to block. - - Returns: - ProducerIdAndEpoch object. Callers must check the 'is_valid' property of the returned object to ensure that a - valid pid and epoch is actually returned. - """ - with self._lock: - start = time.time() - elapsed = 0 - while not self.has_pid() and elapsed < max_wait_time_ms: - self._lock.wait(max_wait_time_ms / 1000) - elapsed = time.time() - start - return self.producer_id_and_epoch - - def set_producer_id_and_epoch(self, producer_id, epoch): - """ - Set the pid and epoch atomically. This method will signal any callers blocked on the `pidAndEpoch` method - once the pid is set. This method will be called on the background thread when the broker responds with the pid. - """ - with self._lock: - self.producer_id_and_epoch = ProducerIdAndEpoch(producer_id, epoch) - if self.producer_id_and_epoch.is_valid: - self._lock.notify_all() - - def reset_producer_id(self): - """ - This method is used when the producer needs to reset it's internal state because of an irrecoverable exception - from the broker. - - We need to reset the producer id and associated state when we have sent a batch to the broker, but we either get - a non-retriable exception or we run out of retries, or the batch expired in the producer queue after it was already - sent to the broker. - - In all of these cases, we don't know whether batch was actually committed on the broker, and hence whether the - sequence number was actually updated. If we don't reset the producer state, we risk the chance that all future - messages will return an OutOfOrderSequenceException. - """ - with self._lock: - self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) - self._sequence_numbers.clear() - - def sequence_number(self, tp): - with self._lock: - return self._sequence_numbers[tp] - - def increment_sequence_number(self, tp, increment): - with self._lock: - if tp not in self._sequence_numbers: - raise IllegalStateError("Attempt to increment sequence number for a partition with no current sequence.") - self._sequence_numbers[tp] += increment From e9cc2028299793c88963d9ff0b7685dd03332a4e Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 6 Apr 2025 07:54:52 -0700 Subject: [PATCH 07/27] RecordAccumulator transaction state -> manager has_unsent -> has_undrained has_incomplete abort_undrained_batches --- kafka/producer/record_accumulator.py | 59 ++++++++++++++++++++++------ kafka/producer/sender.py | 2 +- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 75514102b..3e484d3c7 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -56,6 +56,10 @@ def record_count(self): def producer_id(self): return self.records.producer_id if self.records else None + @property + def producer_epoch(self): + return self.records.producer_epoch if self.records else None + @property def has_sequence(self): return self.records.has_sequence if self.records else False @@ -174,7 +178,7 @@ class RecordAccumulator(object): 'compression_attrs': 0, 'linger_ms': 0, 'retry_backoff_ms': 100, - 'transaction_state': None, + 'transaction_manager': None, 'message_version': 0, } @@ -185,7 +189,7 @@ def __init__(self, **configs): self.config[key] = configs.pop(key) self._closed = False - self._transaction_state = self.config['transaction_state'] + self._transaction_manager = self.config['transaction_manager'] self._flushes_in_progress = AtomicInteger() self._appends_in_progress = AtomicInteger() self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch] @@ -248,7 +252,7 @@ def append(self, tp, timestamp_ms, key, value, headers): batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False - if self._transaction_state and self.config['message_version'] < 2: + if self._transaction_manager and self.config['message_version'] < 2: raise Errors.UnsupportedVersionError("Attempting to use idempotence with a broker which" " does not support the required message format (v2)." " The broker must be version 0.11 or later.") @@ -422,8 +426,8 @@ def ready(self, cluster, now=None): return ready_nodes, next_ready_check, unknown_leaders_exist - def has_unsent(self): - """Return whether there is any unsent record in the accumulator.""" + def has_undrained(self): + """Check whether there are any batches which haven't been drained""" for tp in list(self._batches.keys()): with self._tp_locks[tp]: dq = self._batches[tp] @@ -483,8 +487,8 @@ def drain(self, cluster, nodes, max_size, now=None): break else: producer_id_and_epoch = None - if self._transaction_state: - producer_id_and_epoch = self._transaction_state.producer_id_and_epoch + if self._transaction_manager: + producer_id_and_epoch = self._transaction_manager.producer_id_and_epoch if not producer_id_and_epoch.is_valid: # we cannot send the batch until we have refreshed the PID log.debug("Waiting to send ready batches because transaction producer id is not valid") @@ -497,11 +501,16 @@ def drain(self, cluster, nodes, max_size, now=None): # the previous attempt may actually have been accepted, and if we change # the pid and sequence here, this attempt will also be accepted, causing # a duplicate. - sequence_number = self._transaction_state.sequence_number(batch.topic_partition) + sequence_number = self._transaction_manager.sequence_number(batch.topic_partition) log.debug("Dest: %s: %s producer_id=%s epoch=%s sequence=%s", node_id, batch.topic_partition, producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch, sequence_number) - batch.records.set_producer_state(producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch, sequence_number) + batch.records.set_producer_state( + producer_id_and_epoch.producer_id, + producer_id_and_epoch.epoch, + sequence_number, + self._transaction_manager.is_transactional() + ) batch.records.close() size += batch.records.size_in_bytes() ready.append(batch) @@ -548,6 +557,10 @@ def await_flush_completion(self, timeout=None): finally: self._flushes_in_progress.decrement() + @property + def has_incomplete(self): + return bool(self._incomplete) + def abort_incomplete_batches(self): """ This function is only called when sender is closed forcefully. It will fail all the @@ -557,27 +570,41 @@ def abort_incomplete_batches(self): # 1. Avoid losing batches. # 2. Free up memory in case appending threads are blocked on buffer full. # This is a tight loop but should be able to get through very quickly. + error = Errors.IllegalStateError("Producer is closed forcefully.") while True: - self._abort_batches() + self._abort_batches(error) if not self._appends_in_progress.get(): break # After this point, no thread will append any messages because they will see the close # flag set. We need to do the last abort after no thread was appending in case the there was a new # batch appended by the last appending thread. - self._abort_batches() + self._abort_batches(error) self._batches.clear() - def _abort_batches(self): + def _abort_batches(self, error): """Go through incomplete batches and abort them.""" - error = Errors.IllegalStateError("Producer is closed forcefully.") for batch in self._incomplete.all(): tp = batch.topic_partition # Close the batch before aborting with self._tp_locks[tp]: batch.records.close() + self._batches[tp].remove(batch) batch.done(exception=error) self.deallocate(batch) + def abort_undrained_batches(self, error): + for batch in self._incomplete.all(): + tp = batch.topic_partition + with self._tp_locks[tp]: + aborted = False + if (self._transaction_manager and not batch.has_sequence) or (not self._transaction_manager and not batch.is_done): + aborted = True + batch.records.close() + self._batches[tp].remove(batch) + if aborted: + batch.done(exception=error) + self.deallocate(batch) + def close(self): """Close this accumulator and force all the record buffers to be drained.""" self._closed = True @@ -604,3 +631,9 @@ def remove(self, batch): def all(self): with self._lock: return list(self._incomplete) + + def __bool__(self): + return bool(self._incomplete) + + + __nonzero__ = __bool__ diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 96a50cbbc..6cdba2a0e 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -77,7 +77,7 @@ def run(self): # requests in the accumulator or waiting for acknowledgment, # wait until these are completed. while (not self._force_close - and (self._accumulator.has_unsent() + and (self._accumulator.has_undrained() or self._client.in_flight_request_count() > 0)): try: self.run_once() From 85601675ba2748b9b863560fffd9374d16199792 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 6 Apr 2025 07:55:54 -0700 Subject: [PATCH 08/27] Sender --- kafka/producer/sender.py | 218 ++++++++++++++++++++++++++++++--------- 1 file changed, 167 insertions(+), 51 deletions(-) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 6cdba2a0e..46330a251 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -33,7 +33,7 @@ class Sender(threading.Thread): 'retry_backoff_ms': 100, 'metrics': None, 'guarantee_message_order': False, - 'transaction_state': None, + 'transaction_manager': None, 'transactional_id': None, 'transaction_timeout_ms': 60000, 'client_id': 'kafka-python-' + __version__, @@ -57,7 +57,7 @@ def __init__(self, client, metadata, accumulator, **configs): self._sensors = SenderMetrics(self.config['metrics'], self._client, self._metadata) else: self._sensors = None - self._transaction_state = self.config['transaction_state'] + self._transaction_manager = self.config['transaction_manager'] def run(self): """The main run loop for the sender thread.""" @@ -101,8 +101,44 @@ def run_once(self): while self._topics_to_add: self._client.add_topic(self._topics_to_add.pop()) - self._maybe_wait_for_producer_id() + if self._transaction_manager: + try: + if self._transaction_manager.should_reset_producer_state_after_resolving_sequences(): + # Check if the previous run expired batches which requires a reset of the producer state. + self._transaction_manager.reset_producer_id() + if not self._transaction_manager.is_transactional(): + # this is an idempotent producer, so make sure we have a producer id + self._maybe_wait_for_producer_id() + elif self._transaction_manager.has_unresolved_sequences() and not self._transaction_manager.has_fatal_error(): + self._transaction_manager.transition_to_fatal_error( + Errors.KafkaError("The client hasn't received acknowledgment for" + " some previously sent messages and can no longer retry them." + " It isn't safe to continue.")) + elif self._transaction_manager.has_in_flight_transactional_request() or self._maybe_send_transactional_request(): + # as long as there are outstanding transactional requests, we simply wait for them to return + self._client.poll(timeout_ms=self.config['retry_backoff_ms']) + return + + # do not continue sending if the transaction manager is in a failed state or if there + # is no producer id (for the idempotent case). + if self._transaction_manager.has_fatal_error() or not self._transaction_manager.has_producer_id(): + last_error = self._transaction_manager.last_error + if last_error is not None: + self._maybe_abort_batches(last_error) + self._client.poll(timeout_ms=self.config['retry_backoff_ms']) + return + elif self._transaction_manager.has_abortable_error(): + self._accumulator.abort_undrained_batches(self._transaction_manager.last_error) + + except Errors.SaslAuthenticationFailedError as e: + # This is already logged as error, but propagated here to perform any clean ups. + log.debug("Authentication exception while processing transactional request: %s", e) + self._transaction_manager.authentication_failed(e) + poll_timeout_ms = self._send_producer_data() + self._client.poll(timeout_ms=poll_timeout_ms) + + def _send_producer_data(self): # get the list of partitions with data ready to send result = self._accumulator.ready(self._metadata) ready_nodes, next_ready_check_delay, unknown_leaders_exist = result @@ -139,9 +175,9 @@ def run_once(self): # Reset the producer_id if an expired batch has previously been sent to the broker. # See the documentation of `TransactionState.reset_producer_id` to understand why # we need to reset the producer id here. - if self._transaction_state and any([batch.in_retry() for batch in expired_batches]): - self._transaction_state.reset_producer_id() - return + if self._transaction_manager and any([batch.in_retry() for batch in expired_batches]): + self._transaction_manager.reset_producer_id() + return 0 if self._sensors: for expired_batch in expired_batches: @@ -160,6 +196,12 @@ def run_once(self): if ready_nodes: log.debug("Nodes with data ready to send: %s", ready_nodes) # trace log.debug("Created %d produce requests: %s", len(requests), requests) # trace + # if some partitions are already ready to be sent, the select time + # would be 0; otherwise if some partition already has some data + # accumulated but not ready yet, the select time will be the time + # difference between now and its linger expiry time; otherwise the + # select time will be the time difference between now and the + # metadata expiry time poll_timeout_ms = 0 for node_id, request in six.iteritems(requests): @@ -170,14 +212,67 @@ def run_once(self): self._handle_produce_response, node_id, time.time(), batches) .add_errback( self._failed_produce, batches, node_id)) + return poll_timeout_ms + + def _maybe_send_transactional_request(self): + if self._transaction_manager.is_completing() and self._accumulator.has_incomplete: + if self._transaction_manager.is_aborting(): + self._accumulator.abort_undrained_batches(Errors.KafkaError("Failing batch since transaction was aborted")) + # There may still be requests left which are being retried. Since we do not know whether they had + # been successfully appended to the broker log, we must resend them until their final status is clear. + # If they had been appended and we did not receive the error, then our sequence number would no longer + # be correct which would lead to an OutOfSequenceNumberError. + if not self._accumulator.flush_in_progress(): + self._accumulator.begin_flush() + + next_request_handler = self._transaction_manager.next_request_handler(self._accumulator.has_incomplete) + if next_request_handler is None: + return False + + log.debug("transactional_id: %s -- Sending transactional request %s", self._transaction_manager.transactional_id, next_request_handler.request) + while not self._force_close: + target_node = None + try: + if next_request_handler.needs_coordinator(): + target_node = self._transaction_manager.coordinator(next_request_handler.coordinator_type) + if target_node is None: + self._transaction_manager.lookup_coordinator_for_request(next_request_handler) + break + elif not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): + self._transaction_manager.lookup_coordinator_for_request(next_request_handler) + target_node = None + break + else: + target_node = self._client.least_loaded_node() + if target_node is not None and not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): + target_node = None + + if target_node is not None: + if next_request_handler.is_retry: + time.sleep(self.config['retry_backoff_ms'] / 1000) + txn_correlation_id = self._transaction_manager.next_in_flight_request_correlation_id() + future = self._client.send(target_node, next_request_handler.request) + future.add_both(next_request_handler.on_complete, txn_correlation_id) + return True + + except Exception as e: + log.warn("Got an exception when trying to find a node to send a transactional request to. Going to back off and retry", e) + if next_request_handler.needs_coordinator(): + self._transaction_manager.lookup_coordinator_for_request(next_request_handler) + break - # if some partitions are already ready to be sent, the select time - # would be 0; otherwise if some partition already has some data - # accumulated but not ready yet, the select time will be the time - # difference between now and its linger expiry time; otherwise the - # select time will be the time difference between now and the - # metadata expiry time - self._client.poll(timeout_ms=poll_timeout_ms) + time.sleep(self.config['retry_backoff_ms'] / 1000) + self._metadata.request_update() + + if target_node is None: + self._transaction_manager.retry(next_request_handler) + + return True + + def _maybe_abort_batches(self, exc): + if self._accumulator.has_incomplete: + log.error("Aborting producer batches due to fatal error: %s", exc) + self._accumulator.abort_batches(exc) def initiate_close(self): """Start closing the sender (won't complete until all data is sent).""" @@ -201,10 +296,7 @@ def add_topic(self, topic): self.wakeup() def _maybe_wait_for_producer_id(self): - if not self._transaction_state: - return - - while not self._transaction_state.has_pid(): + while not self._transaction_manager.has_producer_id(): try: node_id = self._client.least_loaded_node() if node_id is None or not self._client.await_ready(node_id): @@ -220,19 +312,20 @@ def _maybe_wait_for_producer_id(self): response = self._client.send_and_receive(node_id, request) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - self._transaction_state.set_producer_id_and_epoch(response.producer_id, response.producer_epoch) + self._transaction_manager.set_producer_id_and_epoch(response.producer_id, response.producer_epoch) return elif getattr(error_type, 'retriable', False): log.debug("Retriable error from InitProducerId response: %s", error_type.__name__) if getattr(error_type, 'invalid_metadata', False): self._metadata.request_update() else: - log.error("Received a non-retriable error from InitProducerId response: %s", error_type.__name__) + self._transaction_manager.transition_to_fatal_error(error_type()) break except Errors.KafkaConnectionError: log.debug("Broker %s disconnected while awaiting InitProducerId response", node_id) except Errors.RequestTimedOutError: log.debug("InitProducerId request to node %s timed out", node_id) + log.debug("Retry InitProducerIdRequest in %sms.", self.config['retry_backoff_ms']) time.sleep(self.config['retry_backoff_ms'] / 1000) def _failed_produce(self, batches, node_id, error): @@ -271,13 +364,29 @@ def _handle_produce_response(self, node_id, send_time, batches, response): for batch in batches: self._complete_batch(batch, None, -1) - def _fail_batch(self, batch, *args, **kwargs): - if self._transaction_state and self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: - # Reset the transaction state since we have hit an irrecoverable exception and cannot make any guarantees - # about the previously committed message. Note that this will discard the producer id and sequence - # numbers for all existing partitions. - self._transaction_state.reset_producer_id() - batch.done(*args, **kwargs) + def _fail_batch(self, batch, exception, base_offset=None, timestamp_ms=None, log_start_offset=None): + log.exception(exception) + if self._transaction_manager: + if isinstance(exception, Errors.OutOfOrderSequenceNumberError) and \ + not self._transaction_manager.is_transactional() and \ + self._transaction_manager.has_producer_id(batch.producer_id): + log.error("The broker received an out of order sequence number for topic-partition %s" + " at offset %s. This indicates data loss on the broker, and should be investigated.", + batch.topic_partition, base_offset) + + # Reset the transaction state since we have hit an irrecoverable exception and cannot make any guarantees + # about the previously committed message. Note that this will discard the producer id and sequence + # numbers for all existing partitions. + self._transaction_manager.reset_producer_id() + elif isinstance(exception, (Errors.ClusterAuthorizationFailedError, + Errors.TransactionalIdAuthorizationFailedError, + Errors.ProducerFencedError, + Errors.InvalidTxnStateError)): + self._transaction_manager.transition_to_fatal_error(exception) + elif self._transaction_manager.is_transactional(): + self._transaction_manager.transition_to_abortable_error(exception) + + batch.done(base_offset=base_offset, timestamp_ms=timestamp_ms, exception=exception, log_start_offset=log_start_offset) self._accumulator.deallocate(batch) if self._sensors: self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) @@ -286,7 +395,7 @@ def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_star """Complete or retry the given batch of records. Arguments: - batch (RecordBatch): The record batch + batch (ProducerBatch): The record batch error (Exception): The error (or None if none) base_offset (int): The base offset assigned to the records if successful timestamp_ms (int, optional): The timestamp returned by the broker for this batch @@ -305,29 +414,25 @@ def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_star self.config['retries'] - batch.attempts - 1, error) - # If idempotence is enabled only retry the request if the current PID is the same as the pid of the batch. - if not self._transaction_state or self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: + # If idempotence is enabled only retry the request if the batch matches our current producer id and epoch + if not self._transaction_manager or self._transaction_manager.producer_id_and_epoch.match(batch): log.debug("Retrying batch to topic-partition %s. Sequence number: %s", batch.topic_partition, - self._transaction_state.sequence_number(batch.topic_partition) if self._transaction_state else None) + self._transaction_manager.sequence_number(batch.topic_partition) if self._transaction_manager else None) self._accumulator.reenqueue(batch) if self._sensors: self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) else: - log.warning("Attempted to retry sending a batch but the producer id changed from %s to %s. This batch will be dropped" % ( - batch.producer_id, self._transaction_state.producer_id_and_epoch.producer_id)) - self._fail_batch(batch, base_offset=base_offset, timestamp_ms=timestamp_ms, exception=error, log_start_offset=log_start_offset) + log.warning("Attempted to retry sending a batch but the producer id/epoch changed from %s/%s to %s/%s. This batch will be dropped" % ( + batch.producer_id, batch.producer_epoch, + self._transaction_manager.producer_id_and_epoch.producer_id, self._transaction_manager.producer_id_and_epoch.epoch)) + self._fail_batch(batch, error(), base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) else: - if error is Errors.OutOfOrderSequenceNumberError and batch.producer_id == self._transaction_state.producer_id_and_epoch.producer_id: - log.error("The broker received an out of order sequence number error for produer_id %s, topic-partition %s" - " at offset %s. This indicates data loss on the broker, and should be investigated.", - batch.producer_id, batch.topic_partition, base_offset) - if error is Errors.TopicAuthorizationFailedError: error = error(batch.topic_partition.topic) # tell the user the result of their request - self._fail_batch(batch, base_offset=base_offset, timestamp_ms=timestamp_ms, exception=error, log_start_offset=log_start_offset) + self._fail_batch(batch, error(), base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) if error is Errors.UnknownTopicOrPartitionError: log.warning("Received unknown topic or partition error in produce request on partition %s." @@ -341,10 +446,10 @@ def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_star batch.done(base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) self._accumulator.deallocate(batch) - if self._transaction_state and self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: - self._transaction_state.increment_sequence_number(batch.topic_partition, batch.record_count) + if self._transaction_manager and self._transaction_manager.producer_id_and_epoch.match(batch): + self._transaction_manager.increment_sequence_number(batch.topic_partition, batch.record_count) log.debug("Incremented sequence number for topic-partition %s to %s", batch.topic_partition, - self._transaction_state.sequence_number(batch.topic_partition)) + self._transaction_manager.sequence_number(batch.topic_partition)) # Unmute the completed partition. if self.config['guarantee_message_order']: @@ -364,7 +469,7 @@ def _create_produce_requests(self, collated): per-node basis. Arguments: - collated: {node_id: [RecordBatch]} + collated: {node_id: [ProducerBatch]} Returns: dict: {node_id: ProduceRequest} (version depends on client api_versions) @@ -391,14 +496,25 @@ def _produce_request(self, node_id, acks, timeout, batches): produce_records_by_partition[topic][partition] = buf version = self._client.api_version(ProduceRequest, max_version=7) - # TODO: support transactional_id - return ProduceRequest[version]( - required_acks=acks, - timeout=timeout, - topics=[(topic, list(partition_info.items())) - for topic, partition_info - in six.iteritems(produce_records_by_partition)], - ) + topic_partition_data = [ + (topic, list(partition_info.items())) + for topic, partition_info in six.iteritems(produce_records_by_partition)] + transactional_id = self._transaction_manager.transactional_id if self._transaction_manager else None + if version >= 3: + return ProduceRequest[version]( + transactional_id=transactional_id, + required_acks=acks, + timeout=timeout, + topics=topic_partition_data, + ) + else: + if transactional_id is not None: + log.warning('Broker does not support ProduceRequest v3+, required for transactional_id') + return ProduceRequest[version]( + required_acks=acks, + timeout=timeout, + topics=topic_partition_data, + ) def wakeup(self): """Wake up the selector associated with this send thread.""" From 9d57758318336a1726b0134632b56e6bdf4923a3 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Sun, 6 Apr 2025 07:56:34 -0700 Subject: [PATCH 09/27] KafkaProducer --- kafka/producer/kafka.py | 109 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 320a1657f..6bd5c0bde 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -19,7 +19,7 @@ from kafka.producer.future import FutureRecordMetadata, FutureProduceResult from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator from kafka.producer.sender import Sender -from kafka.producer.transaction_state import TransactionState +from kafka.producer.transaction_manager import TransactionManager from kafka.record.default_records import DefaultRecordBatchBuilder from kafka.record.legacy_records import LegacyRecordBatchBuilder from kafka.serializer import Serializer @@ -318,6 +318,8 @@ class KafkaProducer(object): 'key_serializer': None, 'value_serializer': None, 'enable_idempotence': False, + 'transactional_id': None, + 'transaction_timeout_ms': 60000, 'acks': 1, 'bootstrap_topics_filter': set(), 'compression_type': None, @@ -444,9 +446,28 @@ def __init__(self, **configs): assert checker(), "Libraries for {} compression codec not found".format(ct) self.config['compression_attrs'] = compression_attrs - self._transaction_state = None + self._metadata = client.cluster + self._transaction_manager = None + self._init_transactions_result = None + if 'enable_idempotence' in user_provided_configs and not self.config['enable_idempotence'] and self.config['transactional_id']: + raise Errors.KafkaConfigurationError("Cannot set transactional_id without enable_idempotence.") + + if self.config['transactional_id']: + self.config['enable_idempotence'] = True + if self.config['enable_idempotence']: - self._transaction_state = TransactionState() + self._transaction_manager = TransactionManager( + transactional_id=self.config['transactional_id'], + transaction_timeout_ms=self.config['transaction_timeout_ms'], + retry_backoff_ms=self.config['retry_backoff_ms'], + api_version=self.config['api_version'], + metadata=self._metadata, + ) + if self._transaction_manager.is_transactional(): + log.info("Instantiated a transactional producer.") + else: + log.info("Instantiated an idempotent producer.") + if 'retries' not in user_provided_configs: log.info("Overriding the default 'retries' config to 3 since the idempotent producer is enabled.") self.config['retries'] = 3 @@ -470,15 +491,14 @@ def __init__(self, **configs): message_version = self.max_usable_produce_magic(self.config['api_version']) self._accumulator = RecordAccumulator( - transaction_state=self._transaction_state, + transaction_manager=self._transaction_manager, message_version=message_version, **self.config) - self._metadata = client.cluster guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1) self._sender = Sender(client, self._metadata, self._accumulator, metrics=self._metrics, - transaction_state=self._transaction_state, + transaction_manager=self._transaction_manager, guarantee_message_order=guarantee_message_order, **self.config) self._sender.daemon = True @@ -610,6 +630,79 @@ def _estimate_size_in_bytes(self, key, value, headers=[]): return LegacyRecordBatchBuilder.estimate_size_in_bytes( magic, self.config['compression_type'], key, value) + def init_transactions(self): + """ + Needs to be called before any other methods when the transactional.id is set in the configuration. + + This method does the following: + 1. Ensures any transactions initiated by previous instances of the producer with the same + transactional_id are completed. If the previous instance had failed with a transaction in + progress, it will be aborted. If the last transaction had begun completion, + but not yet finished, this method awaits its completion. + 2. Gets the internal producer id and epoch, used in all future transactional + messages issued by the producer. + + Note that this method will raise KafkaTimeoutError if the transactional state cannot + be initialized before expiration of `max_block_ms`. It is safe to retry, but once the + transactional state has been successfully initialized, this method should no longer be used. + + Raises: + IllegalStateError: if no transactional_id has been configured + AuthorizationError: fatal error indicating that the configured + transactional_id is not authorized. + KafkaError: if the producer has encountered a previous fatal error or for any other unexpected error + KafkaTimeoutError: if the time taken for initialize the transaction has surpassed `max.block.ms`. + """ + if not self._transaction_manager: + raise Errors.IllegalStateError("Cannot call init_transactions without setting a transactional_id.") + if self._init_transactions_result is None: + self._init_transactions_result = self._transaction_manager.initialize_transactions() + self._sender.wakeup() + + if self._init_transactions_result.wait(timeout_ms=self.config['max_block_ms']): + self._init_transactions_result = None + else: + raise Errors.KafkaTimeoutError("Timeout expired while initializing transactional state in %s ms." % (self.config['max_block_ms'],)) + + def begin_transaction(self): + """ Should be called before the start of each new transaction. + + Note that prior to the first invocation of this method, + you must invoke `init_transactions()` exactly one time. + + Raises: + ProducerFencedError if another producer is with the same + transactional_id is active. + """ + # Set the transactional bit in the producer. + if not self._transaction_manager: + raise Errors.IllegalStateError("Cannot use transactional methods without enabling transactions") + self._transaction_manager.begin_transaction() + + def commit_transaction(self): + """ Commits the ongoing transaction. + + Raises: ProducerFencedError if another producer with the same + transactional_id is active. + """ + if not self._transaction_manager: + raise Errors.IllegalStateError("Cannot commit transaction since transactions are not enabled") + result = self._transaction_manager.begin_commit() + self._sender.wakeup() + result.wait() + + def abort_transaction(self): + """ Aborts the ongoing transaction. + + Raises: ProducerFencedError if another producer with the same + transactional_id is active. + """ + if not self._transaction_manager: + raise Errors.IllegalStateError("Cannot abort transaction since transactions are not enabled.") + result = self._transaction_manager.begin_abort() + self._sender.wakeup() + result.wait() + def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None): """Publish a message to a topic. @@ -687,6 +780,10 @@ def send(self, topic, value=None, key=None, headers=None, partition=None, timest tp = TopicPartition(topic, partition) log.debug("Sending (key=%r value=%r headers=%r) to %s", key, value, headers, tp) + + if self._transaction_manager and self._transaction_manager.is_transactional(): + self._transaction_manager.maybe_add_partition_to_transaction(tp) + result = self._accumulator.append(tp, timestamp_ms, key_bytes, value_bytes, headers) future, batch_is_full, new_batch_created = result From ff27b06b406d1fd0803d956ab40bebc7e03d9cfc Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Mon, 7 Apr 2025 18:36:22 -0700 Subject: [PATCH 10/27] fixup record_accumulator test --- test/test_record_accumulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_record_accumulator.py b/test/test_record_accumulator.py index babff5617..42f980712 100644 --- a/test/test_record_accumulator.py +++ b/test/test_record_accumulator.py @@ -17,7 +17,7 @@ def test_producer_batch_producer_id(): magic=2, compression_type=0, batch_size=100000) batch = ProducerBatch(tp, records) assert batch.producer_id == -1 - batch.records.set_producer_state(123, 456, 789) + batch.records.set_producer_state(123, 456, 789, False) assert batch.producer_id == 123 records.close() assert batch.producer_id == 123 From 937bfa97cd8b6bb58494bf03f9cbdf29232eaa16 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:10:17 -0700 Subject: [PATCH 11/27] transaction manager comment typo --- kafka/producer/transaction_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index dea0456ad..fe65f322a 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -404,7 +404,7 @@ def adjust_sequences_due_to_failed_batch(self, batch): # If a batch is failed fatally, the sequence numbers for future batches bound for the partition must be adjusted # so that they don't fail with the OutOfOrderSequenceNumberError. # - # This method must only be called when we know that the batch is question has been unequivocally failed by the broker, + # This method must only be called when we know that the batch in question has been unequivocally failed by the broker, # ie. it has received a confirmed fatal status code like 'Message Too Large' or something similar. with self._lock: if batch.topic_partition not in self._next_sequence: From 1f997696077744feb45c81987d7ab5d340a9aadf Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:19:27 -0700 Subject: [PATCH 12/27] Revert KAFKA-5494 / idempotent producer with multiple inflight requests --- kafka/producer/record_accumulator.py | 2 +- kafka/producer/sender.py | 9 - kafka/producer/transaction_manager.py | 261 ++------------------------ 3 files changed, 12 insertions(+), 260 deletions(-) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 3e484d3c7..73bced9e1 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -597,7 +597,7 @@ def abort_undrained_batches(self, error): tp = batch.topic_partition with self._tp_locks[tp]: aborted = False - if (self._transaction_manager and not batch.has_sequence) or (not self._transaction_manager and not batch.is_done): + if not batch.is_done: aborted = True batch.records.close() self._batches[tp].remove(batch) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 46330a251..52154dc5b 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -103,17 +103,9 @@ def run_once(self): if self._transaction_manager: try: - if self._transaction_manager.should_reset_producer_state_after_resolving_sequences(): - # Check if the previous run expired batches which requires a reset of the producer state. - self._transaction_manager.reset_producer_id() if not self._transaction_manager.is_transactional(): # this is an idempotent producer, so make sure we have a producer id self._maybe_wait_for_producer_id() - elif self._transaction_manager.has_unresolved_sequences() and not self._transaction_manager.has_fatal_error(): - self._transaction_manager.transition_to_fatal_error( - Errors.KafkaError("The client hasn't received acknowledgment for" - " some previously sent messages and can no longer retry them." - " It isn't safe to continue.")) elif self._transaction_manager.has_in_flight_transactional_request() or self._maybe_send_transactional_request(): # as long as there are outstanding transactional requests, we simply wait for them to return self._client.poll(timeout_ms=self.config['retry_backoff_ms']) @@ -313,7 +305,6 @@ def _maybe_wait_for_producer_id(self): error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: self._transaction_manager.set_producer_id_and_epoch(response.producer_id, response.producer_epoch) - return elif getattr(error_type, 'retriable', False): log.debug("Retriable error from InitProducerId response: %s", error_type.__name__) if getattr(error_type, 'invalid_metadata', False): diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index fe65f322a..4ce36d289 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -105,19 +105,9 @@ class TransactionManager(object): def __init__(self, transactional_id=None, transaction_timeout_ms=0, retry_backoff_ms=100, api_version=(0, 11), metadata=None): self._api_version = api_version self._metadata = metadata - # Keep track of the in flight batches bound for a partition, ordered by sequence. This helps us to ensure that - # we continue to order batches by the sequence numbers even when the responses come back out of order during - # leader failover. We add a batch to the queue when it is drained, and remove it when the batch completes - # (either successfully or through a fatal failure). - # use heapq methods to push/pop from queues - self._in_flight_batches_by_sequence = collections.defaultdict(list) - self._in_flight_batches_sort_id = 0 - - # The base sequence of the next batch bound for a given partition. - self._next_sequence = collections.defaultdict(lambda: 0) - # The sequence of the last record of the last ack'd batch from the given partition. When there are no - # in flight requests for a partition, the self._last_acked_sequence(topicPartition) == nextSequence(topicPartition) - 1. - self._last_acked_sequence = collections.defaultdict(lambda: -1) + + self._sequence_numbers = collections.defaultdict(lambda: 0) + self.transactional_id = transactional_id self.transaction_timeout_ms = transaction_timeout_ms self._transaction_coordinator = None @@ -136,18 +126,6 @@ def __init__(self, transactional_id=None, transaction_timeout_ms=0, retry_backof self._pending_requests_sort_id = 0 self._in_flight_request_correlation_id = self.NO_INFLIGHT_REQUEST_CORRELATION_ID - # If a batch bound for a partition expired locally after being sent at least once, the partition has is considered - # to have an unresolved state. We keep track fo such partitions here, and cannot assign any more sequence numbers - # for this partition until the unresolved state gets cleared. This may happen if other inflight batches returned - # successfully (indicating that the expired batch actually made it to the broker). If we don't get any successful - # responses for the partition once the inflight request count falls to zero, we reset the producer id and - # consequently clear this data structure as well. - self._partitions_with_unresolved_sequences = set() - self._inflight_batches_by_sequence = dict() - # We keep track of the last acknowledged offset on a per partition basis in order to disambiguate UnknownProducer - # responses which are due to the retention period elapsing, and those which are due to actual lost data. - self._last_acked_offset = collections.defaultdict(lambda: -1) - # This is used by the TxnRequestHandlers to control how long to back off before a given request is retried. # For instance, this value is lowered by the AddPartitionsToTxnHandler when it receives a CONCURRENT_TRANSACTIONS # error for the first AddPartitionsRequest in a transaction. @@ -159,7 +137,7 @@ def initialize_transactions(self): self._ensure_transactional() self._transition_to(TransactionState.INITIALIZING) self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) - self._next_sequence.clear() + self._sequence_numbers.clear() handler = InitProducerIdHandler(self, self.transactional_id, self.transaction_timeout_ms) self._enqueue_request(handler) return handler.result @@ -319,195 +297,22 @@ def reset_producer_id(self): " You must either abort the ongoing transaction or" " reinitialize the transactional producer instead") self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) - self._next_sequence.clear() - self._last_acked_sequence.clear() - self._inflight_batches_by_sequence.clear() - self._partitions_with_unresolved_sequences.clear() - self._last_acked_offset.clear() + self._sequence_numbers.clear() def sequence_number(self, tp): with self._lock: - return self._next_sequence[tp] + return self._sequence_numbers[tp] def increment_sequence_number(self, tp, increment): with self._lock: - if tp not in self._next_sequence: + if tp not in self._sequence_numbers: raise Errors.IllegalStateError("Attempt to increment sequence number for a partition with no current sequence.") # Sequence number wraps at java max int - base = self._next_sequence[tp] + base = self._sequence_numbers[tp] if base > (2147483647 - increment): - self._next_sequence[tp] = increment - (2147483647 - base) - 1 + self._sequence_numbers[tp] = increment - (2147483647 - base) - 1 else: - self._next_sequence[tp] += increment - - def _next_in_flight_batches_sort_id(self): - self._in_flight_batches_sort_id += 1 - return self._in_flight_batches_sort_id - - def add_in_flight_batch(self, batch): - with self._lock: - if not batch.has_sequence(): - raise Errors.IllegalStateError("Can't track batch for partition %s when sequence is not set." % (batch.topic_partition,)) - heapq.heappush( - self._in_flight_batches_by_sequence[batch.topic_partition], - (batch.base_sequence, self._next_in_flight_batches_sort_id(), batch) - ) - - def first_in_flight_sequence(self, tp): - """ - Returns the first inflight sequence for a given partition. This is the base sequence of an inflight batch with - the lowest sequence number. If there are no inflight requests being tracked for this partition, this method will return -1 - """ - with self._lock: - if not self._in_flight_batches_by_sequence[tp]: - return NO_SEQUENCE - else: - return self._in_flight_batches_by_sequence[tp][0][2].base_sequence - - def next_batch_by_sequence(self, tp): - with self._lock: - if not self._in_flight_batches_by_sequence[tp]: - return None - else: - return self._in_flight_batches_by_sequence[tp][0][2] - - def remove_in_flight_batch(self, batch): - with self._lock: - if not self._in_flight_batches_by_sequence[batch.topic_partition]: - return - else: - try: - # see https://stackoverflow.com/questions/10162679/python-delete-element-from-heap - queue = self._in_flight_batches_by_sequence[batch.topic_partition] - idx = [item[2] for item in queue].index(batch) - queue[idx] = queue[-1] - queue.pop() - heapq.heapify(queue) - except ValueError: - pass - - def maybe_update_last_acked_sequence(self, tp, sequence): - with self._lock: - if sequence > self._last_acked_sequence[tp]: - self._last_acked_sequence[tp] = sequence - - def update_last_acked_offset(self, base_offset, batch): - if base_offset == -1: - return - last_offset = base_offset + batch.record_count - 1 - if last_offset > self._last_acked_offset[batch.topic_partition]: - self._last_acked_offset[batch.topic_partition] = last_offset - else: - log.debug("Partition %s keeps last_offset at %s", batch.topic_partition, last_offset) - - def adjust_sequences_due_to_failed_batch(self, batch): - # If a batch is failed fatally, the sequence numbers for future batches bound for the partition must be adjusted - # so that they don't fail with the OutOfOrderSequenceNumberError. - # - # This method must only be called when we know that the batch in question has been unequivocally failed by the broker, - # ie. it has received a confirmed fatal status code like 'Message Too Large' or something similar. - with self._lock: - if batch.topic_partition not in self._next_sequence: - # Sequence numbers are not being tracked for this partition. This could happen if the producer id was just - # reset due to a previous OutOfOrderSequenceNumberError. - return - log.debug("producer_id: %s, send to partition %s failed fatally. Reducing future sequence numbers by %s", - batch.producer_id, batch.topic_partition, batch.record_count) - current_sequence = self.sequence_number(batch.topic_partition) - current_sequence -= batch.record_count - if current_sequence < 0: - raise Errors.IllegalStateError( - "Sequence number for partition %s is going to become negative: %s" % (batch.topic_partition, current_sequence)) - - self._set_next_sequence(batch.topic_partition, current_sequence) - - for in_flight_batch in self._in_flight_batches_by_sequence[batch.topic_partition]: - if in_flight_batch.base_sequence < batch.base_sequence: - continue - new_sequence = in_flight_batch.base_sequence - batch.record_count - if new_sequence < 0: - raise Errors.IllegalStateError( - "Sequence number for batch with sequence %s for partition %s is going to become negative: %s" % ( - in_flight_batch.base_sequence, batch.topic_partition, new_sequence)) - - log.info("Resetting sequence number of batch with current sequence %s for partition %s to %s", - in_flight_batch.base_sequence(), batch.topic_partition, new_sequence) - in_flight_batch.reset_producer_state( - ProducerIdAndEpoch(in_flight_batch.producer_id, in_flight_batch.producer_epoch), - new_sequence, - in_flight_batch.is_transactional()) - - def _start_sequences_at_beginning(self, tp): - with self._lock: - sequence = 0 - for in_flight_batch in self._in_flight_batches_by_sequence[tp]: - log.info("Resetting sequence number of batch with current sequence %s for partition %s to %s", - in_flight_batch.base_sequence, in_flight_batch.topic_partition, sequence) - in_flight_batch.reset_producer_state( - ProducerIdAndEpoch(in_flight_batch.producer_id, in_flight_batch.producer_epoch), - sequence, - in_flight_batch.is_transactional()) - sequence += in_flight_batch.record_count - self._set_next_sequence(tp, sequence) - try: - del self._last_acked_sequence[tp] - except KeyError: - pass - - def has_in_flight_batches(self, tp): - with self._lock: - return len(self._in_flight_batches_by_sequence[tp]) > 0 - - def has_unresolved_sequences(self): - with self._lock: - return len(self._partitions_with_unresolved_sequences) > 0 - - def has_unresolved_sequence(self, tp): - with self._lock: - return tp in self._partitions_with_unresolved_sequences - - def mark_sequence_unresolved(self, tp): - with self._lock: - log.debug("Marking partition %s unresolved", tp) - self._partitions_with_unresolved_sequences.add(tp) - - # Checks if there are any partitions with unresolved partitions which may now be resolved. Returns True if - # the producer id needs a reset, False otherwise. - def should_reset_producer_state_after_resolving_sequences(self): - with self._lock: - try: - remove = set() - if self.is_transactional(): - # We should not reset producer state if we are transactional. We will transition to a fatal error instead. - return False - for tp in self._partitions_with_unresolved_sequences: - if not self.has_in_flight_batches(tp): - # The partition has been fully drained. At this point, the last ack'd sequence should be once less than - # next sequence destined for the partition. If so, the partition is fully resolved. If not, we should - # reset the sequence number if necessary. - if self.is_next_sequence(tp, self.sequence_number(tp)): - # This would happen when a batch was expired, but subsequent batches succeeded. - remove.add(tp) - else: - # We would enter this branch if all in flight batches were ultimately expired in the producer. - log.info("No inflight batches remaining for %s, last ack'd sequence for partition is %s, next sequence is %s." - " Going to reset producer state.", tp, self._last_acked_sequence(tp), self.sequence_number(tp)) - return True - return False - finally: - self._partitions_with_unresolved_sequences -= remove - - def is_next_sequence(self, tp, sequence): - with self._lock: - return sequence - self._last_acked_sequence(tp) == 1 - - def _set_next_sequence(self, tp, sequence): - with self._lock: - if tp not in self._next_sequence and sequence != 0: - raise Errors.IllegalStateError( - "Trying to set the sequence number for %s to %s but the sequence number was never set for this partition." % ( - tp, sequence)) - self._next_sequence[tp] = sequence + self._sequence_numbers[tp] += increment def next_request_handler(self, has_incomplete_batches): with self._lock: @@ -584,7 +389,7 @@ def has_abortable_error(self): return self._current_state == TransactionState.ABORTABLE_ERROR # visible for testing - def transactionContainsPartition(self, tp): + def transaction_contains_partition(self, tp): with self._lock: return tp in self._partitions_in_transaction @@ -594,50 +399,6 @@ def has_ongoing_transaction(self): # transactions are considered ongoing once started until completion or a fatal error return self._current_state == TransactionState.IN_TRANSACTION or self.is_completing() or self.has_abortable_error() - def can_retry(self, batch, error, log_start_offset): - with self._lock: - if not self.has_producer_id(batch.producer_id): - return False - - elif ( - error is Errors.OutOfOrderSequenceNumberError - and not self.has_unresolved_sequence(batch.topic_partition) - and (batch.sequence_has_been_reset() or not self.is_next_sequence(batch.topic_partition, batch.base_sequence)) - ): - # We should retry the OutOfOrderSequenceNumberError if the batch is _not_ the next batch, ie. its base - # sequence isn't the self._last_acked_sequence + 1. However, if the first in flight batch fails fatally, we will - # adjust the sequences of the other inflight batches to account for the 'loss' of the sequence range in - # the batch which failed. In this case, an inflight batch will have a base sequence which is - # the self._last_acked_sequence + 1 after adjustment. When this batch fails with an OutOfOrderSequenceNumberError, we want to retry it. - # To account for the latter case, we check whether the sequence has been reset since the last drain. - # If it has, we will retry it anyway. - return True - - elif error is Errors.UnknownProducerIdError: - if log_start_offset == -1: - # We don't know the log start offset with this response. We should just retry the request until we get it. - # The UNKNOWN_PRODUCER_ID error code was added along with the new ProduceResponse which includes the - # logStartOffset. So the '-1' sentinel is not for backward compatibility. Instead, it is possible for - # a broker to not know the logStartOffset at when it is returning the response because the partition - # may have moved away from the broker from the time the error was initially raised to the time the - # response was being constructed. In these cases, we should just retry the request: we are guaranteed - # to eventually get a logStartOffset once things settle down. - return True - - if batch.sequence_has_been_reset(): - # When the first inflight batch fails due to the truncation case, then the sequences of all the other - # in flight batches would have been restarted from the beginning. However, when those responses - # come back from the broker, they would also come with an UNKNOWN_PRODUCER_ID error. In this case, we should not - # reset the sequence numbers to the beginning. - return True - elif self._last_acked_offset(batch.topic_partition) < log_start_offset: - # The head of the log has been removed, probably due to the retention time elapsing. In this case, - # we expect to lose the producer state. Reset the sequences of all inflight batches to be from the beginning - # and retry them. - self._start_sequences_at_beginning(batch.topic_partition) - return True - return False - # visible for testing def is_ready(self): with self._lock: From a586ad7c36bca630eb2a67996d57364ea5540687 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:21:04 -0700 Subject: [PATCH 13/27] sender fixes --- kafka/producer/sender.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 52154dc5b..741c4c88e 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -164,18 +164,31 @@ def _send_producer_data(self): expired_batches = self._accumulator.abort_expired_batches( self.config['request_timeout_ms'], self._metadata) + if expired_batches: + log.debug("Expired %s batches in accumulator", len(expired_batches)) + # Reset the producer_id if an expired batch has previously been sent to the broker. # See the documentation of `TransactionState.reset_producer_id` to understand why # we need to reset the producer id here. if self._transaction_manager and any([batch.in_retry() for batch in expired_batches]): - self._transaction_manager.reset_producer_id() - return 0 + needs_transaction_state_reset = True + else: + needs_transaction_state_reset = False + + for expired_batch in expired_batches: + error = Errors.KafkaTimeoutError( + "Expiring %d record(s) for %s: %s ms has passed since batch creation" % ( + expired_batch.record_count, expired_batch.topic_partition, + int((time.time() - expired_batch.created) * 1000))) + self._fail_batch(expired_batch, error, base_offset=-1) if self._sensors: - for expired_batch in expired_batches: - self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count) self._sensors.update_produce_request_metrics(batches_by_node) + if needs_transaction_state_reset: + self._transaction_manager.reset_producer_id() + return 0 + requests = self._create_produce_requests(batches_by_node) # If we have any nodes that are ready to send + have sendable data, # poll with 0 timeout so this can immediately loop and try sending more @@ -356,7 +369,7 @@ def _handle_produce_response(self, node_id, send_time, batches, response): self._complete_batch(batch, None, -1) def _fail_batch(self, batch, exception, base_offset=None, timestamp_ms=None, log_start_offset=None): - log.exception(exception) + exception = exception if type(exception) is not type else exception() if self._transaction_manager: if isinstance(exception, Errors.OutOfOrderSequenceNumberError) and \ not self._transaction_manager.is_transactional() and \ @@ -417,13 +430,13 @@ def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_star log.warning("Attempted to retry sending a batch but the producer id/epoch changed from %s/%s to %s/%s. This batch will be dropped" % ( batch.producer_id, batch.producer_epoch, self._transaction_manager.producer_id_and_epoch.producer_id, self._transaction_manager.producer_id_and_epoch.epoch)) - self._fail_batch(batch, error(), base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) + self._fail_batch(batch, error, base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) else: if error is Errors.TopicAuthorizationFailedError: error = error(batch.topic_partition.topic) # tell the user the result of their request - self._fail_batch(batch, error(), base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) + self._fail_batch(batch, error, base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) if error is Errors.UnknownTopicOrPartitionError: log.warning("Received unknown topic or partition error in produce request on partition %s." From 1ed98914c8ae2b7edee8bc4117d7f832ea194576 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:44:27 -0700 Subject: [PATCH 14/27] check is_send_to_partitions_allowed --- kafka/producer/record_accumulator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 73bced9e1..83802ef96 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -488,6 +488,8 @@ def drain(self, cluster, nodes, max_size, now=None): else: producer_id_and_epoch = None if self._transaction_manager: + if not self._transaction_manager.is_send_to_partition_allowed(tp): + break producer_id_and_epoch = self._transaction_manager.producer_id_and_epoch if not producer_id_and_epoch.is_valid: # we cannot send the batch until we have refreshed the PID From 5b205c36f25524813661bdad72e35483f89526cb Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:48:52 -0700 Subject: [PATCH 15/27] prefix test-only methods --- kafka/producer/transaction_manager.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index 4ce36d289..3c04bcd10 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -251,12 +251,12 @@ def transition_to_fatal_error(self, exc): self._transition_to(TransactionState.FATAL_ERROR, error=exc) # visible for testing - def is_partition_added(self, partition): + def _test_is_partition_added(self, partition): with self._lock: return partition in self._partitions_in_transaction # visible for testing - def is_partition_pending_add(self, partition): + def _test_is_partition_pending_add(self, partition): return partition in self._new_partitions_in_transaction or partition in self._pending_partitions_in_transaction def has_producer_id_and_epoch(self, producer_id, producer_epoch): @@ -380,27 +380,25 @@ def clear_in_flight_transactional_request_correlation_id(self): def has_in_flight_transactional_request(self): return self._in_flight_request_correlation_id != self.NO_INFLIGHT_REQUEST_CORRELATION_ID - # visible for testing. def has_fatal_error(self): return self._current_state == TransactionState.FATAL_ERROR - # visible for testing. def has_abortable_error(self): return self._current_state == TransactionState.ABORTABLE_ERROR # visible for testing - def transaction_contains_partition(self, tp): + def _test_transaction_contains_partition(self, tp): with self._lock: return tp in self._partitions_in_transaction # visible for testing - def has_ongoing_transaction(self): + def _test_has_ongoing_transaction(self): with self._lock: # transactions are considered ongoing once started until completion or a fatal error return self._current_state == TransactionState.IN_TRANSACTION or self.is_completing() or self.has_abortable_error() # visible for testing - def is_ready(self): + def _test_is_ready(self): with self._lock: return self.is_transactional() and self._current_state == TransactionState.READY From 4cad6f13508d1ed47f02b0895447a83eb41a7835 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:21:53 -0700 Subject: [PATCH 16/27] test_sender updates for transaction manager --- test/test_sender.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/test/test_sender.py b/test/test_sender.py index a1a775b59..a833c58cb 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -11,13 +11,14 @@ from kafka.vendor import six from kafka.client_async import KafkaClient +from kafka.cluster import ClusterMetadata import kafka.errors as Errors from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.producer.kafka import KafkaProducer from kafka.protocol.produce import ProduceRequest from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch from kafka.producer.sender import Sender -from kafka.producer.transaction_state import TransactionState +from kafka.producer.transaction_manager import TransactionManager from kafka.record.memory_records import MemoryRecordsBuilder from kafka.structs import TopicPartition @@ -42,6 +43,16 @@ def producer_batch(topic='foo', partition=0, magic=2): return batch +@pytest.fixture +def transaction_manager(): + return TransactionManager( + transactional_id=None, + transaction_timeout_ms=60000, + retry_backoff_ms=100, + api_version=(2, 1), + metadata=ClusterMetadata()) + + @pytest.mark.parametrize(("api_version", "produce_version"), [ ((2, 1), 7), ((0, 10, 0), 2), @@ -85,16 +96,16 @@ def test_complete_batch_success(sender): assert batch.produce_future.value == (0, 123, 456) -def test_complete_batch_transaction(sender): - sender._transaction_state = TransactionState() +def test_complete_batch_transaction(sender, transaction_manager): + sender._transaction_manager = transaction_manager batch = producer_batch() - assert sender._transaction_state.sequence_number(batch.topic_partition) == 0 - assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id + assert sender._transaction_manager.sequence_number(batch.topic_partition) == 0 + assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id # No error, base_offset 0 sender._complete_batch(batch, None, 0) assert batch.is_done - assert sender._transaction_state.sequence_number(batch.topic_partition) == batch.record_count + assert sender._transaction_manager.sequence_number(batch.topic_partition) == batch.record_count @pytest.mark.parametrize(("error", "refresh_metadata"), [ @@ -164,8 +175,8 @@ def test_complete_batch_retry(sender, accumulator, mocker, error, retry): assert isinstance(batch.produce_future.exception, error) -def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker): - sender._transaction_state = TransactionState() +def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, transaction_manager, mocker): + sender._transaction_manager = transaction_manager sender.config['retries'] = 1 mocker.spy(sender, '_fail_batch') mocker.patch.object(accumulator, 'reenqueue') @@ -175,21 +186,21 @@ def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, mocker assert not batch.is_done accumulator.reenqueue.assert_called_with(batch) batch.records._producer_id = 123 # simulate different producer_id - assert batch.producer_id != sender._transaction_state.producer_id_and_epoch.producer_id + assert batch.producer_id != sender._transaction_manager.producer_id_and_epoch.producer_id sender._complete_batch(batch, error, -1) assert batch.is_done assert isinstance(batch.produce_future.exception, error) -def test_fail_batch(sender, accumulator, mocker): - sender._transaction_state = TransactionState() - mocker.patch.object(TransactionState, 'reset_producer_id') +def test_fail_batch(sender, accumulator, transaction_manager, mocker): + sender._transaction_manager = transaction_manager + mocker.patch.object(TransactionManager, 'reset_producer_id') batch = producer_batch() mocker.patch.object(batch, 'done') - assert sender._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id + assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id error = Exception('error') sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) - sender._transaction_state.reset_producer_id.assert_called_once() + sender._transaction_manager.reset_producer_id.assert_called_once() batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) From ab2857299a396b963166217ab3b83cb1760cf8a1 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 8 Apr 2025 10:56:20 -0700 Subject: [PATCH 17/27] Check reset_producer_id for out of order sequence number on idempotent producer only --- test/test_sender.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/test_sender.py b/test/test_sender.py index a833c58cb..ba20759a5 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -194,12 +194,23 @@ def test_complete_batch_producer_id_changed_no_retry(sender, accumulator, transa def test_fail_batch(sender, accumulator, transaction_manager, mocker): sender._transaction_manager = transaction_manager - mocker.patch.object(TransactionManager, 'reset_producer_id') batch = producer_batch() mocker.patch.object(batch, 'done') assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id error = Exception('error') sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) + batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) + + +def test_out_of_order_sequence_number_reset_producer_id(sender, accumulator, transaction_manager, mocker): + sender._transaction_manager = transaction_manager + assert transaction_manager.transactional_id is None # this test is for idempotent producer only + mocker.patch.object(TransactionManager, 'reset_producer_id') + batch = producer_batch() + mocker.patch.object(batch, 'done') + assert sender._transaction_manager.producer_id_and_epoch.producer_id == batch.producer_id + error = Errors.OutOfOrderSequenceNumberError() + sender._fail_batch(batch, base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) sender._transaction_manager.reset_producer_id.assert_called_once() batch.done.assert_called_with(base_offset=0, timestamp_ms=None, exception=error, log_start_offset=None) From 5ef659506fb3bef837006742454cbc238533b257 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 11:40:41 -0700 Subject: [PATCH 18/27] fixup set_producer_id_and_epoch from sender --- kafka/producer/sender.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 741c4c88e..ba0dfba89 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -11,6 +11,7 @@ from kafka import errors as Errors from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate +from kafka.producer.transaction_manager import ProducerIdAndEpoch from kafka.protocol.init_producer_id import InitProducerIdRequest from kafka.protocol.produce import ProduceRequest from kafka.structs import TopicPartition @@ -317,7 +318,7 @@ def _maybe_wait_for_producer_id(self): response = self._client.send_and_receive(node_id, request) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - self._transaction_manager.set_producer_id_and_epoch(response.producer_id, response.producer_epoch) + self._transaction_manager.set_producer_id_and_epoch(ProducerIdAndEpoch(response.producer_id, response.producer_epoch)) elif getattr(error_type, 'retriable', False): log.debug("Retriable error from InitProducerId response: %s", error_type.__name__) if getattr(error_type, 'invalid_metadata', False): From 243b55bc2a995bb78bea4b92157219d3f61452f9 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 11:41:15 -0700 Subject: [PATCH 19/27] add small timeout to producer/consumer fixture closes in test_producer --- test/integration/test_producer_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/integration/test_producer_integration.py b/test/integration/test_producer_integration.py index 303832b9f..f53d6f1ce 100644 --- a/test/integration/test_producer_integration.py +++ b/test/integration/test_producer_integration.py @@ -16,7 +16,7 @@ def producer_factory(**kwargs): try: yield producer finally: - producer.close(timeout=0) + producer.close(timeout=1) @contextmanager @@ -25,7 +25,7 @@ def consumer_factory(**kwargs): try: yield consumer finally: - consumer.close(timeout_ms=0) + consumer.close(timeout_ms=100) @pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @@ -82,7 +82,7 @@ def test_end_to_end(kafka_broker, compression): def test_kafka_producer_gc_cleanup(): gc.collect() threads = threading.active_count() - producer = KafkaProducer(api_version='0.9') # set api_version explicitly to avoid auto-detection + producer = KafkaProducer(api_version=(2, 1)) # set api_version explicitly to avoid auto-detection assert threading.active_count() == threads + 1 del(producer) gc.collect() From 1321f984e02d490904874b5fb8daa63bd61c40e0 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 14:01:43 -0700 Subject: [PATCH 20/27] Do not update reconnect backoff when closing w/o error --- kafka/conn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kafka/conn.py b/kafka/conn.py index 85a9658d4..31e1f8be9 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -934,7 +934,8 @@ def close(self, error=None): if self.state is ConnectionStates.DISCONNECTED: return log.log(logging.ERROR if error else logging.INFO, '%s: Closing connection. %s', self, error or '') - self._update_reconnect_backoff() + if error: + self._update_reconnect_backoff() self._api_versions_future = None self._sasl_auth_future = None self._init_sasl_mechanism() From 549d8cb325138c6aebf3f7098a462dceb7bec9eb Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 14:52:09 -0700 Subject: [PATCH 21/27] Move producer unit test to test/test_producer.py --- test/integration/test_producer_integration.py | 16 ++----------- test/test_producer.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 14 deletions(-) create mode 100644 test/test_producer.py diff --git a/test/integration/test_producer_integration.py b/test/integration/test_producer_integration.py index f53d6f1ce..dd1750471 100644 --- a/test/integration/test_producer_integration.py +++ b/test/integration/test_producer_integration.py @@ -1,8 +1,8 @@ +from __future__ import absolute_import + from contextlib import contextmanager -import gc import platform import time -import threading import pytest @@ -77,18 +77,6 @@ def test_end_to_end(kafka_broker, compression): assert msgs == set(['msg %d' % (i,) for i in range(messages)]) -@pytest.mark.skipif(platform.python_implementation() != 'CPython', - reason='Test relies on CPython-specific gc policies') -def test_kafka_producer_gc_cleanup(): - gc.collect() - threads = threading.active_count() - producer = KafkaProducer(api_version=(2, 1)) # set api_version explicitly to avoid auto-detection - assert threading.active_count() == threads + 1 - del(producer) - gc.collect() - assert threading.active_count() == threads - - @pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") @pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd']) def test_kafka_producer_proper_record_metadata(kafka_broker, compression): diff --git a/test/test_producer.py b/test/test_producer.py new file mode 100644 index 000000000..569df79f9 --- /dev/null +++ b/test/test_producer.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import + +import gc +import platform +import threading + +import pytest + +from kafka import KafkaProducer + +@pytest.mark.skipif(platform.python_implementation() != 'CPython', + reason='Test relies on CPython-specific gc policies') +def test_kafka_producer_gc_cleanup(): + gc.collect() + threads = threading.active_count() + producer = KafkaProducer(api_version=(2, 1)) # set api_version explicitly to avoid auto-detection + assert threading.active_count() == threads + 1 + del(producer) + gc.collect() + assert threading.active_count() == threads + + + From 716b912255a57e02e4a0dd0ee69d9a966f8e9887 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 15:02:11 -0700 Subject: [PATCH 22/27] Transactional/Idempotent producer requires 0.11+ --- kafka/producer/kafka.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 6bd5c0bde..063376b4d 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -456,6 +456,8 @@ def __init__(self, **configs): self.config['enable_idempotence'] = True if self.config['enable_idempotence']: + assert self.config['api_version'] >= (0, 11), "Transactional/Idempotent producer requires >= Kafka 0.11 Brokers" + self._transaction_manager = TransactionManager( transactional_id=self.config['transactional_id'], transaction_timeout_ms=self.config['transaction_timeout_ms'], From 318f3313794b5eb2ee0542dc2820444cf2dab586 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 15:02:30 -0700 Subject: [PATCH 23/27] Very basic idempotent producer integration test --- test/integration/test_producer_integration.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/integration/test_producer_integration.py b/test/integration/test_producer_integration.py index dd1750471..2bb46a914 100644 --- a/test/integration/test_producer_integration.py +++ b/test/integration/test_producer_integration.py @@ -133,3 +133,14 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression): partition=0) record = future.get(timeout=5) assert abs(record.timestamp - send_time) <= 1000 # Allow 1s deviation + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Idempotent producer requires broker >=0.11") +def test_idempotent_producer(kafka_broker): + connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) + producer = KafkaProducer(bootstrap_servers=connect_str, enable_idempotence=True) + try: + for _ in range(10): + producer.send('foo', value=b'idempotent_msg').get(timeout=1) + finally: + producer.close() From f23198308ef704332afeaea0b8ecc4a6df8291e7 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 16:09:39 -0700 Subject: [PATCH 24/27] Handle empty batch lists from accumulator.drain --- kafka/producer/sender.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index ba0dfba89..707d46bf3 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -481,9 +481,10 @@ def _create_produce_requests(self, collated): """ requests = {} for node_id, batches in six.iteritems(collated): - requests[node_id] = self._produce_request( - node_id, self.config['acks'], - self.config['request_timeout_ms'], batches) + if batches: + requests[node_id] = self._produce_request( + node_id, self.config['acks'], + self.config['request_timeout_ms'], batches) return requests def _produce_request(self, node_id, acks, timeout, batches): @@ -682,8 +683,9 @@ def update_produce_request_metrics(self, batches_map): records += batch.record_count total_bytes += batch.records.size_in_bytes() - self.records_per_request_sensor.record(records) - self.byte_rate_sensor.record(total_bytes) + if node_batch: + self.records_per_request_sensor.record(records) + self.byte_rate_sensor.record(total_bytes) def record_retries(self, topic, count): self.retry_sensor.record(count) From 852e550cc334ccf6d28d9fb32a0a8f84fd781860 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 16:24:05 -0700 Subject: [PATCH 25/27] more transaction manager typos --- kafka/producer/transaction_manager.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index 3c04bcd10..2df9ca8b0 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -250,13 +250,11 @@ def transition_to_fatal_error(self, exc): with self._lock: self._transition_to(TransactionState.FATAL_ERROR, error=exc) - # visible for testing - def _test_is_partition_added(self, partition): + def is_partition_added(self, partition): with self._lock: return partition in self._partitions_in_transaction - # visible for testing - def _test_is_partition_pending_add(self, partition): + def is_partition_pending_add(self, partition): return partition in self._new_partitions_in_transaction or partition in self._pending_partitions_in_transaction def has_producer_id_and_epoch(self, producer_id, producer_epoch): @@ -652,7 +650,7 @@ def handle_response(self, response): self.transaction_manager._lookup_coordinator('transaction', self.transactiona_id) self.reenqueue() return - elif error is Errors.ConcurrentTransactionError: + elif error is Errors.ConcurrentTransactionsError: self.maybe_override_retry_backoff_ms() self.reenqueue() return @@ -704,7 +702,7 @@ def maybe_override_retry_backoff_ms(self): # # This is only a temporary fix, the long term solution is being tracked in # https://issues.apache.org/jira/browse/KAFKA-5482 - if not self._partitions_in_transaction: + if not self.transaction_manager._partitions_in_transaction: self.retry_backoff_ms = min(self.transaction_manager.ADD_PARTITIONS_RETRY_BACKOFF_MS, self.retry_backoff_ms) From f67ca6a5d8531ce5a5cb3873f40c1dc3d8be1950 Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 16:27:08 -0700 Subject: [PATCH 26/27] Improve producer.init_transactions() semantics --- kafka/producer/kafka.py | 17 +++++++++++------ kafka/producer/transaction_manager.py | 10 +++++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 063376b4d..1468cec55 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -645,8 +645,11 @@ def init_transactions(self): messages issued by the producer. Note that this method will raise KafkaTimeoutError if the transactional state cannot - be initialized before expiration of `max_block_ms`. It is safe to retry, but once the - transactional state has been successfully initialized, this method should no longer be used. + be initialized before expiration of `max_block_ms`. + + Retrying after a KafkaTimeoutError will continue to wait for the prior request to succeed or fail. + Retrying after any other exception will start a new initialization attempt. + Retrying after a successful initialization will do nothing. Raises: IllegalStateError: if no transactional_id has been configured @@ -661,10 +664,12 @@ def init_transactions(self): self._init_transactions_result = self._transaction_manager.initialize_transactions() self._sender.wakeup() - if self._init_transactions_result.wait(timeout_ms=self.config['max_block_ms']): - self._init_transactions_result = None - else: - raise Errors.KafkaTimeoutError("Timeout expired while initializing transactional state in %s ms." % (self.config['max_block_ms'],)) + try: + if not self._init_transactions_result.wait(timeout_ms=self.config['max_block_ms']): + raise Errors.KafkaTimeoutError("Timeout expired while initializing transactional state in %s ms." % (self.config['max_block_ms'],)) + finally: + if self._init_transactions_result.failed: + self._init_transactions_result = None def begin_transaction(self): """ Should be called before the start of each new transaction. diff --git a/kafka/producer/transaction_manager.py b/kafka/producer/transaction_manager.py index 2df9ca8b0..f5111c780 100644 --- a/kafka/producer/transaction_manager.py +++ b/kafka/producer/transaction_manager.py @@ -498,7 +498,15 @@ def is_done(self): @property def succeeded(self): - return self._error is None and self._latch.is_set() + return self._latch.is_set() and self._error is None + + @property + def failed(self): + return self._latch.is_set() and self._error is not None + + @property + def exception(self): + return self._error @six.add_metaclass(abc.ABCMeta) From 97d2895ee650b7814bf1a237ac5a4cd1ea684ccd Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Tue, 15 Apr 2025 16:27:36 -0700 Subject: [PATCH 27/27] Test producer transaction --- test/integration/test_producer_integration.py | 39 ++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/test/integration/test_producer_integration.py b/test/integration/test_producer_integration.py index 2bb46a914..0739d8eba 100644 --- a/test/integration/test_producer_integration.py +++ b/test/integration/test_producer_integration.py @@ -138,9 +138,38 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression): @pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Idempotent producer requires broker >=0.11") def test_idempotent_producer(kafka_broker): connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) - producer = KafkaProducer(bootstrap_servers=connect_str, enable_idempotence=True) - try: + with producer_factory(bootstrap_servers=connect_str, enable_idempotence=True) as producer: for _ in range(10): - producer.send('foo', value=b'idempotent_msg').get(timeout=1) - finally: - producer.close() + producer.send('idempotent_test_topic', value=b'idempotent_msg').get(timeout=1) + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Idempotent producer requires broker >=0.11") +def test_transactional_producer(kafka_broker): + connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) + with producer_factory(bootstrap_servers=connect_str, transactional_id='testing') as producer: + producer.init_transactions() + producer.begin_transaction() + producer.send('transactional_test_topic', partition=0, value=b'msg1').get() + producer.send('transactional_test_topic', partition=0, value=b'msg2').get() + producer.abort_transaction() + producer.begin_transaction() + producer.send('transactional_test_topic', partition=0, value=b'msg3').get() + producer.send('transactional_test_topic', partition=0, value=b'msg4').get() + producer.commit_transaction() + + messages = set() + consumer_opts = { + 'bootstrap_servers': connect_str, + 'group_id': None, + 'consumer_timeout_ms': 10000, + 'auto_offset_reset': 'earliest', + 'isolation_level': 'read_committed', + } + with consumer_factory(**consumer_opts) as consumer: + consumer.assign([TopicPartition('transactional_test_topic', 0)]) + for msg in consumer: + assert msg.value in {b'msg3', b'msg4'} + messages.add(msg.value) + if messages == {b'msg3', b'msg4'}: + break + assert messages == {b'msg3', b'msg4'}