Skip to content

Trace websocket for spring webflux reactive handlers #8831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import datadog.trace.instrumentation.netty41.server.HttpServerResponseTracingHandler;
import datadog.trace.instrumentation.netty41.server.HttpServerTracingHandler;
import datadog.trace.instrumentation.netty41.server.MaybeBlockResponseHandler;
import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerRequestTracingHandler;
import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerResponseTracingHandler;
import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerInboundTracingHandler;
import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerOutboundTracingHandler;
import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerTracingHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;
Expand All @@ -34,6 +34,8 @@
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketFrameDecoder;
import io.netty.handler.codec.http.websocketx.WebSocketFrameEncoder;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.util.Attribute;
import net.bytebuddy.asm.Advice;
Expand Down Expand Up @@ -82,8 +84,8 @@ public String[] helperClassNames() {
packageName + ".server.HttpServerTracingHandler",
packageName + ".server.MaybeBlockResponseHandler",
packageName + ".server.websocket.WebSocketServerTracingHandler",
packageName + ".server.websocket.WebSocketServerResponseTracingHandler",
packageName + ".server.websocket.WebSocketServerRequestTracingHandler",
packageName + ".server.websocket.WebSocketServerOutboundTracingHandler",
packageName + ".server.websocket.WebSocketServerInboundTracingHandler",
packageName + ".NettyHttp2Helper",
packageName + ".NettyPipelineHelper",
};
Expand Down Expand Up @@ -162,23 +164,31 @@ public static void addHandler(
HttpServerResponseTracingHandler.INSTANCE,
MaybeBlockResponseHandler.INSTANCE);
} else if (handler instanceof WebSocketServerProtocolHandler) {
if (InstrumenterConfig.get().isWebsocketTracingEnabled()) {
if (pipeline.get(HttpServerTracingHandler.class) != null) {
NettyPipelineHelper.addHandlerAfter(
pipeline, "HttpServerTracingHandler#0", new WebSocketServerTracingHandler());
if (InstrumenterConfig.get().isWebsocketTracingEnabled()
&& pipeline.get(HttpServerTracingHandler.class) != null) {
// remove single websocket handler if added before
if (pipeline.get(WebSocketServerInboundTracingHandler.class) != null) {
pipeline.remove(WebSocketServerInboundTracingHandler.class);
}
if (pipeline.get(HttpServerRequestTracingHandler.class) != null) {
NettyPipelineHelper.addHandlerAfter(
pipeline,
"HttpServerRequestTracingHandler#0",
WebSocketServerRequestTracingHandler.INSTANCE);
}
if (pipeline.get(HttpServerResponseTracingHandler.class) != null) {
NettyPipelineHelper.addHandlerAfter(
pipeline,
"HttpServerResponseTracingHandler#0",
WebSocketServerResponseTracingHandler.INSTANCE);
if (pipeline.get(WebSocketServerOutboundTracingHandler.class) != null) {
pipeline.remove(WebSocketServerOutboundTracingHandler.class);
}
NettyPipelineHelper.addHandlerAfter(
pipeline,
pipeline.get(HttpServerTracingHandler.class),
new WebSocketServerTracingHandler());
}
} else if (handler instanceof WebSocketFrameDecoder) {
if (InstrumenterConfig.get().isWebsocketTracingEnabled()
&& pipeline.get(WebSocketServerTracingHandler.class) == null) {
NettyPipelineHelper.addHandlerAfter(
pipeline, handler, WebSocketServerInboundTracingHandler.INSTANCE);
}
} else if (handler instanceof WebSocketFrameEncoder) {
if (InstrumenterConfig.get().isWebsocketTracingEnabled()
&& pipeline.get(WebSocketServerTracingHandler.class) == null) {
NettyPipelineHelper.addHandlerAfter(
pipeline, handler, WebSocketServerOutboundTracingHandler.INSTANCE);
}
}
// Client pipeline handlers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
import io.netty.handler.codec.http.websocketx.WebSocketFrame;

@ChannelHandler.Sharable
public class WebSocketServerRequestTracingHandler extends ChannelInboundHandlerAdapter {
public static WebSocketServerRequestTracingHandler INSTANCE =
new WebSocketServerRequestTracingHandler();
public class WebSocketServerInboundTracingHandler extends ChannelInboundHandlerAdapter {
public static WebSocketServerInboundTracingHandler INSTANCE =
new WebSocketServerInboundTracingHandler();

@Override
public void channelRead(ChannelHandlerContext ctx, Object frame) {

if (frame instanceof WebSocketFrame) {
Channel channel = ctx.channel();
HandlerContext.Receiver receiverContext =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import io.netty.handler.codec.http.websocketx.WebSocketFrame;

@ChannelHandler.Sharable
public class WebSocketServerResponseTracingHandler extends ChannelOutboundHandlerAdapter {
public static WebSocketServerResponseTracingHandler INSTANCE =
new WebSocketServerResponseTracingHandler();
public class WebSocketServerOutboundTracingHandler extends ChannelOutboundHandlerAdapter {
public static WebSocketServerOutboundTracingHandler INSTANCE =
new WebSocketServerOutboundTracingHandler();

@Override
public void write(ChannelHandlerContext ctx, Object frame, ChannelPromise promise)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

public class WebSocketServerTracingHandler
extends CombinedChannelDuplexHandler<
WebSocketServerRequestTracingHandler, WebSocketServerResponseTracingHandler> {
WebSocketServerInboundTracingHandler, WebSocketServerOutboundTracingHandler> {

public WebSocketServerTracingHandler() {
super(
WebSocketServerRequestTracingHandler.INSTANCE,
WebSocketServerResponseTracingHandler.INSTANCE);
WebSocketServerInboundTracingHandler.INSTANCE,
WebSocketServerOutboundTracingHandler.INSTANCE);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datadog.trace.agent.test.AgentTestRunner
import datadog.trace.agent.test.asserts.TraceAssert
import datadog.trace.agent.test.base.OkHttpWebsocketClient
import datadog.trace.api.DDSpanTypes
import datadog.trace.api.DDTags
import datadog.trace.bootstrap.instrumentation.api.Tags
Expand All @@ -9,6 +10,9 @@ import dd.trace.instrumentation.springwebflux.server.EchoHandlerFunction
import dd.trace.instrumentation.springwebflux.server.FooModel
import dd.trace.instrumentation.springwebflux.server.SpringWebFluxTestApplication
import dd.trace.instrumentation.springwebflux.server.TestController
import dd.trace.instrumentation.springwebflux.server.WsHandler
import net.bytebuddy.utility.RandomString
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.context.TestConfiguration
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory
Expand All @@ -21,6 +25,10 @@ import org.springframework.web.reactive.function.client.WebClient
import org.springframework.web.server.ResponseStatusException
import reactor.core.publisher.Mono

import static datadog.trace.agent.test.base.HttpServerTest.websocketCloseSpan
import static datadog.trace.agent.test.base.HttpServerTest.websocketReceiveSpan
import static datadog.trace.agent.test.base.HttpServerTest.websocketSendSpan

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = [SpringWebFluxTestApplication, ForceNettyAutoConfiguration],
properties = "server.http2.enabled=true")
Expand All @@ -40,13 +48,22 @@ class SpringWebfluxTest extends AgentTestRunner {
@LocalServerPort
int port

WebClient client = WebClient.builder().clientConnector (new ReactorClientHttpConnector()).build()
@Autowired
private WsHandler wsHandler

WebClient client = WebClient.builder().clientConnector(new ReactorClientHttpConnector()).build()

@Override
boolean useStrictTraceWrites() {
false
}

@Override
protected void configurePreAgent() {
super.configurePreAgent()
injectSysConfig("trace.websocket.messages.enabled", "true")
}

def "Basic GET test #testName"() {
setup:
String url = "http://localhost:$port$urlPath"
Expand All @@ -61,7 +78,7 @@ class SpringWebfluxTest extends AgentTestRunner {
sortSpansByStart()
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
}
trace(2) {
span {
Expand Down Expand Up @@ -142,7 +159,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
}
trace(3) {
span {
Expand Down Expand Up @@ -237,7 +254,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
}
trace(3) {
span {
Expand Down Expand Up @@ -285,7 +302,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 404, true)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true)
}
trace(2) {
span {
Expand Down Expand Up @@ -331,7 +348,7 @@ class SpringWebfluxTest extends AgentTestRunner {
String url = "http://localhost:$port/echo"

when:
def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString),String)).exchange().block()
def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString), String)).exchange().block()

then:
response.statusCode().value() == 202
Expand All @@ -341,7 +358,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "POST", URI.create(url), 202)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202)
}
trace(3) {
span {
Expand Down Expand Up @@ -406,7 +423,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 500)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500)
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500)
}
trace(2) {
span {
Expand Down Expand Up @@ -495,7 +512,7 @@ class SpringWebfluxTest extends AgentTestRunner {
trace(2) {
sortSpansByStart()
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 307)
traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307)
traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307)
}

trace(2) {
Expand Down Expand Up @@ -540,7 +557,7 @@ class SpringWebfluxTest extends AgentTestRunner {
trace(2) {
sortSpansByStart()
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(finalUrl))
traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl))
traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl))
}
trace(2) {
sortSpansByStart()
Expand Down Expand Up @@ -599,7 +616,7 @@ class SpringWebfluxTest extends AgentTestRunner {
def traceParent
trace(2) {
clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url))
}
trace(2) {
span {
Expand Down Expand Up @@ -660,6 +677,73 @@ class SpringWebfluxTest extends AgentTestRunner {
"annotation API delayed response" | "/foo-delayed" | "/foo-delayed" | "getFooDelayed" | new FooModel(3L, "delayed").toString()
}

def 'test websocket server receive #msgType message of size #size and #chunks chunks'() {
when:
String url = "http://localhost:$port/websocket"
def wsClient = new OkHttpWebsocketClient()
wsClient.connect(url)
wsHandler.awaitConnected()
if (message instanceof String) {
wsClient.send(message as String)
} else {
wsClient.send(message as byte[])
}
wsHandler.awaitExchangeComplete()
wsClient.close(1001, "goodbye")

then:
assertTraces(3, {
DDSpan handshake
trace(2) {
sortSpansByStart()
handshake = span(0)
span {
resourceName "GET /websocket"
operationName "netty.request"
spanType DDSpanTypes.HTTP_SERVER
tags {
"$Tags.COMPONENT" "netty"
"$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER
"$Tags.PEER_HOST_IPV4" "127.0.0.1"
"$Tags.PEER_PORT" Integer
"$Tags.HTTP_URL" url
"$Tags.HTTP_HOSTNAME" "localhost"
"$Tags.HTTP_METHOD" "GET"
"$Tags.HTTP_STATUS" 101
"$Tags.HTTP_USER_AGENT" String
"$Tags.HTTP_CLIENT_IP" "127.0.0.1"
"$Tags.HTTP_ROUTE" "/websocket"
defaultTags()
}
}
span {
resourceName "WsHandler.handle"
operationName "WsHandler.handle"
spanType DDSpanTypes.HTTP_SERVER
childOfPrevious()
tags {
"$Tags.COMPONENT" "spring-webflux-controller"
"$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER
"handler.type" WsHandler.getName()
defaultTags()
}
}
}
trace(2) {
sortSpansByStart()
websocketReceiveSpan(it, handshake, msgType, size, chunks)
websocketSendSpan(it, handshake, msgType, size, chunks)
}
trace(1) {
websocketCloseSpan(it, handshake, false, 1001, "goodbye")
}
})
where:
message | msgType | chunks | size
RandomString.make(10) | "text" | 1 | 10
RandomString.make(20).getBytes("UTF-8") | "binary" | 1 | 20
}

def clientSpan(
TraceAssert trace,
Object parentSpan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.context.annotation.Bean
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.reactive.HandlerMapping
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.reactive.function.server.HandlerFunction
import org.springframework.web.reactive.function.server.RouterFunction
import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.reactive.function.server.ServerResponse
import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping
import org.springframework.web.reactive.socket.WebSocketHandler
import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter
import reactor.core.publisher.Mono

import java.time.Duration
Expand All @@ -26,6 +30,22 @@ class SpringWebFluxTestApplication {
return route(POST("/echo"), new EchoHandlerFunction(echoHandler))
}

@Bean
WebSocketHandlerAdapter webSocketHandlerAdapter() {
return new WebSocketHandlerAdapter()
}

@Bean
HandlerMapping wsHandlerMapping(WsHandler wsHandler) {
Map<String, WebSocketHandler> map = new HashMap<>()
map.put("/websocket", wsHandler)

SimpleUrlHandlerMapping handlerMapping = new SimpleUrlHandlerMapping()
handlerMapping.setOrder(1)
handlerMapping.setUrlMap(map)
return handlerMapping
}

@Bean
RouterFunction<ServerResponse> greetRouterFunction(GreetingHandler greetingHandler) {
return route(GET("/greet"), new HandlerFunction<ServerResponse>() {
Expand Down
Loading
Loading