From 93723027e1ec49de53a9b2f13fbc792b5b2c90df Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Tue, 9 Sep 2025 14:09:00 +0200 Subject: [PATCH 1/7] KAFKA-19694: Trigger StreamsRebalanceListener in Consumer.close In the consumer, we invoke the consumer rebalance onPartitionRevoked or onPartitionLost callbacks, when the consumer closes. The point is that the application may want to commit, or wipe the state if we are closing unsuccessfully. In the StreamsRebalanceListener, we did not implement this behavior, which means when closing the consumer we may lose some progress, and in the worst case also miss that we have to wipe our local state state since we got fenced. In this PR we implement StreamsRebalanceListenerInvoker, very similarly to ConsumerRebalanceListenerInvoker and invoke it in Consumer.close. --- .../internals/AsyncKafkaConsumer.java | 93 +++-- .../StreamsRebalanceListenerInvoker.java | 114 ++++++ .../internals/AsyncKafkaConsumerTest.java | 69 ++++ .../StreamsRebalanceListenerInvokerTest.java | 337 ++++++++++++++++++ .../DefaultStreamsRebalanceListener.java | 2 + .../DefaultStreamsRebalanceListenerTest.java | 100 +++--- 6 files changed, 623 insertions(+), 92 deletions(-) create mode 100644 clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java create mode 100644 clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index 5c72c2babbb00..16bb640c32ed1 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -187,23 +187,14 @@ public class AsyncKafkaConsumer implements ConsumerDelegate { */ private class BackgroundEventProcessor implements EventProcessor { - private Optional streamsRebalanceListener = Optional.empty(); - private final Optional streamsRebalanceData; + private final Optional streamsRebalanceListenerInvoker; public BackgroundEventProcessor() { - this.streamsRebalanceData = Optional.empty(); + this.streamsRebalanceListenerInvoker = Optional.empty(); } - public BackgroundEventProcessor(final Optional streamsRebalanceData) { - this.streamsRebalanceData = streamsRebalanceData; - } - - private void setStreamsRebalanceListener(final StreamsRebalanceListener streamsRebalanceListener) { - if (streamsRebalanceData.isEmpty()) { - throw new IllegalStateException("Background event processor was not created to be used with Streams " + - "rebalance protocol events"); - } - this.streamsRebalanceListener = Optional.of(streamsRebalanceListener); + public BackgroundEventProcessor(final Optional streamsRebalanceListenerInvoker) { + this.streamsRebalanceListenerInvoker = streamsRebalanceListenerInvoker; } @Override @@ -278,7 +269,7 @@ private void processStreamsOnAllTasksLostCallbackNeededEvent(final StreamsOnAllT private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback(final Set activeTasksToRevoke, final CompletableFuture future) { - final Optional exceptionFromCallback = streamsRebalanceListener().onTasksRevoked(activeTasksToRevoke); + final Optional exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksRevoked(activeTasksToRevoke)); final Optional error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task revocation callback throws an error")); return new StreamsOnTasksRevokedCallbackCompletedEvent(future, error); } @@ -286,36 +277,20 @@ private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback private StreamsOnTasksAssignedCallbackCompletedEvent invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment, final CompletableFuture future) { final Optional error; - final Optional exceptionFromCallback = streamsRebalanceListener().onTasksAssigned(assignment); - if (exceptionFromCallback.isPresent()) { - error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "Task assignment callback throws an error")); - } else { - error = Optional.empty(); - streamsRebalanceData().setReconciledAssignment(assignment); - } + final Optional exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksAssigned(assignment)); + error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws an error")); return new StreamsOnTasksAssignedCallbackCompletedEvent(future, error); } private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback(final CompletableFuture future) { final Optional error; - final Optional exceptionFromCallback = streamsRebalanceListener().onAllTasksLost(); - if (exceptionFromCallback.isPresent()) { - error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "All tasks lost callback throws an error")); - } else { - error = Optional.empty(); - streamsRebalanceData().setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY); - } + final Optional exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeAllTasksLost()); + error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws an error")); return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error); } - private StreamsRebalanceData streamsRebalanceData() { - return streamsRebalanceData.orElseThrow( - () -> new IllegalStateException("Background event processor was not created to be used with Streams " + - "rebalance protocol events")); - } - - private StreamsRebalanceListener streamsRebalanceListener() { - return streamsRebalanceListener.orElseThrow( + private StreamsRebalanceListenerInvoker streamsRebalanceListenerInvoker() { + return streamsRebalanceListenerInvoker.orElseThrow( () -> new IllegalStateException("Background event processor was not created to be used with Streams " + "rebalance protocol events")); } @@ -367,6 +342,7 @@ private StreamsRebalanceListener streamsRebalanceListener() { private final WakeupTrigger wakeupTrigger = new WakeupTrigger(); private final OffsetCommitCallbackInvoker offsetCommitCallbackInvoker; private final ConsumerRebalanceListenerInvoker rebalanceListenerInvoker; + private final Optional streamsRebalanceListenerInvoker; // Last triggered async commit future. Used to wait until all previous async commits are completed. // We only need to keep track of the last one, since they are guaranteed to complete in order. private CompletableFuture> lastPendingAsyncCommit = null; @@ -517,7 +493,9 @@ public AsyncKafkaConsumer(final ConsumerConfig config, time, new RebalanceCallbackMetricsManager(metrics) ); - this.backgroundEventProcessor = new BackgroundEventProcessor(streamsRebalanceData); + this.streamsRebalanceListenerInvoker = streamsRebalanceData.map(s -> + new StreamsRebalanceListenerInvoker(logContext, s)); + this.backgroundEventProcessor = new BackgroundEventProcessor(streamsRebalanceListenerInvoker); this.backgroundEventReaper = backgroundEventReaperFactory.build(logContext); // The FetchCollector is only used on the application thread. @@ -577,6 +555,7 @@ public AsyncKafkaConsumer(final ConsumerConfig config, this.time = time; this.backgroundEventQueue = backgroundEventQueue; this.rebalanceListenerInvoker = rebalanceListenerInvoker; + this.streamsRebalanceListenerInvoker = Optional.empty(); this.backgroundEventProcessor = new BackgroundEventProcessor(); this.backgroundEventReaper = backgroundEventReaper; this.metrics = metrics; @@ -699,6 +678,7 @@ public AsyncKafkaConsumer(final ConsumerConfig config, networkClientDelegateSupplier, requestManagersSupplier, asyncConsumerMetrics); + this.streamsRebalanceListenerInvoker = Optional.empty(); this.backgroundEventProcessor = new BackgroundEventProcessor(); this.backgroundEventReaper = new CompletableEventReaper(logContext); } @@ -1532,21 +1512,32 @@ private void runRebalanceCallbacksOnClose() { int memberEpoch = groupMetadata.get().get().generationId(); - Set assignedPartitions = groupAssignmentSnapshot.get(); + final Exception error; + + if (streamsRebalanceListenerInvoker.isPresent()) { - if (assignedPartitions.isEmpty()) - // Nothing to revoke. - return; + if (memberEpoch > 0) + error = streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked(); + else + error = streamsRebalanceListenerInvoker.get().invokeAllTasksLost(); - SortedSet droppedPartitions = new TreeSet<>(TOPIC_PARTITION_COMPARATOR); - droppedPartitions.addAll(assignedPartitions); + } else { - final Exception error; + Set assignedPartitions = groupAssignmentSnapshot.get(); - if (memberEpoch > 0) - error = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions); - else - error = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions); + if (assignedPartitions.isEmpty()) + // Nothing to revoke. + return; + + SortedSet droppedPartitions = new TreeSet<>(TOPIC_PARTITION_COMPARATOR); + droppedPartitions.addAll(assignedPartitions); + + if (memberEpoch > 0) + error = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions); + else + error = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions); + + } if (error != null) throw ConsumerUtils.maybeWrapAsKafkaException(error); @@ -1964,7 +1955,11 @@ public void subscribe(Collection topics, ConsumerRebalanceListener liste public void subscribe(Collection topics, StreamsRebalanceListener streamsRebalanceListener) { subscribeInternal(topics, Optional.empty()); - backgroundEventProcessor.setStreamsRebalanceListener(streamsRebalanceListener); + if (streamsRebalanceListenerInvoker.isPresent()) { + streamsRebalanceListenerInvoker.get().setRebalanceListener(streamsRebalanceListener); + } else { + throw new IllegalStateException("Consumer was not created to be used with Streams rebalance protocol events"); + } } @Override diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java new file mode 100644 index 0000000000000..fc7a47635a62a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.utils.LogContext; + +import org.slf4j.Logger; + +import java.util.Optional; +import java.util.Set; + +/** + * This class encapsulates the invocation of the callback methods defined in the {@link StreamsRebalanceListener} + * interface. When streams group task assignment changes, these methods are invoked. This class wraps those + * callback calls with some logging and error handling. + */ +public class StreamsRebalanceListenerInvoker { + + private final Logger log; + + private final StreamsRebalanceData streamsRebalanceData; + private Optional listener; + + StreamsRebalanceListenerInvoker(LogContext logContext, StreamsRebalanceData streamsRebalanceData) { + this.log = logContext.logger(getClass()); + this.listener = Optional.empty(); + this.streamsRebalanceData = streamsRebalanceData; + } + + public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListener) { + this.listener = Optional.ofNullable(streamsRebalanceListener); + } + + public Exception invokeAllTasksRevoked() { + if (listener.isPresent()) { + return invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks()); + } + + return null; + } + + public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment assignment) { + if (listener.isPresent()) { + log.info("Adding newly assigned tasks: {}", assignment); + try { + listener.get().onTasksAssigned(assignment); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksAssigned for tasks {}", + assignment, + e + ); + return e; + } + } + return null; + } + + public Exception invokeTasksRevoked(final Set tasks) { + if (listener.isPresent()) { + log.info("Revoke previously assigned tasks {}", tasks); + try { + listener.get().onTasksRevoked(tasks); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksRevoked for tasks {}", + tasks, + e + ); + return e; + } + } + + return null; + } + + public Exception invokeAllTasksLost() { + if (listener.isPresent()) { + log.info("Lost all previously assigned tasks"); + try { + listener.get().onAllTasksLost(); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksLost.", + e + ); + return e; + } + } + return null; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java index dcf604d6b819c..a0281b00cb226 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java @@ -2210,6 +2210,75 @@ private void markOffsetsReadyForCommitEvent() { }).when(applicationEventHandler).add(ArgumentMatchers.isA(CommitEvent.class)); } + @Test + public void testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpochPositive() { + final String groupId = "streamsGroup"; + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()); + + try (final MockedStatic requestManagers = mockStatic(RequestManagers.class)) { + consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData); + StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class); + when(mockStreamsListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + consumer.subscribe(singletonList("topic"), mockStreamsListener); + final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers); + final int memberEpoch = 42; + final String memberId = "memberId"; + groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId); + + consumer.close(CloseOptions.timeout(Duration.ZERO)); + + verify(mockStreamsListener).onTasksRevoked(any()); + } + } + + @Test + public void testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpochZeroOrNegative() { + // Test that close() calls streamsRebalanceListener.invokeAllTasksLost() when memberEpoch <= 0 + final String groupId = "streamsGroup"; + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()); + + try (final MockedStatic requestManagers = mockStatic(RequestManagers.class)) { + consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData); + StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class); + when(mockStreamsListener.onAllTasksLost()).thenReturn(Optional.empty()); + consumer.subscribe(singletonList("topic"), mockStreamsListener); + final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers); + final int memberEpoch = 0; + final String memberId = "memberId"; + groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId); + + consumer.close(CloseOptions.timeout(Duration.ZERO)); + + verify(mockStreamsListener).onAllTasksLost(); + } + } + + @Test + public void testCloseWrapsStreamsRebalanceListenerException() { + final String groupId = "streamsGroup"; + final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()); + + try (final MockedStatic requestManagers = mockStatic(RequestManagers.class)) { + consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData); + StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class); + RuntimeException testException = new RuntimeException("Test streams listener exception"); + doThrow(testException).when(mockStreamsListener).onTasksRevoked(any()); + consumer.subscribe(singletonList("topic"), mockStreamsListener); + final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers); + final int memberEpoch = 1; + final String memberId = "memberId"; + groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId); + + KafkaException thrownException = assertThrows(KafkaException.class, + () -> consumer.close(CloseOptions.timeout(Duration.ZERO))); + + assertNotNull(thrownException.getCause()); + assertTrue(thrownException.getCause() instanceof RuntimeException); + assertTrue(thrownException.getCause().getMessage().contains("Test streams listener exception")); + verify(mockStreamsListener).onTasksRevoked(any()); + } + } + private void markReconcileAndAutoCommitCompleteForPollEvent() { doAnswer(invocation -> { PollEvent event = invocation.getArgument(0); diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java new file mode 100644 index 0000000000000..1c13fdb013163 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.utils.LogContext; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.util.Optional; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.STRICT_STUBS) +public class StreamsRebalanceListenerInvokerTest { + + @Mock + private StreamsRebalanceListener mockListener; + + @Mock + private StreamsRebalanceData streamsRebalanceData; + + private StreamsRebalanceListenerInvoker invoker; + private final LogContext logContext = new LogContext(); + + @BeforeEach + public void setup() { + invoker = new StreamsRebalanceListenerInvoker(logContext, streamsRebalanceData); + } + + @Test + public void testConstructorInitializesWithEmptyListener() { + // When invoker is constructed, it should have no listener set initially + // This is verified by testing that invoke methods return null when no listener is present + assertNull(invoker.invokeAllTasksRevoked()); + assertNull(invoker.invokeAllTasksLost()); + } + + @Test + public void testSetRebalanceListener() { + // Test setting a listener + invoker.setRebalanceListener(mockListener); + + // Verify listener is set by checking that methods no longer return null immediately + // (we'll mock the dependencies needed for actual invocation) + StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); + when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); + when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + + // Should now invoke the listener instead of returning null immediately + Exception result = invoker.invokeAllTasksRevoked(); + assertNull(result); // No exception thrown by mock listener + verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks())); + } + + @Test + public void testSetRebalanceListenerWithNull() { + // Test setting listener to null + invoker.setRebalanceListener(null); + + // Should behave as if no listener is set + assertNull(invoker.invokeAllTasksRevoked()); + assertNull(invoker.invokeAllTasksLost()); + } + + @Test + public void testSetRebalanceListenerOverwritesExisting() { + StreamsRebalanceListener firstListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); + StreamsRebalanceListener secondListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); + + StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); + when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); + when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + + // Set first listener + invoker.setRebalanceListener(firstListener); + + // Overwrite with second listener + invoker.setRebalanceListener(secondListener); + + // Should use second listener + invoker.invokeAllTasksRevoked(); + verify(firstListener, never()).onTasksRevoked(any()); + verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks())); + } + + @Test + public void testInvokeAllTasksRevokedWithNoListener() { + // When no listener is set, should return null + Exception result = invoker.invokeAllTasksRevoked(); + assertNull(result); + } + + @Test + public void testInvokeAllTasksRevokedWithListener() { + invoker.setRebalanceListener(mockListener); + + StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); + when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); + when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + + Exception result = invoker.invokeAllTasksRevoked(); + + assertNull(result); + verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks())); + } + + @Test + public void testInvokeTasksAssignedWithNoListener() { + StreamsRebalanceData.Assignment assignment = createMockAssignment(); + + Exception result = invoker.invokeTasksAssigned(assignment); + + assertNull(result); + verify(mockListener, never()).onTasksAssigned(any()); + } + + @Test + public void testInvokeTasksAssignedWithListener() { + invoker.setRebalanceListener(mockListener); + StreamsRebalanceData.Assignment assignment = createMockAssignment(); + when(mockListener.onTasksAssigned(assignment)).thenReturn(Optional.empty()); + + Exception result = invoker.invokeTasksAssigned(assignment); + + assertNull(result); + verify(mockListener).onTasksAssigned(eq(assignment)); + } + + @Test + public void testInvokeTasksAssignedWithWakeupException() { + invoker.setRebalanceListener(mockListener); + StreamsRebalanceData.Assignment assignment = createMockAssignment(); + WakeupException wakeupException = new WakeupException(); + doThrow(wakeupException).when(mockListener).onTasksAssigned(assignment); + + WakeupException thrownException = assertThrows(WakeupException.class, + () -> invoker.invokeTasksAssigned(assignment)); + + assertEquals(wakeupException, thrownException); + verify(mockListener).onTasksAssigned(eq(assignment)); + } + + @Test + public void testInvokeTasksAssignedWithInterruptException() { + invoker.setRebalanceListener(mockListener); + StreamsRebalanceData.Assignment assignment = createMockAssignment(); + InterruptException interruptException = new InterruptException("Test interrupt"); + doThrow(interruptException).when(mockListener).onTasksAssigned(assignment); + + InterruptException thrownException = assertThrows(InterruptException.class, + () -> invoker.invokeTasksAssigned(assignment)); + + assertEquals(interruptException, thrownException); + verify(mockListener).onTasksAssigned(eq(assignment)); + } + + @Test + public void testInvokeTasksAssignedWithOtherException() { + invoker.setRebalanceListener(mockListener); + StreamsRebalanceData.Assignment assignment = createMockAssignment(); + RuntimeException runtimeException = new RuntimeException("Test exception"); + doThrow(runtimeException).when(mockListener).onTasksAssigned(assignment); + + Exception result = invoker.invokeTasksAssigned(assignment); + + assertEquals(runtimeException, result); + verify(mockListener).onTasksAssigned(eq(assignment)); + } + + @Test + public void testInvokeTasksRevokedWithNoListener() { + Set tasks = createMockTasks(); + + Exception result = invoker.invokeTasksRevoked(tasks); + + assertNull(result); + verify(mockListener, never()).onTasksRevoked(any()); + } + + @Test + public void testInvokeTasksRevokedWithListener() { + invoker.setRebalanceListener(mockListener); + Set tasks = createMockTasks(); + when(mockListener.onTasksRevoked(tasks)).thenReturn(Optional.empty()); + + Exception result = invoker.invokeTasksRevoked(tasks); + + assertNull(result); + verify(mockListener).onTasksRevoked(eq(tasks)); + } + + @Test + public void testInvokeTasksRevokedWithWakeupException() { + invoker.setRebalanceListener(mockListener); + Set tasks = createMockTasks(); + WakeupException wakeupException = new WakeupException(); + doThrow(wakeupException).when(mockListener).onTasksRevoked(tasks); + + WakeupException thrownException = assertThrows(WakeupException.class, + () -> invoker.invokeTasksRevoked(tasks)); + + assertEquals(wakeupException, thrownException); + verify(mockListener).onTasksRevoked(eq(tasks)); + } + + @Test + public void testInvokeTasksRevokedWithInterruptException() { + invoker.setRebalanceListener(mockListener); + Set tasks = createMockTasks(); + InterruptException interruptException = new InterruptException("Test interrupt"); + doThrow(interruptException).when(mockListener).onTasksRevoked(tasks); + + InterruptException thrownException = assertThrows(InterruptException.class, + () -> invoker.invokeTasksRevoked(tasks)); + + assertEquals(interruptException, thrownException); + verify(mockListener).onTasksRevoked(eq(tasks)); + } + + @Test + public void testInvokeTasksRevokedWithOtherException() { + invoker.setRebalanceListener(mockListener); + Set tasks = createMockTasks(); + RuntimeException runtimeException = new RuntimeException("Test exception"); + doThrow(runtimeException).when(mockListener).onTasksRevoked(tasks); + + Exception result = invoker.invokeTasksRevoked(tasks); + + assertEquals(runtimeException, result); + verify(mockListener).onTasksRevoked(eq(tasks)); + } + + @Test + public void testInvokeAllTasksLostWithNoListener() { + Exception result = invoker.invokeAllTasksLost(); + + assertNull(result); + verify(mockListener, never()).onAllTasksLost(); + } + + @Test + public void testInvokeAllTasksLostWithListener() { + invoker.setRebalanceListener(mockListener); + when(mockListener.onAllTasksLost()).thenReturn(Optional.empty()); + + Exception result = invoker.invokeAllTasksLost(); + + assertNull(result); + verify(mockListener).onAllTasksLost(); + } + + @Test + public void testInvokeAllTasksLostWithWakeupException() { + invoker.setRebalanceListener(mockListener); + WakeupException wakeupException = new WakeupException(); + doThrow(wakeupException).when(mockListener).onAllTasksLost(); + + WakeupException thrownException = assertThrows(WakeupException.class, + () -> invoker.invokeAllTasksLost()); + + assertEquals(wakeupException, thrownException); + verify(mockListener).onAllTasksLost(); + } + + @Test + public void testInvokeAllTasksLostWithInterruptException() { + invoker.setRebalanceListener(mockListener); + InterruptException interruptException = new InterruptException("Test interrupt"); + doThrow(interruptException).when(mockListener).onAllTasksLost(); + + InterruptException thrownException = assertThrows(InterruptException.class, + () -> invoker.invokeAllTasksLost()); + + assertEquals(interruptException, thrownException); + verify(mockListener).onAllTasksLost(); + } + + @Test + public void testInvokeAllTasksLostWithOtherException() { + invoker.setRebalanceListener(mockListener); + RuntimeException runtimeException = new RuntimeException("Test exception"); + doThrow(runtimeException).when(mockListener).onAllTasksLost(); + + Exception result = invoker.invokeAllTasksLost(); + + assertEquals(runtimeException, result); + verify(mockListener).onAllTasksLost(); + } + + private StreamsRebalanceData.Assignment createMockAssignment() { + Set activeTasks = createMockTasks(); + Set standbyTasks = Set.of(); + Set warmupTasks = Set.of(); + + return new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks); + } + + private Set createMockTasks() { + return Set.of( + new StreamsRebalanceData.TaskId("subtopology1", 0), + new StreamsRebalanceData.TaskId("subtopology1", 1) + ); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java index dcc4821f2a8c4..a95fcef5a6c8a 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListener.java @@ -89,6 +89,7 @@ public Optional onTasksAssigned(final StreamsRebalanceData.Assignment taskManager.handleAssignment(activeTasksWithPartitions, standbyTasksWithPartitions); streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED); taskManager.handleRebalanceComplete(); + streamsRebalanceData.setReconciledAssignment(assignment); } catch (final Exception exception) { return Optional.of(exception); } @@ -99,6 +100,7 @@ public Optional onTasksAssigned(final StreamsRebalanceData.Assignment public Optional onAllTasksLost() { try { taskManager.handleLostAll(); + streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY); } catch (final Exception exception) { return Optional.of(exception); } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java index 66cb8e5185b15..3902fff4924b1 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java @@ -118,49 +118,46 @@ void testOnTasksRevokedWithException() { @Test void testOnTasksAssigned() { - createRebalanceListenerWithRebalanceData(new StreamsRebalanceData( - UUID.randomUUID(), - Optional.empty(), - Map.of( - "1", - new StreamsRebalanceData.Subtopology( - Set.of("source1"), - Set.of(), - Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), - Map.of(), - Set.of() - ), - "2", - new StreamsRebalanceData.Subtopology( - Set.of("source2"), - Set.of(), - Map.of("repartition2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), - Map.of(), - Set.of() - ), - "3", - new StreamsRebalanceData.Subtopology( - Set.of("source3"), - Set.of(), - Map.of("repartition3", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), - Map.of(), - Set.of() - ) + final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); + when(streamsRebalanceData.subtopologies()).thenReturn(Map.of( + "1", + new StreamsRebalanceData.Subtopology( + Set.of("source1"), + Set.of(), + Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), + Map.of(), + Set.of() ), - Map.of() + "2", + new StreamsRebalanceData.Subtopology( + Set.of("source2"), + Set.of(), + Map.of("repartition2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), + Map.of(), + Set.of() + ), + "3", + new StreamsRebalanceData.Subtopology( + Set.of("source3"), + Set.of(), + Map.of("repartition3", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())), + Map.of(), + Set.of() + ) )); + createRebalanceListenerWithRebalanceData(streamsRebalanceData); - final Optional result = defaultStreamsRebalanceListener.onTasksAssigned( - new StreamsRebalanceData.Assignment( - Set.of(new StreamsRebalanceData.TaskId("1", 0)), - Set.of(new StreamsRebalanceData.TaskId("2", 0)), - Set.of(new StreamsRebalanceData.TaskId("3", 0)) - ) + final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment( + Set.of(new StreamsRebalanceData.TaskId("1", 0)), + Set.of(new StreamsRebalanceData.TaskId("2", 0)), + Set.of(new StreamsRebalanceData.TaskId("3", 0)) ); + final Optional result = defaultStreamsRebalanceListener.onTasksAssigned(assignment); + assertTrue(result.isEmpty()); - final InOrder inOrder = inOrder(taskManager, streamThread); + final InOrder inOrder = inOrder(taskManager, streamThread, streamsRebalanceData); inOrder.verify(taskManager).handleAssignment( Map.of(new TaskId(1, 0), Set.of(new TopicPartition("source1", 0), new TopicPartition("repartition1", 0))), Map.of( @@ -170,6 +167,7 @@ void testOnTasksAssigned() { ); inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_ASSIGNED); inOrder.verify(taskManager).handleRebalanceComplete(); + inOrder.verify(streamsRebalanceData).setReconciledAssignment(assignment); } @Test @@ -177,21 +175,33 @@ void testOnTasksAssignedWithException() { final Exception exception = new RuntimeException("sample exception"); doThrow(exception).when(taskManager).handleAssignment(any(), any()); - createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of())); - final Optional result = defaultStreamsRebalanceListener.onTasksAssigned(new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of())); - assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty()); + final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); + when(streamsRebalanceData.subtopologies()).thenReturn(Map.of()); + createRebalanceListenerWithRebalanceData(streamsRebalanceData); + + final Optional result = defaultStreamsRebalanceListener.onTasksAssigned( + new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()) + ); + assertTrue(result.isPresent()); assertEquals(exception, result.get()); - verify(taskManager).handleLostAll(); + verify(taskManager).handleAssignment(any(), any()); verify(streamThread, never()).setState(StreamThread.State.PARTITIONS_ASSIGNED); verify(taskManager, never()).handleRebalanceComplete(); + verify(streamsRebalanceData, never()).setReconciledAssignment(any()); } @Test void testOnAllTasksLost() { - createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of())); + final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); + when(streamsRebalanceData.subtopologies()).thenReturn(Map.of()); + createRebalanceListenerWithRebalanceData(streamsRebalanceData); + assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty()); - verify(taskManager).handleLostAll(); + + final InOrder inOrder = inOrder(taskManager, streamsRebalanceData); + inOrder.verify(taskManager).handleLostAll(); + inOrder.verify(streamsRebalanceData).setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY); } @Test @@ -199,10 +209,14 @@ void testOnAllTasksLostWithException() { final Exception exception = new RuntimeException("sample exception"); doThrow(exception).when(taskManager).handleLostAll(); - createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of())); + final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); + when(streamsRebalanceData.subtopologies()).thenReturn(Map.of()); + createRebalanceListenerWithRebalanceData(streamsRebalanceData); + final Optional result = defaultStreamsRebalanceListener.onAllTasksLost(); assertTrue(result.isPresent()); assertEquals(exception, result.get()); verify(taskManager).handleLostAll(); + verify(streamsRebalanceData, never()).setReconciledAssignment(any()); } } From 0e725787b71b305487cc70c5b0e3bcc12fe341fe Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Tue, 9 Sep 2025 19:39:37 +0200 Subject: [PATCH 2/7] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../clients/consumer/internals/AsyncKafkaConsumer.java | 10 ++++++---- .../internals/DefaultStreamsRebalanceListenerTest.java | 4 +--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index 16bb640c32ed1..972ba4b866fd8 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -1516,10 +1516,11 @@ private void runRebalanceCallbacksOnClose() { if (streamsRebalanceListenerInvoker.isPresent()) { - if (memberEpoch > 0) + if (memberEpoch > 0) { error = streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked(); - else + } else { error = streamsRebalanceListenerInvoker.get().invokeAllTasksLost(); + } } else { @@ -1532,10 +1533,11 @@ private void runRebalanceCallbacksOnClose() { SortedSet droppedPartitions = new TreeSet<>(TOPIC_PARTITION_COMPARATOR); droppedPartitions.addAll(assignedPartitions); - if (memberEpoch > 0) + if (memberEpoch > 0) { error = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions); - else + } else { error = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions); + } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java index 3902fff4924b1..1297df7b1eeb6 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStreamsRebalanceListenerTest.java @@ -178,11 +178,10 @@ void testOnTasksAssignedWithException() { final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); when(streamsRebalanceData.subtopologies()).thenReturn(Map.of()); createRebalanceListenerWithRebalanceData(streamsRebalanceData); - + final Optional result = defaultStreamsRebalanceListener.onTasksAssigned( new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()) ); - assertTrue(result.isPresent()); assertEquals(exception, result.get()); verify(taskManager).handleAssignment(any(), any()); @@ -212,7 +211,6 @@ void testOnAllTasksLostWithException() { final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class); when(streamsRebalanceData.subtopologies()).thenReturn(Map.of()); createRebalanceListenerWithRebalanceData(streamsRebalanceData); - final Optional result = defaultStreamsRebalanceListener.onAllTasksLost(); assertTrue(result.isPresent()); assertEquals(exception, result.get()); From 9003acc5f772a3d3231bf73d3065995ff69c7e07 Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Thu, 11 Sep 2025 12:08:48 +0200 Subject: [PATCH 3/7] Address commens --- .../internals/AsyncKafkaConsumer.java | 30 ++--- .../StreamsRebalanceListenerInvoker.java | 100 +++++++++-------- .../internals/AsyncKafkaConsumerTest.java | 6 +- .../StreamsRebalanceListenerInvokerTest.java | 105 +++++------------- 4 files changed, 91 insertions(+), 150 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index 972ba4b866fd8..02e18fda2b3f5 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -187,16 +187,6 @@ public class AsyncKafkaConsumer implements ConsumerDelegate { */ private class BackgroundEventProcessor implements EventProcessor { - private final Optional streamsRebalanceListenerInvoker; - - public BackgroundEventProcessor() { - this.streamsRebalanceListenerInvoker = Optional.empty(); - } - - public BackgroundEventProcessor(final Optional streamsRebalanceListenerInvoker) { - this.streamsRebalanceListenerInvoker = streamsRebalanceListenerInvoker; - } - @Override public void process(final BackgroundEvent event) { switch (event.type()) { @@ -276,16 +266,14 @@ private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback private StreamsOnTasksAssignedCallbackCompletedEvent invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment, final CompletableFuture future) { - final Optional error; final Optional exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksAssigned(assignment)); - error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws an error")); + final Optional error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws an error")); return new StreamsOnTasksAssignedCallbackCompletedEvent(future, error); } private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback(final CompletableFuture future) { - final Optional error; final Optional exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeAllTasksLost()); - error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws an error")); + final Optional error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws an error")); return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error); } @@ -495,7 +483,7 @@ public AsyncKafkaConsumer(final ConsumerConfig config, ); this.streamsRebalanceListenerInvoker = streamsRebalanceData.map(s -> new StreamsRebalanceListenerInvoker(logContext, s)); - this.backgroundEventProcessor = new BackgroundEventProcessor(streamsRebalanceListenerInvoker); + this.backgroundEventProcessor = new BackgroundEventProcessor(); this.backgroundEventReaper = backgroundEventReaperFactory.build(logContext); // The FetchCollector is only used on the application thread. @@ -1507,7 +1495,7 @@ private void autoCommitOnClose(final Timer timer) { } private void runRebalanceCallbacksOnClose() { - if (groupMetadata.get().isEmpty() || rebalanceListenerInvoker == null) + if (groupMetadata.get().isEmpty()) return; int memberEpoch = groupMetadata.get().get().generationId(); @@ -1956,12 +1944,12 @@ public void subscribe(Collection topics, ConsumerRebalanceListener liste } public void subscribe(Collection topics, StreamsRebalanceListener streamsRebalanceListener) { + + streamsRebalanceListenerInvoker + .orElseThrow(() -> new IllegalStateException("Consumer was not created to be used with Streams rebalance protocol events")) + .setRebalanceListener(streamsRebalanceListener); + subscribeInternal(topics, Optional.empty()); - if (streamsRebalanceListenerInvoker.isPresent()) { - streamsRebalanceListenerInvoker.get().setRebalanceListener(streamsRebalanceListener); - } else { - throw new IllegalStateException("Consumer was not created to be used with Streams rebalance protocol events"); - } } @Override diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java index fc7a47635a62a..f90b326cdc52f 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; +import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -44,70 +45,75 @@ public class StreamsRebalanceListenerInvoker { } public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListener) { - this.listener = Optional.ofNullable(streamsRebalanceListener); + Objects.requireNonNull(streamsRebalanceListener, "StreamsRebalanceListener cannot be null"); + if (listener.isPresent()) { + throw new IllegalStateException("StreamsRebalanceListener can only be set once"); + } + this.listener = Optional.of(streamsRebalanceListener); } public Exception invokeAllTasksRevoked() { - if (listener.isPresent()) { - return invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks()); + if (listener.isEmpty()) { + throw new IllegalStateException("StreamsRebalanceListener is not defined"); } - - return null; + return invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks()); } public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment assignment) { - if (listener.isPresent()) { - log.info("Adding newly assigned tasks: {}", assignment); - try { - listener.get().onTasksAssigned(assignment); - } catch (WakeupException | InterruptException e) { - throw e; - } catch (Exception e) { - log.error( - "Streams rebalance listener failed on invocation of onTasksAssigned for tasks {}", - assignment, - e - ); - return e; - } + if (listener.isEmpty()) { + throw new IllegalStateException("StreamsRebalanceListener is not defined"); + } + log.info("Invoking tasks assigned callback for new assignment: {}", assignment); + try { + listener.get().onTasksAssigned(assignment); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksAssigned for tasks {}", + assignment, + e + ); + return e; } return null; } public Exception invokeTasksRevoked(final Set tasks) { - if (listener.isPresent()) { - log.info("Revoke previously assigned tasks {}", tasks); - try { - listener.get().onTasksRevoked(tasks); - } catch (WakeupException | InterruptException e) { - throw e; - } catch (Exception e) { - log.error( - "Streams rebalance listener failed on invocation of onTasksRevoked for tasks {}", - tasks, - e - ); - return e; - } + if (listener.isEmpty()) { + throw new IllegalStateException("StreamsRebalanceListener is not defined"); + } + log.info("Invoking task revoked callback for revoked active tasks {}", tasks); + try { + listener.get().onTasksRevoked(tasks); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksRevoked for tasks {}", + tasks, + e + ); + return e; } - return null; } public Exception invokeAllTasksLost() { - if (listener.isPresent()) { - log.info("Lost all previously assigned tasks"); - try { - listener.get().onAllTasksLost(); - } catch (WakeupException | InterruptException e) { - throw e; - } catch (Exception e) { - log.error( - "Streams rebalance listener failed on invocation of onTasksLost.", - e - ); - return e; - } + if (listener.isEmpty()) { + throw new IllegalStateException("StreamsRebalanceListener is not defined"); + } + log.info("Invoking tasks lost callback for all tasks"); + try { + listener.get().onAllTasksLost(); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error( + "Streams rebalance listener failed on invocation of onTasksLost.", + e + ); + return e; } return null; } diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java index a0281b00cb226..9aef4cf7f5293 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumerTest.java @@ -2233,7 +2233,6 @@ public void testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpoc @Test public void testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpochZeroOrNegative() { - // Test that close() calls streamsRebalanceListener.invokeAllTasksLost() when memberEpoch <= 0 final String groupId = "streamsGroup"; final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()); @@ -2271,9 +2270,8 @@ public void testCloseWrapsStreamsRebalanceListenerException() { KafkaException thrownException = assertThrows(KafkaException.class, () -> consumer.close(CloseOptions.timeout(Duration.ZERO))); - - assertNotNull(thrownException.getCause()); - assertTrue(thrownException.getCause() instanceof RuntimeException); + + assertInstanceOf(RuntimeException.class, thrownException.getCause()); assertTrue(thrownException.getCause().getMessage().contains("Test streams listener exception")); verify(mockStreamsListener).onTasksRevoked(any()); } diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java index 1c13fdb013163..44c90b67993d1 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java @@ -37,7 +37,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.never; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -59,67 +59,44 @@ public void setup() { invoker = new StreamsRebalanceListenerInvoker(logContext, streamsRebalanceData); } - @Test - public void testConstructorInitializesWithEmptyListener() { - // When invoker is constructed, it should have no listener set initially - // This is verified by testing that invoke methods return null when no listener is present - assertNull(invoker.invokeAllTasksRevoked()); - assertNull(invoker.invokeAllTasksLost()); - } - - @Test - public void testSetRebalanceListener() { - // Test setting a listener - invoker.setRebalanceListener(mockListener); - - // Verify listener is set by checking that methods no longer return null immediately - // (we'll mock the dependencies needed for actual invocation) - StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); - when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); - when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty()); - - // Should now invoke the listener instead of returning null immediately - Exception result = invoker.invokeAllTasksRevoked(); - assertNull(result); // No exception thrown by mock listener - verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks())); - } - @Test public void testSetRebalanceListenerWithNull() { - // Test setting listener to null - invoker.setRebalanceListener(null); - - // Should behave as if no listener is set - assertNull(invoker.invokeAllTasksRevoked()); - assertNull(invoker.invokeAllTasksLost()); + NullPointerException exception = assertThrows(NullPointerException.class, + () -> invoker.setRebalanceListener(null)); + assertEquals("StreamsRebalanceListener cannot be null", exception.getMessage()); } @Test - public void testSetRebalanceListenerOverwritesExisting() { - StreamsRebalanceListener firstListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); - StreamsRebalanceListener secondListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); - - StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); - when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); - when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + public void testSetRebalanceListenerThrowsWhenCalledTwice() { + StreamsRebalanceListener firstListener = mock(StreamsRebalanceListener.class); + StreamsRebalanceListener secondListener = mock(StreamsRebalanceListener.class); // Set first listener invoker.setRebalanceListener(firstListener); - // Overwrite with second listener - invoker.setRebalanceListener(secondListener); - - // Should use second listener - invoker.invokeAllTasksRevoked(); - verify(firstListener, never()).onTasksRevoked(any()); - verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks())); + // Attempting to set second listener should throw IllegalStateException + IllegalStateException exception = assertThrows(IllegalStateException.class, + () -> invoker.setRebalanceListener(secondListener)); + assertEquals("StreamsRebalanceListener can only be set once", exception.getMessage()); } @Test - public void testInvokeAllTasksRevokedWithNoListener() { - // When no listener is set, should return null - Exception result = invoker.invokeAllTasksRevoked(); - assertNull(result); + public void testInvokeMethodsWithNoListener() { + IllegalStateException exception1 = assertThrows(IllegalStateException.class, + () -> invoker.invokeAllTasksRevoked()); + assertEquals("StreamsRebalanceListener is not defined", exception1.getMessage()); + + IllegalStateException exception2 = assertThrows(IllegalStateException.class, + () -> invoker.invokeTasksAssigned(createMockAssignment())); + assertEquals("StreamsRebalanceListener is not defined", exception2.getMessage()); + + IllegalStateException exception3 = assertThrows(IllegalStateException.class, + () -> invoker.invokeTasksRevoked(createMockTasks())); + assertEquals("StreamsRebalanceListener is not defined", exception3.getMessage()); + + IllegalStateException exception4 = assertThrows(IllegalStateException.class, + () -> invoker.invokeAllTasksLost()); + assertEquals("StreamsRebalanceListener is not defined", exception4.getMessage()); } @Test @@ -136,16 +113,6 @@ public void testInvokeAllTasksRevokedWithListener() { verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks())); } - @Test - public void testInvokeTasksAssignedWithNoListener() { - StreamsRebalanceData.Assignment assignment = createMockAssignment(); - - Exception result = invoker.invokeTasksAssigned(assignment); - - assertNull(result); - verify(mockListener, never()).onTasksAssigned(any()); - } - @Test public void testInvokeTasksAssignedWithListener() { invoker.setRebalanceListener(mockListener); @@ -199,16 +166,6 @@ public void testInvokeTasksAssignedWithOtherException() { verify(mockListener).onTasksAssigned(eq(assignment)); } - @Test - public void testInvokeTasksRevokedWithNoListener() { - Set tasks = createMockTasks(); - - Exception result = invoker.invokeTasksRevoked(tasks); - - assertNull(result); - verify(mockListener, never()).onTasksRevoked(any()); - } - @Test public void testInvokeTasksRevokedWithListener() { invoker.setRebalanceListener(mockListener); @@ -262,14 +219,6 @@ public void testInvokeTasksRevokedWithOtherException() { verify(mockListener).onTasksRevoked(eq(tasks)); } - @Test - public void testInvokeAllTasksLostWithNoListener() { - Exception result = invoker.invokeAllTasksLost(); - - assertNull(result); - verify(mockListener, never()).onAllTasksLost(); - } - @Test public void testInvokeAllTasksLostWithListener() { invoker.setRebalanceListener(mockListener); From 479b9a95d02ff9d8f23a86a57e21a87d5256166f Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Thu, 11 Sep 2025 13:58:41 +0200 Subject: [PATCH 4/7] allow resetting the same listener --- .../consumer/internals/StreamsRebalanceListenerInvoker.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java index f90b326cdc52f..318e0f5dc556e 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java @@ -46,7 +46,7 @@ public class StreamsRebalanceListenerInvoker { public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListener) { Objects.requireNonNull(streamsRebalanceListener, "StreamsRebalanceListener cannot be null"); - if (listener.isPresent()) { + if (listener.isPresent() && listener.get() != streamsRebalanceListener) { throw new IllegalStateException("StreamsRebalanceListener can only be set once"); } this.listener = Optional.of(streamsRebalanceListener); From d1ef6e688a1108a4bba397147f3fe1e08621e864 Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Mon, 15 Sep 2025 17:11:26 +0200 Subject: [PATCH 5/7] comments --- .../clients/consumer/internals/AsyncKafkaConsumer.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index 02e18fda2b3f5..632f43d036098 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -1445,7 +1445,7 @@ private void close(Duration timeout, CloseOptions.GroupMembershipOperation membe () -> autoCommitOnClose(closeTimer), firstException); swallow(log, Level.ERROR, "Failed to stop finding coordinator", this::stopFindCoordinatorOnClose, firstException); - swallow(log, Level.ERROR, "Failed to release group assignment", + swallow(log, Level.ERROR, "Failed to run rebalance callbacks", this::runRebalanceCallbacksOnClose, firstException); swallow(log, Level.ERROR, "Failed to leave group while closing consumer", () -> leaveGroupOnClose(closeTimer, membershipOperation), firstException); @@ -1502,7 +1502,7 @@ private void runRebalanceCallbacksOnClose() { final Exception error; - if (streamsRebalanceListenerInvoker.isPresent()) { + if (streamsRebalanceListenerInvoker != null && streamsRebalanceListenerInvoker.isPresent()) { if (memberEpoch > 0) { error = streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked(); @@ -1510,7 +1510,7 @@ private void runRebalanceCallbacksOnClose() { error = streamsRebalanceListenerInvoker.get().invokeAllTasksLost(); } - } else { + } else if (rebalanceListenerInvoker != null) { Set assignedPartitions = groupAssignmentSnapshot.get(); From 3ec89749cb3699accc6581cdc8964203665b686b Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Mon, 15 Sep 2025 17:15:02 +0200 Subject: [PATCH 6/7] address comments --- .../kafka/clients/consumer/internals/AsyncKafkaConsumer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java index 632f43d036098..938ae909027d0 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncKafkaConsumer.java @@ -1500,7 +1500,7 @@ private void runRebalanceCallbacksOnClose() { int memberEpoch = groupMetadata.get().get().generationId(); - final Exception error; + Exception error = null; if (streamsRebalanceListenerInvoker != null && streamsRebalanceListenerInvoker.isPresent()) { From add3b44d7df8e0d7b89ae76e33420527afecea36 Mon Sep 17 00:00:00 2001 From: Lucas Brutschy Date: Tue, 16 Sep 2025 13:53:46 +0200 Subject: [PATCH 7/7] fix --- .../StreamsRebalanceListenerInvoker.java | 3 --- .../StreamsRebalanceListenerInvokerTest.java | 27 ++++++++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java index 318e0f5dc556e..f4c5aa4addc5b 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvoker.java @@ -46,9 +46,6 @@ public class StreamsRebalanceListenerInvoker { public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListener) { Objects.requireNonNull(streamsRebalanceListener, "StreamsRebalanceListener cannot be null"); - if (listener.isPresent() && listener.get() != streamsRebalanceListener) { - throw new IllegalStateException("StreamsRebalanceListener can only be set once"); - } this.listener = Optional.of(streamsRebalanceListener); } diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java index 44c90b67993d1..2f3e5ab05230c 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/StreamsRebalanceListenerInvokerTest.java @@ -37,7 +37,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -67,17 +67,24 @@ public void testSetRebalanceListenerWithNull() { } @Test - public void testSetRebalanceListenerThrowsWhenCalledTwice() { - StreamsRebalanceListener firstListener = mock(StreamsRebalanceListener.class); - StreamsRebalanceListener secondListener = mock(StreamsRebalanceListener.class); - + public void testSetRebalanceListenerOverwritesExisting() { + StreamsRebalanceListener firstListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); + StreamsRebalanceListener secondListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class); + + StreamsRebalanceData.Assignment mockAssignment = createMockAssignment(); + when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment); + when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty()); + // Set first listener invoker.setRebalanceListener(firstListener); - - // Attempting to set second listener should throw IllegalStateException - IllegalStateException exception = assertThrows(IllegalStateException.class, - () -> invoker.setRebalanceListener(secondListener)); - assertEquals("StreamsRebalanceListener can only be set once", exception.getMessage()); + + // Overwrite with second listener + invoker.setRebalanceListener(secondListener); + + // Should use second listener + invoker.invokeAllTasksRevoked(); + verify(firstListener, never()).onTasksRevoked(any()); + verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks())); } @Test