Skip to content

Commit 7abf1a5

Browse files
committed
GH-9110: propagate Reactor context to ReactiveMessageHandler
Fixes: #9110 * Move context propagation utilities to the `IntegrationReactiveUtils` * Capture context into message header in the `IntegrationReactiveUtils.adaptSubscribableChannelToPublisher()` before `sink.tryEmitNext()` * Restore the context from message header in the `flatMap()` for `ReactiveStreamsConsumer.reactiveMessageHandler` **Auto-cherry-pick to `6.2.x`**
1 parent a3814be commit 7abf1a5

File tree

4 files changed

+92
-42
lines changed

4 files changed

+92
-42
lines changed

spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.util.concurrent.atomic.AtomicReference;
2222
import java.util.concurrent.locks.LockSupport;
2323

24-
import io.micrometer.context.ContextSnapshotFactory;
2524
import org.reactivestreams.Publisher;
2625
import org.reactivestreams.Subscriber;
2726
import reactor.core.Disposable;
@@ -31,17 +30,16 @@
3130
import reactor.core.publisher.Sinks;
3231
import reactor.core.scheduler.Scheduler;
3332
import reactor.core.scheduler.Schedulers;
34-
import reactor.util.context.Context;
3533
import reactor.util.context.ContextView;
3634

3735
import org.springframework.core.log.LogMessage;
3836
import org.springframework.integration.IntegrationMessageHeaderAccessor;
3937
import org.springframework.integration.StaticMessageHeaderAccessor;
4038
import org.springframework.integration.support.MutableMessageBuilder;
39+
import org.springframework.integration.util.IntegrationReactiveUtils;
4140
import org.springframework.messaging.Message;
4241
import org.springframework.messaging.MessageDeliveryException;
4342
import org.springframework.util.Assert;
44-
import org.springframework.util.ClassUtils;
4543

