diff --git a/src/main/kotlin/io/github/nomisRev/kafka/receiver/KafkaReceiver.kt b/src/main/kotlin/io/github/nomisRev/kafka/receiver/KafkaReceiver.kt index 9031669b..94cb11b8 100644 --- a/src/main/kotlin/io/github/nomisRev/kafka/receiver/KafkaReceiver.kt +++ b/src/main/kotlin/io/github/nomisRev/kafka/receiver/KafkaReceiver.kt @@ -43,8 +43,8 @@ private class DefaultKafkaReceiver(private val settings: ReceiverSettings< @OptIn(ExperimentalCoroutinesApi::class) override fun receive(topicNames: Collection): Flow> = scopedConsumer(settings.groupId) { scope, dispatcher, consumer -> - val loop = PollLoop(topicNames, settings, consumer, scope) - loop.receive().flowOn(dispatcher) + val loop = PollLoop(topicNames, settings, consumer, scope, currentCoroutineContext()) + loop.receive() .flatMapConcat { records -> records.map { record -> ReceiverRecord(record, loop.toCommittableOffset(record)) @@ -54,8 +54,8 @@ private class DefaultKafkaReceiver(private val settings: ReceiverSettings< override fun receiveAutoAck(topicNames: Collection): Flow>> = scopedConsumer(settings.groupId) { scope, dispatcher, consumer -> - val loop = PollLoop(topicNames, settings, consumer, scope) - loop.receive().flowOn(dispatcher).map { records -> + val loop = PollLoop(topicNames, settings, consumer, scope, currentCoroutineContext()) + loop.receive().map { records -> records.asFlow() .onCompletion { records.forEach { loop.toCommittableOffset(it).acknowledge() } } } @@ -84,14 +84,15 @@ private class DefaultKafkaReceiver(private val settings: ReceiverSettings< @OptIn(ExperimentalCoroutinesApi::class) fun scopedConsumer( groupId: String, - block: (CoroutineScope, ExecutorCoroutineDispatcher, KafkaConsumer) -> Flow + block: suspend (CoroutineScope, ExecutorCoroutineDispatcher, KafkaConsumer) -> Flow ): Flow = flow { kafkaConsumerDispatcher(groupId).use { dispatcher: ExecutorCoroutineDispatcher -> val job = Job() val scope = CoroutineScope(job + dispatcher + defaultCoroutineExceptionHandler) try { - KafkaConsumer(settings.toProperties(), settings.keyDeserializer, settings.valueDeserializer) - .use { emit(block(scope, dispatcher, it)) } + withContext(dispatcher) { + KafkaConsumer(settings.toProperties(), settings.keyDeserializer, settings.valueDeserializer) + }.use { emit(block(scope, dispatcher, it)) } } finally { job.cancelAndJoin() } diff --git a/src/main/kotlin/io/github/nomisRev/kafka/receiver/internals/PollLoop.kt b/src/main/kotlin/io/github/nomisRev/kafka/receiver/internals/PollLoop.kt index 94254b55..0a4581dd 100644 --- a/src/main/kotlin/io/github/nomisRev/kafka/receiver/internals/PollLoop.kt +++ b/src/main/kotlin/io/github/nomisRev/kafka/receiver/internals/PollLoop.kt @@ -1,9 +1,14 @@ package io.github.nomisRev.kafka.receiver.internals import io.github.nomisRev.kafka.receiver.Offset +import io.github.nomisRev.kafka.receiver.ReceiverRecord import io.github.nomisRev.kafka.receiver.ReceiverSettings +import io.github.nomisRev.kafka.receiver.internals.AckMode.ATMOST_ONCE +import io.github.nomisRev.kafka.receiver.internals.AckMode.AUTO_ACK +import io.github.nomisRev.kafka.receiver.internals.AckMode.EXACTLY_ONCE +import io.github.nomisRev.kafka.receiver.internals.AckMode.MANUAL_ACK import io.github.nomisRev.kafka.receiver.size -import java.time.Duration +import java.time.Duration as JavaDuration import java.time.Duration.ofSeconds import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger @@ -14,12 +19,13 @@ import kotlin.coroutines.suspendCoroutine import kotlin.time.toJavaDuration import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.CoroutineStart.UNDISPATCHED import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.onClosed import kotlinx.coroutines.channels.onFailure +import kotlinx.coroutines.channels.onSuccess import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.consumeAsFlow @@ -32,17 +38,32 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.errors.WakeupException import org.slf4j.Logger import org.slf4j.LoggerFactory +import kotlin.coroutines.CoroutineContext +import kotlin.time.Duration + +/** + * Testing [ConsumerThread] for internal usage. + * Will be removed before 1.0.0 + */ +private val DEBUG: Boolean = true + +private fun checkConsumerThread(msg: String): Unit = + if (DEBUG) require( + Thread.currentThread().name.startsWith("kotlin-kafka-") + ) { "$msg => should run on kotlin-kafka thread, but found ${Thread.currentThread().name}" } + else Unit internal class PollLoop( // TODO also allow for Pattern, and assign private val topicNames: Collection, private val settings: ReceiverSettings, private val consumer: Consumer, - scope: CoroutineScope, + private val consumerScope: CoroutineScope, + outerContext: CoroutineContext, awaitingTransaction: AtomicBoolean = AtomicBoolean(false), private val isActive: AtomicBoolean = AtomicBoolean(true), - private val ackMode: AckMode = AckMode.MANUAL_ACK, - isRetriableException: (Throwable) -> Boolean = { e -> e is RetriableCommitFailedException }, + private val ackMode: AckMode = MANUAL_ACK, + isRetriableCommit: (Throwable) -> Boolean = { e -> e is RetriableCommitFailedException }, ) { private val reachedMaxCommitBatchSize = Channel(Channel.RENDEZVOUS) private val atMostOnceOffset: AtmostOnceOffsets = AtmostOnceOffsets() @@ -50,8 +71,9 @@ internal class PollLoop( ackMode, settings, consumer, - isRetriableException, - scope, + isRetriableCommit, + consumerScope, + outerContext, isActive, awaitingTransaction, atMostOnceOffset @@ -64,7 +86,7 @@ internal class PollLoop( * This way it optimises sending commits to Kafka in an optimised way. * Either every `Duration` or `x` elements, whichever comes first. */ - private val commitManagerJob = scope.launch( + private val commitManagerJob = consumerScope.launch( start = CoroutineStart.LAZY, context = Dispatchers.Default ) { @@ -76,20 +98,23 @@ internal class PollLoop( ) } - fun receive(): Flow> = - loop.channel.consumeAsFlow() + fun receive(): Flow> { + return loop.channel.consumeAsFlow() .onStart { - if (topicNames.isNotEmpty()) loop.subscriber(topicNames) - loop.schedulePoll() + if (topicNames.isNotEmpty()) loop.subscribe(topicNames) + withContext(consumerScope.coroutineContext) { loop.poll() } commitManagerJob.start() }.onCompletion { stop() } + } private suspend fun stop() { - if (!isActive.compareAndSet(true, false)) Unit + isActive.set(false) reachedMaxCommitBatchSize.close() - commitManagerJob.cancelAndJoin() - consumer.wakeup() - loop.close(settings.closeTimeout) + withContext(consumerScope.coroutineContext) { + commitManagerJob.cancelAndJoin() + consumer.wakeup() + loop.close(settings.closeTimeout) + } } internal fun toCommittableOffset(record: ConsumerRecord): CommittableOffset = @@ -124,7 +149,7 @@ internal class CommittableOffset( } } - private /*suspend*/ fun maybeUpdateOffset(): Int = + private suspend fun maybeUpdateOffset(): Int = if (acknowledged.compareAndSet(false, true)) loop.commitBatch.updateOffset(topicPartition, offset) else loop.commitBatch.batchSize() @@ -138,257 +163,322 @@ internal class EventLoop( private val ackMode: AckMode, private val settings: ReceiverSettings, private val consumer: Consumer, - private val isRetriableException: (Throwable) -> Boolean, + private val isRetriableCommit: (Throwable) -> Boolean, private val scope: CoroutineScope, + private val outerContext: CoroutineContext, private val isActive: AtomicBoolean, private val awaitingTransaction: AtomicBoolean, private val atmostOnceOffsets: AtmostOnceOffsets, ) { - private val requesting = AtomicBoolean(true) + /** Atomic state to check if we're poll'ing, or back-pressuring */ + private val isPolling = AtomicBoolean(true) + + /** Atomic state to tracks if we've paused, or not. */ private val pausedByUs = AtomicBoolean(false) + + /** Channel to which we send records */ val channel: Channel> = Channel() + + /** Cached pollTimeout, converted from [kotlin.time.Duration] to [java.time.Duration] */ private val pollTimeout = settings.pollTimeout.toJavaDuration() - private fun onPartitionsRevoked(partitions: Collection) { - if (!partitions.isEmpty()) { - // It is safe to use the consumer here since we are in a poll() - if (ackMode != AckMode.ATMOST_ONCE) { - runCommitIfRequired(true) + /** + * Subscribes to the given [topicNames], + * and in case of failure we rethrow it through the [Channel] by closing with the caught exception. + */ + suspend fun subscribe(topicNames: Collection): Unit = + withContext(scope.coroutineContext) { + try { + consumer.subscribe(topicNames, RebalanceListener()) + } catch (e: Throwable) { + logger.error("Subscribing to $topicNames failed", e) + val closed = channel.close(e) + if (!closed) throw IllegalStateException("Failed to close EventLoop.channel") + .apply { addSuppressed(e) } } - // TODO Setup user listeners - // for (onRevoke in receiverOptions.revokeListeners()) { - // onRevoke.accept(toSeekable(partitions)) - // } - } - } - - fun subscriber(topicNames: Collection): Job = scope.launch { - try { - consumer.subscribe(topicNames, object : ConsumerRebalanceListener { - override fun onPartitionsAssigned(partitions: MutableCollection) { - logger.debug("onPartitionsAssigned $partitions") - var repausedAll = false - if (partitions.isNotEmpty() && pausedByUs.get()) { - logger.debug("Rebalance during back pressure, re-pausing new assignments") - consumer.pause(partitions) - repausedAll = true - } - if (pausedByUser.isNotEmpty()) { - val toRepause = buildList { - //It is necessary to re-pause any user-paused partitions that are re-assigned after the rebalance. - //Also remove any revoked partitions that the user paused from the userPaused collection. - pausedByUser.forEach { tp -> - if (partitions.contains(tp)) add(tp) - else pausedByUser.remove(tp) - } - } - if (!repausedAll && toRepause.isNotEmpty()) { - consumer.pause(toRepause) - } - } - // TODO Setup user listeners - // for (onAssign in receiverOptions.assignListeners()) { - // onAssign.accept(toSeekable(partitions)) - // } - if (logger.isTraceEnabled) { - try { - val positions = partitions.map { part: TopicPartition -> - "$part pos: ${consumer.position(part, ofSeconds(5))}" - } - logger.trace( - "positions: $positions, committed: ${ - consumer.committed( - partitions.toSet(), ofSeconds(5) - ) - }" - ) - } catch (ex: Exception) { - logger.error("Failed to get positions or committed", ex) - } - } - } - - override fun onPartitionsRevoked(partitions: MutableCollection) { - logger.debug("onPartitionsRevoked $partitions") - this@EventLoop.onPartitionsRevoked(partitions) - commitBatch.onPartitionsRevoked(partitions) - } - }) - } catch (e: Throwable) { - logger.error("Unexpected exception", e) - channel.close(e) } - } - private fun checkAndSetPausedByUs(): Boolean { - logger.debug("checkAndSetPausedByUs") + // + /** + * Checks if we need to pause, + * If it was already paused, then we check if we need to wake up the consumer. + * We wake up the consumer, if we're actively polling records and we're currently not retrying sending commits. + */ + @ConsumerThread + private fun pauseAndWakeupIfNeeded(): Boolean { + checkConsumerThread("pauseAndWakeupIfNeeded") val pausedNow = !pausedByUs.getAndSet(true) - if (pausedNow && requesting.get() && !retrying.get()) { + val shouldWakeUpConsumer = pausedNow && isPolling.get() && !isRetryingCommit.get() + logger.debug("checkAndSetPausedByUs: already paused {}, shouldWakeUpConsumer {}", pausedNow, shouldWakeUpConsumer) + if (shouldWakeUpConsumer) { consumer.wakeup() } return pausedNow } - /* - * TODO this can probably be removed - * Race condition where onRequest was called to increase requested but we - * hadn't yet paused the consumer; wake immediately in this case. - */ - private val scheduled = AtomicBoolean() private val pausedByUser: MutableSet = HashSet() - fun schedulePoll(): Job? = - if (!scheduled.getAndSet(true)) scope.launch { - try { - scheduled.set(false) - if (isActive.get()) { - // Ensure that commits are not queued behind polls since number of poll events is chosen by reactor. - runCommitIfRequired(false) - - val pauseForDeferred = - (settings.maxDeferredCommits > 0 && commitBatch.deferredCount() >= settings.maxDeferredCommits) - val shouldPoll: Boolean = if (pauseForDeferred || retrying.get()) false else requesting.get() - - if (shouldPoll) { - if (!awaitingTransaction.get()) { - if (pausedByUs.getAndSet(false)) { - val toResume: MutableSet = HashSet(consumer.assignment()) - toResume.removeAll(pausedByUser) - pausedByUser.clear() - consumer.resume(toResume) - if (logger.isDebugEnabled) { - logger.debug("Resumed partitions: $toResume") - } - } - } else { - if (checkAndSetPausedByUs()) { - pausedByUser.addAll(consumer.paused()) - consumer.pause(consumer.assignment()) - logger.debug("Paused - awaiting transaction") - } - } - } else if (checkAndSetPausedByUs()) { + + @ConsumerThread + fun poll() { + if (!isActive.get()) return + + try { + // Ensure that commits are not queued behind polls since number of poll events is chosen by reactor. + runCommitIfRequired(false) + + val pauseForDeferred = + (settings.maxDeferredCommits > 0 && commitBatch.deferredCount() >= settings.maxDeferredCommits) + val shouldPoll: Boolean = if (pauseForDeferred || isRetryingCommit.get()) false else isPolling.get() + + if (shouldPoll) { + if (!awaitingTransaction.get()) { + if (pausedByUs.getAndSet(false)) { + val toResume: MutableSet = HashSet(consumer.assignment()) + toResume.removeAll(pausedByUser) + pausedByUser.clear() + consumer.resume(toResume) + logger.debug("Resumed partitions: {}", toResume) + } + } else { + if (pauseAndWakeupIfNeeded()) { pausedByUser.addAll(consumer.paused()) consumer.pause(consumer.assignment()) - when { - pauseForDeferred -> logger.debug("Paused - too many deferred commits") - retrying.get() -> logger.debug("Paused - commits are retrying") - else -> logger.debug("Paused - back pressure") - } + logger.debug("Paused - awaiting transaction") } + } + } else if (pauseAndWakeupIfNeeded()) { + pausedByUser.addAll(consumer.paused()) + consumer.pause(consumer.assignment()) + when { + pauseForDeferred -> logger.debug("Paused - too many deferred commits") + isRetryingCommit.get() -> logger.debug("Paused - commits are retrying") + else -> logger.debug("Paused - back pressure") + } + } - val records: ConsumerRecords = try { - consumer.poll(pollTimeout) - } catch (e: WakeupException) { - logger.debug("Consumer woken") - ConsumerRecords.empty() - } - if (isActive.get()) schedulePoll() - if (!records.isEmpty) { - if (settings.maxDeferredCommits > 0) { - commitBatch.addUncommitted(records) - } - logger.debug("Emitting ${records.count()} records") - channel.trySend(records) - .onClosed { logger.error("Channel closed when trying to send records.", it) } - .onFailure { error -> - if (error != null) { - logger.error("Channel send failed when trying to send records.", error) - channel.close(error) - } else logger.debug("Back-pressuring kafka consumer. Might pause KafkaConsumer on next tick.") - - requesting.set(false) - // TODO Can we rely on a dispatcher from above? - // This should not run on the kafka consumer thread - scope.launch(Dispatchers.Default) { - /* - * Send the records down, - * when send returns we attempt to send and empty set of records down to test the backpressure. - * If our "backpressure test" returns we start requesting/polling again - */ - channel.send(records) - channel.send(ConsumerRecords.empty()) - if (pausedByUs.get()) { - consumer.wakeup() - } - requesting.set(true) - } + val records: ConsumerRecords = try { + consumer.poll(pollTimeout) + } catch (e: WakeupException) { + logger.debug("Consumer woken") + ConsumerRecords.empty() + } + + if (!records.isEmpty) { + if (settings.maxDeferredCommits > 0) { + commitBatch.addUncommitted(records) + } + logger.debug("Attempting to send ${records.count()} records to Channel") + channel.trySend(records) + .onSuccess { poll() } + .onClosed { error -> logger.error("Channel closed when trying to send records.", error) } + .onFailure { error -> + if (error != null) { + logger.error("Channel send failed when trying to send records.", error) + channel.close(error) + } else logger.debug("Back-pressuring kafka consumer. Might pause KafkaConsumer on next poll tick.") + + isPolling.set(false) + + scope.launch(outerContext) { + /* + * Send the records down, + * when send returns we attempt to send and empty set of records down to test the backpressure. + * If our "backpressure test" returns we start requesting/polling again + */ + channel.send(records) + channel.send(ConsumerRecords.empty()) + if (pausedByUs.get()) { + consumer.wakeup() } + isPolling.set(true) + poll() + } + } + } + } catch (e: Exception) { + logger.error("Polling encountered an unexpected exception", e) + channel.close(e) + } + } + + @ConsumerThread + inner class RebalanceListener : ConsumerRebalanceListener { + @ConsumerThread + override fun onPartitionsAssigned(partitions: MutableCollection) { + checkConsumerThread("RebalanceListener.onPartitionsAssigned") + logger.debug("onPartitionsAssigned {}", partitions) + var repausedAll = false + if (partitions.isNotEmpty() && pausedByUs.get()) { + logger.debug("Rebalance during back pressure, re-pausing new assignments") + consumer.pause(partitions) + repausedAll = true + } + if (pausedByUser.isNotEmpty()) { + val toRepause = buildList { + //It is necessary to re-pause any user-paused partitions that are re-assigned after the rebalance. + //Also remove any revoked partitions that the user paused from the userPaused collection. + pausedByUser.forEach { tp -> + if (partitions.contains(tp)) add(tp) + else pausedByUser.remove(tp) } } - } catch (e: Exception) { - logger.error("Unexpected exception", e) - channel.close(e) + if (!repausedAll && toRepause.isNotEmpty()) { + consumer.pause(toRepause) + } } - } else null + // TODO Setup user listeners + // for (onAssign in receiverOptions.assignListeners()) { + // onAssign.accept(toSeekable(partitions)) + // } + traceCommitted(partitions) + } + + @ConsumerThread + override fun onPartitionsRevoked(partitions: MutableCollection) { + checkConsumerThread("RebalanceListener.onPartitionsRevoked") + logger.debug("onPartitionsRevoked {}", partitions) + partitionsRevoked(partitions) + commitBatch.onPartitionsRevoked(partitions) + } + } + @ConsumerThread + private fun partitionsRevoked(partitions: Collection) { + if (!partitions.isEmpty()) { + // It is safe to use the consumer here since we are in a poll() + if (ackMode != ATMOST_ONCE) { + runCommitIfRequired(true) + } + // TODO Setup user listeners + // for (onRevoke in receiverOptions.revokeListeners()) { + // onRevoke.accept(toSeekable(partitions)) + // } + } + } + // + + // val commitBatch: CommittableBatch = CommittableBatch() - private val isPending = AtomicBoolean() - private val inProgress = AtomicInteger() + private val commitPending = AtomicBoolean() + private val asyncCommitsInProgress = AtomicInteger() private val consecutiveCommitFailures = AtomicInteger() - private val retrying = AtomicBoolean() + private val isRetryingCommit = AtomicBoolean() - // TODO Should reset delay of commitJob + /** + * If we were retrying, schedule a poll and set isRetryingCommit to false + * If we weren't retrying, do nothing. + */ + @ConsumerThread + private fun schedulePollAfterRetrying() { + if (isRetryingCommit.getAndSet(false)) poll() + } + + @ConsumerThread + private fun runCommitIfRequired(force: Boolean) { + if (force) commitPending.set(true) + if (!isRetryingCommit.get() && commitPending.get()) commit() + } + + fun scheduleCommitIfRequired() { + if ( + isActive.get() && + !isRetryingCommit.get() && + commitPending.compareAndSet(false, true) + ) scope.launch { commit() } + } + + // TODO We should reset delay of commitJob + @ConsumerThread private fun commit() { - if (!isPending.compareAndSet(true, false)) return + checkConsumerThread("commit") + if (!commitPending.compareAndSet(true, false)) return val commitArgs: CommittableBatch.CommitArgs = commitBatch.getAndClearOffsets() - try { - if (commitArgs.offsets.isEmpty()) commitSuccess(commitArgs, commitArgs.offsets) - else { - when (ackMode) { - AckMode.MANUAL_ACK, AckMode.AUTO_ACK -> { - inProgress.incrementAndGet() - try { - logger.debug("Async committing: ${commitArgs.offsets}") - consumer.commitAsync(commitArgs.offsets) { offsets, exception -> - inProgress.decrementAndGet() - if (exception == null) commitSuccess(commitArgs, offsets) - else commitFailure(commitArgs, exception) - } - } catch (e: Throwable) { - inProgress.decrementAndGet() - throw e - } - schedulePoll() - } + if (commitArgs.offsets.isEmpty()) commitSuccess(commitArgs, commitArgs.offsets) + else { + when (ackMode) { + MANUAL_ACK, AUTO_ACK -> commitAsync(commitArgs) + ATMOST_ONCE -> commitSync(commitArgs) + /** [AckMode.EXACTLY_ONCE] offsets are committed by a producer. */ + EXACTLY_ONCE -> Unit + } + } + } - AckMode.ATMOST_ONCE -> { - logger.debug("Sync committing: ${commitArgs.offsets}") - consumer.commitSync(commitArgs.offsets) - commitSuccess(commitArgs, commitArgs.offsets) - atmostOnceOffsets.onCommit(commitArgs.offsets) - } - // Handled separately using transactional KafkaPublisher - AckMode.EXACTLY_ONCE -> Unit - } + /** + * Commit async, for [MANUAL_ACK] & [AUTO_ACK]. + * We increment the [asyncCommitsInProgress] and request the SDK to [Consumer.commitAsync]. + * + * We always decrement [asyncCommitsInProgress], + * and invoke the relevant handlers in case of [commitSuccess] or [commitFailure]. + * + * We always need to poll after [Consumer.commitAsync]. + */ + private fun commitAsync(commitArgs: CommittableBatch.CommitArgs) { + runCatching { + asyncCommitsInProgress.incrementAndGet() + logger.debug("Async committing: {}", commitArgs.offsets) + consumer.commitAsync(commitArgs.offsets) { offsets, exception -> + asyncCommitsInProgress.decrementAndGet() + if (exception == null) commitSuccess(commitArgs, offsets) + else commitFailure(commitArgs, exception) } - } catch (e: Exception) { - logger.error("Unexpected exception", e) + poll() + }.recoverCatching { e -> + // TODO("NonFatal") + asyncCommitsInProgress.decrementAndGet() commitFailure(commitArgs, e) } } - private fun commitSuccess(commitArgs: CommittableBatch.CommitArgs?, offsets: Map) { - if (offsets.isNotEmpty()) { - consecutiveCommitFailures.set(0) + /** + * Commit sync, for [ATMOST_ONCE]. + * For [ATMOST_ONCE] we want to guarantee it's only received once, + * so we immediately commit the message before sending it to the [Channel]. + * This is blocking, and invoke the relevant handlers in case of [commitSuccess] or [commitFailure]. + */ + private fun commitSync(commitArgs: CommittableBatch.CommitArgs): Unit { + runCatching { + logger.debug("Sync committing: {}", commitArgs.offsets) + // TODO check if this should be runInterruptible ?? + consumer.commitSync(commitArgs.offsets) + commitSuccess(commitArgs, commitArgs.offsets) + atmostOnceOffsets.onCommit(commitArgs.offsets) + }.recoverCatching { e -> + // TODO("NonFatal") + commitFailure(commitArgs, e) } - pollTaskAfterRetry() - commitArgs?.continuations?.forEach { cont -> + } + + /** + * Commit was successfully: + * - Set commitFailures to 0 + * - Schedule poll if we previously were retrying to commit + * - Complete all the [Offset.commit] continuations + */ + @ConsumerThread + private fun commitSuccess( + commitArgs: CommittableBatch.CommitArgs, + offsets: Map + ) { + checkConsumerThread("commitSuccess") + if (offsets.isNotEmpty()) consecutiveCommitFailures.set(0) + schedulePollAfterRetrying() + commitArgs.continuations?.forEach { cont -> cont.resume(Unit) } } - private fun pollTaskAfterRetry(): Job? = - if (retrying.getAndSet(false)) schedulePoll() else null - - private fun commitFailure(commitArgs: CommittableBatch.CommitArgs, exception: Exception) { + @ConsumerThread + private fun commitFailure(commitArgs: CommittableBatch.CommitArgs, exception: Throwable) { + checkConsumerThread("commitFailure") logger.warn("Commit failed", exception) - if (!isRetriableException(exception) && consecutiveCommitFailures.incrementAndGet() < settings.maxCommitAttempts) { - logger.debug("Cannot retry") - pollTaskAfterRetry() + if (!isRetriableCommit(exception) && consecutiveCommitFailures.incrementAndGet() < settings.maxCommitAttempts) { + logger.debug("Commit failed with exception $exception, zero retries remaining") + schedulePollAfterRetrying() val callbackEmitters: List>? = commitArgs.continuations if (callbackEmitters.isNullOrEmpty()) channel.close(exception) else { - isPending.set(false) + commitPending.set(false) commitBatch.restoreOffsets(commitArgs, false) callbackEmitters.forEach { cont -> cont.resumeWithException(exception) @@ -396,62 +486,97 @@ internal class EventLoop( } } else { commitBatch.restoreOffsets(commitArgs, true) + // TODO addSuppressed if in the end we failed to commit within settings.maxCommitAttempts??? logger.warn("Commit failed with exception $exception, retries remaining ${(settings.maxCommitAttempts - consecutiveCommitFailures.get())}") - isPending.set(true) - retrying.set(true) - schedulePoll() - scope.launch { + commitPending.set(true) + isRetryingCommit.set(true) + poll() + // We launch UNDISPATCHED as optimisation since we're already on the consumer thread, + // and we immediately call delay. + scope.launch(start = UNDISPATCHED) { delay(settings.commitRetryInterval) commit() } } } - private fun runCommitIfRequired(force: Boolean) { - if (force) isPending.set(true) - if (!retrying.get() && isPending.get()) commit() + /** + * Close the PollLoop, + * and commit all outstanding [Offset.acknowledge] and [Offset.commit] before closing. + * + * It will make `3` attempts to commit the offsets, + */ + suspend fun close(timeout: Duration): Unit = withContext(scope.coroutineContext) { + checkConsumerThread("close") +// val manualAssignment: Collection = receiverOptions.assignment() +// if (!manualAssignment.isEmpty()) revokePartitions(manualAssignment) + + commitOnClose( + closeEndTime = System.currentTimeMillis() + timeout.inWholeMilliseconds, + maxAttempts = 3 + ) } - fun scheduleCommitIfRequired(): Job? = - if (isActive.get() && !retrying.get() && isPending.compareAndSet(false, true)) scope.launch { commit() } - else null - - // TODO investigate - // https://github.com/akka/alpakka-kafka/blob/aad6a1ccbd4f549b3053988c85cbbe9b11d51542/core/src/main/scala/akka/kafka/internal/KafkaConsumerActor.scala#L514 - private fun waitFor(endTimeMillis: Long) { - while (inProgress.get() > 0 && endTimeMillis - System.currentTimeMillis() > 0) { - consumer.poll(Duration.ofMillis(1)) + /** + * We recurse here in case the consumer has had a recent wakeup call (from user code) + * which will cause a poll() (in waitFor) to be interrupted while we're + * possibly waiting for async commit results. + */ + @ConsumerThread + private suspend fun commitOnClose(closeEndTime: Long, maxAttempts: Int) { + try { + val forceCommit = when (ackMode) { + ATMOST_ONCE -> atmostOnceOffsets.undoCommitAhead(commitBatch) + else -> true + } + // + /** + * [AckMode.EXACTLY_ONCE] offsets are committed by a producer, consumer may be closed immediately. + * For all other [AckMode] we need to commit all [ReceiverRecord] [Offset.acknowledge]. + * Since we want to perform this in an optimal way + */ + if (ackMode != EXACTLY_ONCE) { + runCommitIfRequired(forceCommit) + /** + * The SDK doesn't send commit requests, unless we also poll. + * So poll for 1 millis, until no more commitsInProgress. + */ + while (asyncCommitsInProgress.get() > 0 && closeEndTime - System.currentTimeMillis() > 0) { + consumer.poll(JavaDuration.ofMillis(1)) + } + } + val timeoutRemaining = closeEndTime - System.currentTimeMillis() + consumer.close(JavaDuration.ofMillis(timeoutRemaining.coerceAtLeast(0))) + } catch (e: WakeupException) { + if (maxAttempts == 0) throw e + else commitOnClose(closeEndTime, maxAttempts - 1) } } + // - suspend fun close(timeout: kotlin.time.Duration): Unit = withContext(scope.coroutineContext) { - val closeEndTimeMillis = System.currentTimeMillis() + timeout.inWholeMilliseconds - // val manualAssignment: Collection = receiverOptions.assignment() - // if (manualAssignment != null && !manualAssignment.isEmpty()) onPartitionsRevoked(manualAssignment) - /* - * We loop here in case the consumer has had a recent wakeup call (from user code) - * which will cause a poll() (in waitFor) to be interrupted while we're - * possibly waiting for async commit results. - */ - val maxAttempts = 3 - for (i in 0 until maxAttempts) { + // + /** Trace the position, and last committed offsets for the given partitions */ + @ConsumerThread + private fun traceCommitted(partitions: Collection) { + if (logger.isTraceEnabled) { try { - val forceCommit = when (ackMode) { - AckMode.ATMOST_ONCE -> atmostOnceOffsets.undoCommitAhead(commitBatch) - else -> true - } - // For exactly-once, offsets are committed by a producer, consumer may be closed immediately - if (ackMode != AckMode.EXACTLY_ONCE) { - runCommitIfRequired(forceCommit) - waitFor(closeEndTimeMillis) + val positions = partitions.map { part: TopicPartition -> + "$part position: ${consumer.position(part, ofSeconds(5))}" } - var timeoutMillis: Long = closeEndTimeMillis - System.currentTimeMillis() - if (timeoutMillis < 0) timeoutMillis = 0 - consumer.close(Duration.ofMillis(timeoutMillis)) - break - } catch (e: WakeupException) { - if (i == maxAttempts - 1) throw e + + val committed = + consumer.committed(partitions.toSet(), ofSeconds(5)) + + logger.trace("positions: $positions, committed: $committed") + } catch (ex: Exception) { + logger.error("Failed to get positions or committed", ex) } } } -} \ No newline at end of file + // +} + +/** + * Documentation marker that functions that are only being called from the KafkaConsumer thread + */ +private annotation class ConsumerThread diff --git a/src/test/kotlin/io/github/nomisrev/kafka/KafkaSpec.kt b/src/test/kotlin/io/github/nomisrev/kafka/KafkaSpec.kt index 011f98b7..0c42d951 100644 --- a/src/test/kotlin/io/github/nomisrev/kafka/KafkaSpec.kt +++ b/src/test/kotlin/io/github/nomisrev/kafka/KafkaSpec.kt @@ -13,12 +13,18 @@ import io.github.nomisRev.kafka.receiver.KafkaReceiver import io.github.nomisRev.kafka.receiver.ReceiverSettings import io.kotest.assertions.assertSoftly import io.kotest.assertions.async.shouldTimeout +import io.kotest.assertions.fail +import io.kotest.assertions.failure import io.kotest.core.spec.style.StringSpec import io.kotest.matchers.shouldBe +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.take import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull import org.apache.kafka.clients.admin.Admin import org.apache.kafka.clients.admin.AdminClientConfig import org.apache.kafka.clients.admin.NewTopic @@ -79,7 +85,7 @@ abstract class KafkaSpec(body: KafkaSpec.() -> Unit = {}) : StringSpec() { private fun adminProperties(): Properties = Properties().apply { put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, container.bootstrapServers) - put(AdminClientConfig.CLIENT_ID_CONFIG, "test-kafka-admin-client") + put(AdminClientConfig.CLIENT_ID_CONFIG, "test-kafka-admin-client-${UUID.randomUUID()}") put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "10000") put(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG, "10000") } @@ -117,7 +123,6 @@ abstract class KafkaSpec(body: KafkaSpec.() -> Unit = {}) : StringSpec() { val producer = KafkaProducer(publisherSettings().properties()) val publisher = autoClose(KafkaPublisher(publisherSettings()) { producer }) - private fun nextTopicName(): String = "topic-${UUID.randomUUID()}" @@ -125,13 +130,15 @@ abstract class KafkaSpec(body: KafkaSpec.() -> Unit = {}) : StringSpec() { topicConfig: Map = emptyMap(), partitions: Int = 4, replicationFactor: Short = 1, - action: suspend Admin.(NewTopic) -> A, + action: suspend (NewTopic) -> A, ): A { val topic = NewTopic(nextTopicName(), partitions, replicationFactor).configs(topicConfig) return admin { createTopic(topic) try { - action(topic) + coroutineScope { + action(topic) + } } finally { topic.shouldBeEmpty() deleteTopic(topic.name()) @@ -168,45 +175,17 @@ abstract class KafkaSpec(body: KafkaSpec.() -> Unit = {}) : StringSpec() { } } - @JvmName("shouldHaveAllRecords") - suspend fun NewTopic.shouldHaveRecords( - records: Iterable>> - ) { - val expected = - records.flatten().groupBy({ it.partition() }) { it.value() }.mapValues { it.value.toSet() } - KafkaReceiver(receiverSetting()) - .receive(name()) - .map { record -> - record.also { record.offset.acknowledge() } - } - .take(records.flatten().size) - .toList() - .groupBy({ it.partition() }) { it.value() } - .mapValues { it.value.toSet() } shouldBe expected - } - - suspend fun NewTopic.shouldHaveRecords(records: Iterable>) { - KafkaReceiver(receiverSetting()) - .receive(name()) - .map { record -> - record - .also { record.offset.acknowledge() } - } - .take(records.toList().size) - .toList() - .groupBy({ it.partition() }) { it.value() } shouldBe records.groupBy({ it.partition() }) { it.value() } - } - suspend fun NewTopic.shouldBeEmpty() { - shouldTimeout(100.milliseconds) { + val res = withTimeoutOrNull(100) { KafkaReceiver(receiverSetting()) .receive(name()) .take(1) .toList() } + if (res != null) fail("Expected test to timeout, but found $res") } - suspend fun NewTopic.shouldHaveRecord(records: ProducerRecord) { + suspend infix fun NewTopic.shouldHaveRecord(records: ProducerRecord) { assertSoftly { KafkaReceiver(receiverSetting()) .receive(name()) @@ -219,20 +198,34 @@ abstract class KafkaSpec(body: KafkaSpec.() -> Unit = {}) : StringSpec() { } } - suspend fun topicWithSingleMessage(topic: NewTopic, record: ProducerRecord) = + suspend infix fun NewTopic.shouldHaveRecords(records: Iterable>) { KafkaReceiver(receiverSetting()) - .receive(topic.name()) - .map { - it.apply { offset.acknowledge() } - }.first().value() shouldBe record.value() + .receive(name()) + .map { record -> + record + .also { record.offset.acknowledge() } + } + .take(records.toList().size) + .toList() + .groupBy({ it.partition() }) { it.value() } shouldBe records.groupBy({ it.partition() }) { it.value() } + } - suspend fun topicShouldBeEmpty(topic: NewTopic) = - shouldTimeout(1.seconds) { - KafkaReceiver(receiverSetting()) - .receive(topic.name()) - .take(1) - .toList() - } + @JvmName("shouldHaveAllRecords") + suspend infix fun NewTopic.shouldHaveRecords( + records: Iterable>> + ) { + val expected = + records.flatten().groupBy({ it.partition() }) { it.value() }.mapValues { it.value.toSet() } + KafkaReceiver(receiverSetting()) + .receive(name()) + .map { record -> + record.also { record.offset.acknowledge() } + } + .take(records.flatten().size) + .toList() + .groupBy({ it.partition() }) { it.value() } + .mapValues { it.value.toSet() } shouldBe expected + } fun stubProducer(failOnNumber: Int? = null): suspend () -> Producer = suspend { object : Producer { diff --git a/src/test/kotlin/io/github/nomisrev/kafka/consumer/CommitStrategySpec.kt b/src/test/kotlin/io/github/nomisrev/kafka/receiver/CommitStrategySpec.kt similarity index 97% rename from src/test/kotlin/io/github/nomisrev/kafka/consumer/CommitStrategySpec.kt rename to src/test/kotlin/io/github/nomisrev/kafka/receiver/CommitStrategySpec.kt index a69f490f..d807e7e9 100644 --- a/src/test/kotlin/io/github/nomisrev/kafka/consumer/CommitStrategySpec.kt +++ b/src/test/kotlin/io/github/nomisrev/kafka/receiver/CommitStrategySpec.kt @@ -1,4 +1,4 @@ -package io.github.nomisrev.kafka.consumer +package io.github.nomisrev.kafka.receiver import io.github.nomisRev.kafka.receiver.CommitStrategy import io.kotest.assertions.throwables.shouldThrow @@ -22,7 +22,7 @@ class CommitStrategySpec : StringSpec({ }.message shouldBe "Size based auto-commit requires positive non-zero commit batch size but found $size" } } - + "Negative or zero sized BySizeOrTime strategy fails" { checkAll(Arb.int(max = 0)) { size -> shouldThrow { @@ -30,11 +30,11 @@ class CommitStrategySpec : StringSpec({ }.message shouldBe "Size based auto-commit requires positive non-zero commit batch size but found $size" } } - + fun Arb.Companion.duration( min: Long = Long.MIN_VALUE, max: Long = Long.MAX_VALUE, ): Arb = Arb.long(min, max).map { it.nanoseconds } - + "Negative or zero duration BySizeOrTime strategy fails" { checkAll(Arb.duration(max = 0)) { duration -> shouldThrow { @@ -42,7 +42,7 @@ class CommitStrategySpec : StringSpec({ }.message shouldBe "Time based auto-commit requires positive non-zero interval but found $duration" } } - + "Negative or zero duration ByTime strategy fails" { checkAll(Arb.duration(max = 0)) { duration -> shouldThrow { diff --git a/src/test/kotlin/io/github/nomisrev/kafka/consumer/KafakReceiverSpec.kt b/src/test/kotlin/io/github/nomisrev/kafka/receiver/KafakReceiverSpec.kt similarity index 74% rename from src/test/kotlin/io/github/nomisrev/kafka/consumer/KafakReceiverSpec.kt rename to src/test/kotlin/io/github/nomisrev/kafka/receiver/KafakReceiverSpec.kt index 9fb10335..e9c971e8 100644 --- a/src/test/kotlin/io/github/nomisrev/kafka/consumer/KafakReceiverSpec.kt +++ b/src/test/kotlin/io/github/nomisrev/kafka/receiver/KafakReceiverSpec.kt @@ -1,4 +1,4 @@ -package io.github.nomisrev.kafka.consumer +package io.github.nomisrev.kafka.receiver import io.github.nomisRev.kafka.receiver.CommitStrategy import io.github.nomisRev.kafka.receiver.KafkaReceiver @@ -24,41 +24,45 @@ import org.apache.kafka.clients.producer.ProducerRecord @OptIn(ExperimentalCoroutinesApi::class) class KafakReceiverSpec : KafkaSpec({ - - val depth = 100 - val lastIndex = depth - 1 + + val count = 1000 + val lastIndex = count - 1 fun produced( startIndex: Int = 0, - lastIndex: Int = depth, + lastIndex: Int = count, ): List> = (startIndex until lastIndex).map { n -> Pair("key-$n", "value->$n") } - + + val produced = produced() + "All produced records are received" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) KafkaReceiver(receiverSetting()) .receive(topic.name()) - .map { + .map { record -> yield() - Pair(it.key(), it.value()) - }.take(depth).toList() shouldContainExactlyInAnyOrder produced() + Pair(record.key(), record.value()) + .also { record.offset.acknowledge() } + }.take(count) + .toList() shouldContainExactlyInAnyOrder produced } } - + "All produced records with headers are received" { withTopic(partitions = 1) { topic -> - val producerRecords = produced().map { (key, value) -> + val producerRecords = produced.map { (key, value) -> ProducerRecord(topic.name(), key, value).apply { headers().add("header1", byteArrayOf(0.toByte())) headers().add("header2", value.toByteArray()) } } - + publishToKafka(producerRecords) - + KafkaReceiver(receiverSetting()) .receive(topic.name()) - .take(depth) + .take(count) .collectIndexed { index, received -> assertSoftly(producerRecords[index]) { received.key() shouldBe key() @@ -67,13 +71,14 @@ class KafakReceiverSpec : KafkaSpec({ received.headers().toArray().size shouldBe 2 received.headers() shouldBe headers() } + received.offset.acknowledge() } } } - - "Should receive all records at least once when subscribing several consumers" { + + "Should receive all records when subscribing several consumers" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val consumer = KafkaReceiver(receiverSetting()) .receive(topic.name()) @@ -81,41 +86,41 @@ class KafakReceiverSpec : KafkaSpec({ yield() Pair(it.key(), it.value()) } - + flowOf(consumer, consumer) .flattenMerge() - .take(depth) - .toList() shouldContainExactlyInAnyOrder produced() + .take(count) + .toList() shouldContainExactlyInAnyOrder produced } } - + "All acknowledged messages are committed on flow completion" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver( receiverSetting().copy( - commitStrategy = CommitStrategy.BySize(2 * depth) + commitStrategy = CommitStrategy.BySize(2 * count) ) ) receiver.receive(topic.name()) - .take(depth) + .take(count) .collectIndexed { index, value -> if (index == lastIndex) { value.offset.acknowledge() receiver.committedCount(topic.name()) shouldBe 0 } else value.offset.acknowledge() } - - receiver.committedCount(topic.name()) shouldBe 100 + + receiver.committedCount(topic.name()) shouldBe count } } - + "All acknowledged messages are committed on flow failure" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver( receiverSetting().copy( - commitStrategy = CommitStrategy.BySize(2 * depth) + commitStrategy = CommitStrategy.BySize(2 * count) ) ) val failure = RuntimeException("Flow terminates") @@ -129,18 +134,18 @@ class KafakReceiverSpec : KafkaSpec({ } else value.offset.acknowledge() } }.exceptionOrNull() shouldBe failure - - receiver.committedCount(topic.name()) shouldBe 100 + + receiver.committedCount(topic.name()) shouldBe count } } - + "All acknowledged messages are committed on flow cancellation" { val scope = this withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver( receiverSetting().copy( - commitStrategy = CommitStrategy.BySize(2 * depth) + commitStrategy = CommitStrategy.BySize(2 * count) ) ) val latch = CompletableDeferred() @@ -152,69 +157,69 @@ class KafakReceiverSpec : KafkaSpec({ require(latch.complete(Unit)) { "Latch completed twice" } } else value.offset.acknowledge() }.launchIn(scope) - + latch.await() job.cancelAndJoin() - - receiver.committedCount(topic.name()) shouldBe 100 + + receiver.committedCount(topic.name()) shouldBe count } } - + "Manual commit also commits all acknowledged offsets" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver( receiverSetting().copy( - commitStrategy = CommitStrategy.BySize(2 * depth) + commitStrategy = CommitStrategy.BySize(2 * count) ) ) receiver.receive(topic.name()) - .take(depth) + .take(count) .collectIndexed { index, value -> if (index == lastIndex) { value.offset.commit() - receiver.committedCount(topic.name()) shouldBe 100 + receiver.committedCount(topic.name()) shouldBe count } else value.offset.acknowledge() } } } - + "receiveAutoAck" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver(receiverSetting()) - + receiver.receiveAutoAck(topic.name()) .flatMapConcat { it } - .take(depth) + .take(count) .collect() - - receiver.committedCount(topic.name()) shouldBe 100 + + receiver.committedCount(topic.name()) shouldBe count } } - + "receiveAutoAck does not receive same records" { withTopic(partitions = 3) { topic -> - publishToKafka(topic, produced()) + publishToKafka(topic, produced) val receiver = KafkaReceiver(receiverSetting()) - + receiver.receiveAutoAck(topic.name()) .flatMapConcat { it } - .take(depth) + .take(count) .collect() - - receiver.committedCount(topic.name()) shouldBe 100 - - val seconds = produced(depth + 1, depth + 1 + depth) + + receiver.committedCount(topic.name()) shouldBe count + + val seconds = produced(count + 1, count + 1 + count) publishToKafka(topic, seconds) - + receiver.receiveAutoAck(topic.name()) .flatMapConcat { it } .map { Pair(it.key(), it.value()) } - .take(depth) + .take(count) .toList() shouldContainExactlyInAnyOrder seconds - - receiver.committedCount(topic.name()) shouldBe 200 + + receiver.committedCount(topic.name()) shouldBe (2 * count) } } })