Skip to content

Commit 97596fb

Browse files
committed
Allow plugging in a WebSocketHandlerDecorator
The WebSocketMessageBroker config now allows wrapping the SubProtocolWebSocketHandler to enable advanced use cases that may require access to the underlying WebSocketSession. Issue: SPR-12314
1 parent f0323be commit 97596fb

File tree

8 files changed

+283
-35
lines changed

8 files changed

+283
-35
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java

+61-12
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import java.util.List;
2222
import java.util.Map;
2323

24+
import org.springframework.beans.factory.FactoryBean;
25+
import org.springframework.web.socket.WebSocketHandler;
26+
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
2427
import org.w3c.dom.Element;
2528

2629
import org.springframework.beans.MutablePropertyValues;
@@ -89,7 +92,9 @@
8992
*/
9093
class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
9194

92-
private static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler";
95+
public static final String WEB_SOCKET_HANDLER_BEAN_NAME = "subProtocolWebSocketHandler";
96+
97+
public static final String SOCKJS_SCHEDULER_BEAN_NAME = "messageBrokerSockJsScheduler";
9398

9499
private static final int DEFAULT_MAPPING_ORDER = 1;
95100

@@ -156,7 +161,7 @@ public BeanDefinition parse(Element element, ParserContext context) {
156161
scopeConfigurer.getPropertyValues().add("scopes", scopeMap);
157162
registerBeanDefByName("webSocketScopeConfigurer", scopeConfigurer, context, source);
158163

159-
registerWebSocketMessageBrokerStats(subProtoHandler, broker, inChannel, outChannel, context, source);
164+
registerWebSocketMessageBrokerStats(broker, inChannel, outChannel, context, source);
160165

161166
context.popAndRegisterContainingComponent();
162167
return null;
@@ -228,22 +233,32 @@ private RuntimeBeanReference registerSubProtoHandler(Element element, RuntimeBea
228233
cavs.addIndexedArgumentValue(0, inChannel);
229234
cavs.addIndexedArgumentValue(1, outChannel);
230235

231-
RootBeanDefinition beanDef = new RootBeanDefinition(SubProtocolWebSocketHandler.class, cavs, null);
232-
beanDef.getPropertyValues().addPropertyValue("protocolHandlers", stompHandlerDef);
236+
RootBeanDefinition handlerDef = new RootBeanDefinition(SubProtocolWebSocketHandler.class, cavs, null);
237+
handlerDef.getPropertyValues().addPropertyValue("protocolHandlers", stompHandlerDef);
238+
registerBeanDefByName(WEB_SOCKET_HANDLER_BEAN_NAME, handlerDef, context, source);
239+
RuntimeBeanReference result = new RuntimeBeanReference(WEB_SOCKET_HANDLER_BEAN_NAME);
233240

234241
Element transportElem = DomUtils.getChildElementByTagName(element, "transport");
235242
if (transportElem != null) {
236243
if (transportElem.hasAttribute("message-size")) {
237244
stompHandlerDef.getPropertyValues().add("messageSizeLimit", transportElem.getAttribute("message-size"));
238245
}
239246
if (transportElem.hasAttribute("send-timeout")) {
240-
beanDef.getPropertyValues().add("sendTimeLimit", transportElem.getAttribute("send-timeout"));
247+
handlerDef.getPropertyValues().add("sendTimeLimit", transportElem.getAttribute("send-timeout"));
241248
}
242249
if (transportElem.hasAttribute("send-buffer-size")) {
243-
beanDef.getPropertyValues().add("sendBufferSizeLimit", transportElem.getAttribute("send-buffer-size"));
250+
handlerDef.getPropertyValues().add("sendBufferSizeLimit", transportElem.getAttribute("send-buffer-size"));
251+
}
252+
Element factoriesElement = DomUtils.getChildElementByTagName(transportElem, "decorator-factories");
253+
if (factoriesElement != null) {
254+
ManagedList<Object> factories = extractBeanSubElements(factoriesElement, context);
255+
RootBeanDefinition factoryBean = new RootBeanDefinition(DecoratingFactoryBean.class);
256+
factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(0, handlerDef);
257+
factoryBean.getConstructorArgumentValues().addIndexedArgumentValue(1, factories);
258+
result = new RuntimeBeanReference(registerBeanDef(factoryBean, context, source));
244259
}
245260
}
246-
return new RuntimeBeanReference(registerBeanDef(beanDef, context, source));
261+
return result;
247262
}
248263

249264
private RuntimeBeanReference registerRequestHandler(Element element, RuntimeBeanReference subProtoHandler,
@@ -448,14 +463,15 @@ private RuntimeBeanReference registerUserDestinationMessageHandler(RuntimeBeanRe
448463
return new RuntimeBeanReference(registerBeanDef(beanDef, context, source));
449464
}
450465

451-
private void registerWebSocketMessageBrokerStats(RuntimeBeanReference subProtoHandler,
452-
RootBeanDefinition broker, RuntimeBeanReference inChannel, RuntimeBeanReference outChannel,
453-
ParserContext context, Object source) {
466+
private void registerWebSocketMessageBrokerStats(RootBeanDefinition broker, RuntimeBeanReference inChannel,
467+
RuntimeBeanReference outChannel, ParserContext context, Object source) {
454468

455469
RootBeanDefinition beanDef = new RootBeanDefinition(WebSocketMessageBrokerStats.class);
456-
beanDef.getPropertyValues().add("subProtocolWebSocketHandler", subProtoHandler);
457470

458-
if (StompBrokerRelayMessageHandler.class.equals(broker.getBeanClass())) {
471+
RuntimeBeanReference webSocketHandler = new RuntimeBeanReference(WEB_SOCKET_HANDLER_BEAN_NAME);
472+
beanDef.getPropertyValues().add("subProtocolWebSocketHandler", webSocketHandler);
473+
474+
if (StompBrokerRelayMessageHandler.class.equals(broker.getBeanClass())) {
459475
beanDef.getPropertyValues().add("stompBrokerRelay", broker);
460476
}
461477
String name = inChannel.getBeanName() + "Executor";
@@ -486,4 +502,37 @@ private static void registerBeanDefByName(String name, RootBeanDefinition beanDe
486502
context.registerComponent(new BeanComponentDefinition(beanDef, name));
487503
}
488504

505+
506+
private static class DecoratingFactoryBean implements FactoryBean<WebSocketHandler> {
507+
508+
private final WebSocketHandler handler;
509+
510+
private final List<WebSocketHandlerDecoratorFactory> factories;
511+
512+
513+
private DecoratingFactoryBean(WebSocketHandler handler, List<WebSocketHandlerDecoratorFactory> factories) {
514+
this.handler = handler;
515+
this.factories = factories;
516+
}
517+
518+
@Override
519+
public WebSocketHandler getObject() throws Exception {
520+
WebSocketHandler result = this.handler;
521+
for (WebSocketHandlerDecoratorFactory factory : this.factories) {
522+
result = factory.decorate(result);
523+
}
524+
return result;
525+
}
526+
527+
@Override
528+
public Class<?> getObjectType() {
529+
return WebSocketHandler.class;
530+
}
531+
532+
@Override
533+
public boolean isSingleton() {
534+
return true;
535+
}
536+
}
537+
489538
}

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupport.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
import org.springframework.messaging.simp.config.AbstractMessageBrokerConfiguration;
2424
import org.springframework.messaging.simp.stomp.StompBrokerRelayMessageHandler;
2525
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
26+
import org.springframework.util.Assert;
2627
import org.springframework.web.servlet.HandlerMapping;
2728
import org.springframework.web.socket.WebSocketHandler;
2829
import org.springframework.web.socket.config.WebSocketMessageBrokerStats;
30+
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
31+
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
2932
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
3033

3134
/**
@@ -47,7 +50,9 @@ public abstract class WebSocketMessageBrokerConfigurationSupport extends Abstrac
4750

4851
@Bean
4952
public HandlerMapping stompWebSocketHandlerMapping() {
50-
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(subProtocolWebSocketHandler(),
53+
WebSocketHandler handler = subProtocolWebSocketHandler();
54+
handler = decorateWebSocketHandler(handler);
55+
WebMvcStompEndpointRegistry registry = new WebMvcStompEndpointRegistry(handler,
5156
getTransportRegistration(), userSessionRegistry(), messageBrokerSockJsTaskScheduler());
5257
registry.setApplicationContext(getApplicationContext());
5358
registerStompEndpoints(registry);
@@ -59,6 +64,13 @@ public WebSocketHandler subProtocolWebSocketHandler() {
5964
return new SubProtocolWebSocketHandler(clientInboundChannel(), clientOutboundChannel());
6065
}
6166

67+
protected WebSocketHandler decorateWebSocketHandler(WebSocketHandler handler) {
68+
for (WebSocketHandlerDecoratorFactory factory : getTransportRegistration().getDecoratorFactories()) {
69+
handler = factory.decorate(handler);
70+
}
71+
return handler;
72+
}
73+
6274
protected final WebSocketTransportRegistration getTransportRegistration() {
6375
if (this.transportRegistration == null) {
6476
this.transportRegistration = new WebSocketTransportRegistration();

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketTransportRegistration.java

+40
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
package org.springframework.web.socket.config.annotation;
1717

1818

19+
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
20+
21+
import java.util.ArrayList;
22+
import java.util.Arrays;
23+
import java.util.List;
24+
1925
/**
2026
* Configure the processing of messages received from and sent to WebSocket clients.
2127
*
@@ -30,6 +36,9 @@ public class WebSocketTransportRegistration {
3036

3137
private Integer sendBufferSizeLimit;
3238

39+
private final List<WebSocketHandlerDecoratorFactory> decoratorFactories =
40+
new ArrayList<WebSocketHandlerDecoratorFactory>(2);
41+
3342

3443
/**
3544
* Configure the maximum size for an incoming sub-protocol message.
@@ -147,4 +156,35 @@ public WebSocketTransportRegistration setSendBufferSizeLimit(int sendBufferSizeL
147156
protected Integer getSendBufferSizeLimit() {
148157
return this.sendBufferSizeLimit;
149158
}
159+
160+
/**
161+
* Configure one or more factories to decorate the handler used to process
162+
* WebSocket messages. This may be useful in some advanced use cases, for
163+
* example to allow Spring Security to forcibly close the WebSocket session
164+
* when the corresponding HTTP session expires.
165+
* @since 4.1.2
166+
*/
167+
public WebSocketTransportRegistration setDecoratorFactories(WebSocketHandlerDecoratorFactory... factories) {
168+
if (factories != null) {
169+
this.decoratorFactories.addAll(Arrays.asList(factories));
170+
}
171+
return this;
172+
}
173+
174+
/**
175+
* Add a factory that to decorate the handler used to process WebSocket
176+
* messages. This may be useful for some advanced use cases, for example
177+
* to allow Spring Security to forcibly close the WebSocket session when
178+
* the corresponding HTTP session expires.
179+
* @since 4.1.2
180+
*/
181+
public WebSocketTransportRegistration addDecoratorFactory(WebSocketHandlerDecoratorFactory factory) {
182+
this.decoratorFactories.add(factory);
183+
return this;
184+
}
185+
186+
protected List<WebSocketHandlerDecoratorFactory> getDecoratorFactories() {
187+
return this.decoratorFactories;
188+
}
189+
150190
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright 2002-2014 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.socket.handler;
18+
19+
import org.springframework.web.socket.WebSocketHandler;
20+
21+
/**
22+
* A factory for applying decorators to a WebSocketHandler.
23+
*
24+
* <p>Decoration should be done through sub-classing
25+
* {@link org.springframework.web.socket.handler.WebSocketHandlerDecorator
26+
* WebSocketHandlerDecorator} to allow any code to traverse decorators and/or
27+
* unwrap the original handler when necessary .
28+
*
29+
* @author Rossen Stoyanchev
30+
* @since 4.1.2
31+
*/
32+
public interface WebSocketHandlerDecoratorFactory {
33+
34+
/**
35+
* Decorate the given WebSocketHandler.
36+
* @param handler the handler to be decorated.
37+
* @return the same handler or the handler wrapped with a sub-class of
38+
* {@code WebSocketHandlerDecorator}.
39+
*/
40+
WebSocketHandler decorate(WebSocketHandler handler);
41+
42+
}

spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd

+32
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,38 @@
497497
]]></xsd:documentation>
498498
</xsd:annotation>
499499
<xsd:complexType>
500+
<xsd:sequence>
501+
<xsd:element name="decorator-factories" maxOccurs="1" minOccurs="0">
502+
<xsd:complexType>
503+
<xsd:annotation>
504+
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
505+
Configure one or more factories to decorate the handler used to process WebSocket
506+
messages. This may be useful for some advanced use cases, for example to allow
507+
Spring Security to forcibly close the WebSocket session when the corresponding
508+
HTTP session expires.
509+
]]></xsd:documentation>
510+
</xsd:annotation>
511+
<xsd:sequence>
512+
<xsd:choice minOccurs="1" maxOccurs="unbounded">
513+
<xsd:element ref="beans:bean">
514+
<xsd:annotation>
515+
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
516+
A WebSocketHandlerDecoratorFactory bean definition.
517+
]]></xsd:documentation>
518+
</xsd:annotation>
519+
</xsd:element>
520+
<xsd:element ref="beans:ref">
521+
<xsd:annotation>
522+
<xsd:documentation source="org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory"><![CDATA[
523+
A reference to a WebSocketHandlerDecoratorFactory bean.
524+
]]></xsd:documentation>
525+
</xsd:annotation>
526+
</xsd:element>
527+
</xsd:choice>
528+
</xsd:sequence>
529+
</xsd:complexType>
530+
</xsd:element>
531+
</xsd:sequence>
500532
<xsd:attribute name="message-size" type="xsd:string">
501533
<xsd:annotation>
502534
<xsd:documentation><![CDATA[

spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java

+22-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
import org.springframework.web.servlet.HandlerMapping;
6363
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
6464
import org.springframework.web.socket.WebSocketHandler;
65+
import org.springframework.web.socket.WebSocketSession;
66+
import org.springframework.web.socket.handler.TestWebSocketSession;
6567
import org.springframework.web.socket.handler.WebSocketHandlerDecorator;
68+
import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory;
6669
import org.springframework.web.socket.messaging.StompSubProtocolHandler;
6770
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
6871
import org.springframework.web.socket.server.HandshakeHandler;
@@ -93,7 +96,7 @@ public void setup() {
9396

9497

9598
@Test
96-
public void simpleBroker() {
99+
public void simpleBroker() throws Exception {
97100
loadBeanDefinitions("websocket-config-broker-simple.xml");
98101

99102
HandlerMapping hm = this.appContext.getBean(HandlerMapping.class);
@@ -113,6 +116,10 @@ public void simpleBroker() {
113116
List<HandshakeInterceptor> interceptors = wsHttpRequestHandler.getHandshakeInterceptors();
114117
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class)));
115118

119+
WebSocketSession session = new TestWebSocketSession("id");
120+
wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session);
121+
assertEquals(true, session.getAttributes().get("decorated"));
122+
116123
WebSocketHandler wsHandler = unwrapWebSocketHandler(wsHttpRequestHandler.getWebSocketHandler());
117124
assertNotNull(wsHandler);
118125
assertThat(wsHandler, Matchers.instanceOf(SubProtocolWebSocketHandler.class));
@@ -429,7 +436,6 @@ public boolean supportsParameter(MethodParameter parameter) {
429436
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
430437
return null;
431438
}
432-
433439
}
434440

435441
class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
@@ -443,4 +449,18 @@ public boolean supportsReturnType(MethodParameter returnType) {
443449
public void handleReturnValue(Object returnValue, MethodParameter returnType, Message<?> message) throws Exception {
444450

445451
}
452+
}
453+
454+
class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory {
455+
456+
@Override
457+
public WebSocketHandler decorate(WebSocketHandler handler) {
458+
return new WebSocketHandlerDecorator(handler) {
459+
@Override
460+
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
461+
session.getAttributes().put("decorated", true);
462+
super.afterConnectionEstablished(session);
463+
}
464+
};
465+
}
446466
}

0 commit comments

Comments
 (0)