4644
/**
4745
* The {@link AbstractMessageChannel} implementation for the
@@ -56,9 +54,6 @@
5654
public class FluxMessageChannel extends AbstractMessageChannel
5755
implements Publisher<Message<?>>, ReactiveStreamsSubscribableChannel {
5856

59-
private static final boolean isContextPropagationPresent = ClassUtils.isPresent(
60-
"io.micrometer.context.ContextSnapshot", FluxMessageChannel.class.getClassLoader());
61-
6257
private final Scheduler scheduler = Schedulers.boundedElastic();
6358

6459
private final Sinks.Many<Message<?>> sink = Sinks.many().multicast().onBackpressureBuffer(1, false);
@@ -91,8 +86,8 @@ protected boolean doSend(Message<?> message, long timeout) {
9186

9287
private boolean tryEmitMessage(Message<?> message) {
9388
Message<?> messageToEmit = message;
94-
if (isContextPropagationPresent) {
95-
ContextView contextView = ContextSnapshotHelper.captureContext();
89+
if (IntegrationReactiveUtils.isContextPropagationPresent) {
90+
ContextView contextView = IntegrationReactiveUtils.captureReactorContext();
9691
if (!contextView.isEmpty()) {
9792
messageToEmit = MutableMessageBuilder.fromMessage(message)
9893
.setHeader(IntegrationMessageHeaderAccessor.REACTOR_CONTEXT, contextView)
@@ -196,14 +191,4 @@ public void destroy() {
196191
super.destroy();
197192
}
198193

199-
private static final class ContextSnapshotHelper {
200-
201-
private static final ContextSnapshotFactory CONTEXT_SNAPSHOT_FACTORY = ContextSnapshotFactory.builder().build();
202-
203-
static ContextView captureContext() {
204-
return CONTEXT_SNAPSHOT_FACTORY.captureAll().updateContext(Context.empty());
205-
}
206-
207-
}
208-
209194
}

spring-integration-core/src/main/java/org/springframework/integration/endpoint/ReactiveStreamsConsumer.java

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.function.Consumer;
2020
import java.util.function.Function;
2121

22-
import io.micrometer.context.ContextSnapshotFactory;
2322
import org.reactivestreams.Publisher;
2423
import org.reactivestreams.Subscriber;
2524
import org.reactivestreams.Subscription;
@@ -31,6 +30,7 @@
3130

3231
import org.springframework.context.Lifecycle;
3332
import org.springframework.integration.IntegrationMessageHeaderAccessor;
33+
import org.springframework.integration.StaticMessageHeaderAccessor;
3434
import org.springframework.integration.channel.ChannelUtils;
3535
import org.springframework.integration.channel.NullChannel;
3636
import org.springframework.integration.core.MessageProducer;
@@ -44,7 +44,6 @@
4444
import org.springframework.messaging.MessageHandler;
4545
import org.springframework.messaging.ReactiveMessageHandler;
4646
import org.springframework.util.Assert;
47-
import org.springframework.util.ClassUtils;
4847
import org.springframework.util.ErrorHandler;
4948

5049
/**
@@ -58,9 +57,6 @@
5857
*/
5958
public class ReactiveStreamsConsumer extends AbstractEndpoint implements IntegrationConsumer {
6059

61-
private static final boolean isContextPropagationPresent = ClassUtils.isPresent(
62-
"io.micrometer.context.ContextSnapshot", ReactiveStreamsConsumer.class.getClassLoader());
63-
6460
private final MessageChannel inputChannel;
6561

6662
private final Publisher<Message<Object>> publisher;
@@ -188,7 +184,9 @@ protected void doStart() {
188184
if (this.reactiveMessageHandler != null) {
189185
this.subscription =
190186
fluxFromChannel
191-
.flatMap(this.reactiveMessageHandler::handleMessage)
187+
.flatMap((message) ->
188+
this.reactiveMessageHandler.handleMessage(message)
189+
.contextWrite(StaticMessageHeaderAccessor.getReactorContext(message)))
192190
.onErrorContinue((ex, data) -> this.errorHandler.handleError(ex))
193191
.subscribe();
194192
}
@@ -300,7 +298,7 @@ protected void hookOnSubscribe(Subscription subscription) {
300298
protected void hookOnNext(Message<?> message) {
301299
Message<?> messageToDeliver = message;
302300

303-
if (isContextPropagationPresent) {
301+
if (IntegrationReactiveUtils.isContextPropagationPresent) {
304302
ContextView reactorContext = message.getHeaders()
305303
.get(IntegrationMessageHeaderAccessor.REACTOR_CONTEXT, ContextView.class);
306304

@@ -310,7 +308,7 @@ protected void hookOnNext(Message<?> message) {
310308
.removeHeader(IntegrationMessageHeaderAccessor.REACTOR_CONTEXT)
311309
.build();
312310

313-
try (AutoCloseable scope = ContextSnapshotHelper.setContext(reactorContext)) {
311+
try (AutoCloseable scope = IntegrationReactiveUtils.setThreadLocalsFromReactorContext(reactorContext)) {
314312
this.delegate.onNext(messageToDeliver);
315313
}
316314
catch (Exception ex) {
@@ -335,14 +333,4 @@ protected void hookOnComplete() {
335333

336334
}
337335

338-
private static final class ContextSnapshotHelper {
339-
340-
private static final ContextSnapshotFactory CONTEXT_SNAPSHOT_FACTORY = ContextSnapshotFactory.builder().build();
341-
342-
static AutoCloseable setContext(ContextView context) {
343-
return CONTEXT_SNAPSHOT_FACTORY.setThreadLocalsFrom(context);
344-
}
345-
346-
}
347-
348336
}

spring-integration-core/src/main/java/org/springframework/integration/util/IntegrationReactiveUtils.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2022 the original author or authors.
2+
* Copyright 2020-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,24 +19,30 @@
1919
import java.time.Duration;
2020
import java.util.concurrent.locks.LockSupport;
2121

22+
import io.micrometer.context.ContextSnapshotFactory;
2223
import org.apache.commons.logging.Log;
2324
import org.apache.commons.logging.LogFactory;
2425
import org.reactivestreams.Publisher;
2526
import reactor.core.publisher.Flux;
2627
import reactor.core.publisher.Mono;
2728
import reactor.core.publisher.Sinks;
2829
import reactor.core.scheduler.Schedulers;
30+
import reactor.util.context.Context;
31+
import reactor.util.context.ContextView;
2932
import reactor.util.retry.Retry;
3033

34+
import org.springframework.integration.IntegrationMessageHeaderAccessor;
3135
import org.springframework.integration.StaticMessageHeaderAccessor;
3236
import org.springframework.integration.acks.AckUtils;
3337
import org.springframework.integration.core.MessageSource;
38+
import org.springframework.integration.support.MutableMessageBuilder;
3439
import org.springframework.messaging.Message;
3540
import org.springframework.messaging.MessageChannel;
3641
import org.springframework.messaging.MessageHandler;
3742
import org.springframework.messaging.MessagingException;
3843
import org.springframework.messaging.PollableChannel;
3944
import org.springframework.messaging.SubscribableChannel;
45+
import org.springframework.util.ClassUtils;
4046

4147
/**
4248
* Utilities for adapting integration components to/from reactive types.
@@ -60,9 +66,40 @@ public final class IntegrationReactiveUtils {
6066
*/
6167
public static final Duration DEFAULT_DELAY_WHEN_EMPTY = Duration.ofSeconds(1);
6268

69+
/**
70+
* The indicator that {@code io.micrometer:context-propagation} library is on classpath.
71+
* @since 6.2.5
72+
*/
73+
public static final boolean isContextPropagationPresent = ClassUtils.isPresent(
74+
"io.micrometer.context.ContextSnapshot", IntegrationReactiveUtils.class.getClassLoader());
75+
76+
private static final ContextSnapshotFactory CONTEXT_SNAPSHOT_FACTORY = ContextSnapshotFactory.builder().build();
77+
6378
private IntegrationReactiveUtils() {
6479
}
6580

81+
/**
82+
* Capture a Reactor {@link ContextView} from the current thread local state
83+
* according to the {@link ContextSnapshotFactory} logic.
84+
* @return the Reactor {@link ContextView} from the current thread local state.
85+
* @since 6.2.5
86+
*/
87+
public static ContextView captureReactorContext() {
88+
return CONTEXT_SNAPSHOT_FACTORY.captureAll().updateContext(Context.empty());
89+
}
90+
91+
/**
92+
* Populate thread local variables from the provided Reactor {@link ContextView}
93+
* according to the {@link ContextSnapshotFactory} logic.
94+
* @param context the Reactor {@link ContextView} to populate from.
95+
* @return the {@link io.micrometer.context.ContextSnapshot.Scope} as a {@link AutoCloseable}
96+
* to not pollute the target classpath. Can be cast if necessary.
97+
* @since 6.2.5
98+
*/
99+
public static AutoCloseable setThreadLocalsFromReactorContext(ContextView context) {
100+
return CONTEXT_SNAPSHOT_FACTORY.setThreadLocalsFrom(context);
101+
}
102+
66103
/**
67104
* Wrap a provided {@link MessageSource} into a {@link Flux} for pulling the on demand.
68105
* When {@link MessageSource#receive()} returns {@code null}, the source {@link Mono}
@@ -137,8 +174,17 @@ private static <T> Flux<Message<T>> adaptSubscribableChannelToPublisher(Subscrib
137174
return Flux.defer(() -> {
138175
Sinks.Many<Message<T>> sink = Sinks.many().unicast().onBackpressureError();
139176
MessageHandler messageHandler = (message) -> {
177+
Message<?> messageToEmit = message;
178+
if (IntegrationReactiveUtils.isContextPropagationPresent) {
179+
ContextView contextView = IntegrationReactiveUtils.captureReactorContext();
180+
if (!contextView.isEmpty()) {
181+
messageToEmit = MutableMessageBuilder.fromMessage(message)
182+
.setHeader(IntegrationMessageHeaderAccessor.REACTOR_CONTEXT, contextView)
183+
.build();
184+
}
185+
}
140186
while (true) {
141-
switch (sink.tryEmitNext((Message<T>) message)) {
187+
switch (sink.tryEmitNext((Message<T>) messageToEmit)) {
142188
case FAIL_NON_SERIALIZED:
143189
case FAIL_OVERFLOW:
144190
LockSupport.parkNanos(1000); // NOSONAR

spring-integration-core/src/test/java/org/springframework/integration/support/management/observation/IntegrationObservabilityZipkinTests.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,31 @@
1818

1919
import java.util.concurrent.CountDownLatch;
2020
import java.util.concurrent.TimeUnit;
21+
import java.util.concurrent.atomic.AtomicReference;
2122

2223
import io.micrometer.common.KeyValues;
2324
import io.micrometer.core.tck.MeterRegistryAssert;
25+
import io.micrometer.observation.Observation;
2426
import io.micrometer.observation.ObservationRegistry;
2527
import io.micrometer.tracing.Span;
2628
import io.micrometer.tracing.test.SampleTestRunner;
2729
import io.micrometer.tracing.test.simple.SpansAssert;
30+
import reactor.core.publisher.Mono;
2831

2932
import org.springframework.beans.factory.annotation.Qualifier;
3033
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
3134
import org.springframework.context.annotation.Bean;
3235
import org.springframework.context.annotation.Configuration;
36+
import org.springframework.integration.annotation.BridgeTo;
3337
import org.springframework.integration.annotation.EndpointId;
3438
import org.springframework.integration.annotation.Poller;
3539
import org.springframework.integration.annotation.ServiceActivator;
3640
import org.springframework.integration.channel.NullChannel;
41+
import org.springframework.integration.channel.PublishSubscribeChannel;
3742
import org.springframework.integration.channel.QueueChannel;
3843
import org.springframework.integration.config.EnableIntegration;
3944
import org.springframework.integration.config.EnableIntegrationManagement;
45+
import org.springframework.integration.dsl.IntegrationFlow;
4046
import org.springframework.integration.gateway.MessagingGatewaySupport;
4147
import org.springframework.integration.handler.BridgeHandler;
4248
import org.springframework.integration.handler.advice.HandleMessageAdvice;
@@ -46,6 +52,7 @@
4652
import org.springframework.messaging.support.GenericMessage;
4753

4854
import static org.assertj.core.api.Assertions.assertThat;
55+
import static org.awaitility.Awaitility.await;
4956

5057
/**
5158
* @author Artem Bilan
@@ -93,6 +100,8 @@ public SampleTestRunnerConsumer yourCode() {
93100
assertThat(receive).isNull();
94101

95102
assertThat(configuration.observedHandlerLatch.await(10, TimeUnit.SECONDS)).isTrue();
103+
104+
await().untilAsserted(() -> assertThat(configuration.observationReference.get()).isNotNull());
96105
}
97106

98107
SpansAssert.assertThat(bb.getFinishedSpans())
@@ -110,7 +119,7 @@ public SampleTestRunnerConsumer yourCode() {
110119
.hasTag(IntegrationObservation.ProducerTags.COMPONENT_NAME.asString(), "queueChannel")
111120
.hasTag(IntegrationObservation.ProducerTags.COMPONENT_TYPE.asString(), "producer")
112121
.hasKindEqualTo(Span.Kind.PRODUCER))
113-
.hasSize(3);
122+
.hasSize(4);
114123

115124
MeterRegistryAssert.assertThat(getMeterRegistry())
116125
.hasTimerWithNameAndTags("spring.integration.handler",
@@ -125,7 +134,7 @@ public SampleTestRunnerConsumer yourCode() {
125134
@EnableIntegration
126135
@EnableIntegrationManagement(
127136
observationPatterns = {
128-
"${spring.integration.management.observation-patterns:testInboundGateway,skippedObservationInboundGateway,queueChannel,observedEndpoint}",
137+
"${spring.integration.management.observation-patterns:testInboundGateway,skippedObservationInboundGateway,queueChannel,observedEndpoint,publishSubscribeChannel}",
129138
"${spring.integration.management.observation-patterns:}"
130139
})
131140
public static class ObservationIntegrationTestConfiguration {
@@ -166,8 +175,10 @@ TestMessagingGatewaySupport skippedObservationInboundGateway() {
166175
@ServiceActivator(inputChannel = "queueChannel",
167176
poller = @Poller(fixedDelay = "100"),
168177
adviceChain = "observedHandlerAdvice")
169-
BridgeHandler bridgeHandler() {
170-
return new BridgeHandler();
178+
BridgeHandler bridgeHandler(PublishSubscribeChannel publishSubscribeChannel) {
179+
BridgeHandler bridgeHandler = new BridgeHandler();
180+
bridgeHandler.setOutputChannel(publishSubscribeChannel);
181+
return bridgeHandler;
171182
}
172183

173184
@Bean
@@ -182,6 +193,26 @@ HandleMessageAdvice observedHandlerAdvice() {
182193
};
183194
}
184195

196+
@Bean
197+
@BridgeTo
198+
PublishSubscribeChannel publishSubscribeChannel() {
199+
return new PublishSubscribeChannel();
200+
}
201+
202+
AtomicReference<Observation> observationReference = new AtomicReference<>();
203+
204+
@Bean
205+
IntegrationFlow handleReactiveFlow(PublishSubscribeChannel publishSubscribeChannel,
206+
ObservationRegistry observationRegistry) {
207+
208+
return IntegrationFlow.from(publishSubscribeChannel)
209+
.handleReactive(m ->
210+
Mono.just("Hi There")
211+
.doOnSuccess(val ->
212+
observationReference.set(observationRegistry.getCurrentObservation()))
213+
.then());
214+
}
215+
185216
}
186217

187218
private static class TestMessagingGatewaySupport extends MessagingGatewaySupport {

0 commit comments

Comments
 (0)