Skip to content

Commit 3d1ae9c

Browse files
committed
Efficient and consistent setAllowedOrigins collection type
Issue: SPR-13761
1 parent cd4ce87 commit 3d1ae9c

File tree

7 files changed

+93
-100
lines changed

7 files changed

+93
-100
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java

+12-15
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
package org.springframework.web.socket.server.support;
1818

19-
import java.util.ArrayList;
2019
import java.util.Collection;
2120
import java.util.Collections;
22-
import java.util.List;
21+
import java.util.LinkedHashSet;
2322
import java.util.Map;
23+
import java.util.Set;
2424

2525
import org.apache.commons.logging.Log;
2626
import org.apache.commons.logging.LogFactory;
@@ -34,8 +34,8 @@
3434
import org.springframework.web.util.WebUtils;
3535

3636
/**
37-
* An interceptor to check request {@code Origin} header value against a collection of
38-
* allowed origins.
37+
* An interceptor to check request {@code Origin} header value against a
38+
* collection of allowed origins.
3939
*
4040
* @author Sebastien Deleuze
4141
* @since 4.1.2
@@ -44,60 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
4444

4545
protected Log logger = LogFactory.getLog(getClass());
4646

47-
private final List<String> allowedOrigins;
47+
private final Set<String> allowedOrigins = new LinkedHashSet<String>();
4848

4949

5050
/**
5151
* Default constructor with only same origin requests allowed.
5252
*/
5353
public OriginHandshakeInterceptor() {
54-
this.allowedOrigins = new ArrayList<String>();
5554
}
5655

5756
/**
5857
* Constructor using the specified allowed origin values.
59-
*
6058
* @see #setAllowedOrigins(Collection)
6159
*/
6260
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
63-
this();
6461
setAllowedOrigins(allowedOrigins);
6562
}
6663

64+
6765
/**
6866
* Configure allowed {@code Origin} header values. This check is mostly
6967
* designed for browsers. There is nothing preventing other types of client
7068
* to modify the {@code Origin} header value.
71-
*
7269
* <p>Each provided allowed origin must have a scheme, and optionally a port
7370
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
7471
* string may also be "*" in which case all origins are allowed.
75-
*
7672
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
7773
*/
7874
public void setAllowedOrigins(Collection<String> allowedOrigins) {
79-
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
75+
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
8076
this.allowedOrigins.clear();
8177
this.allowedOrigins.addAll(allowedOrigins);
8278
}
8379

8480
/**
85-
* @see #setAllowedOrigins(Collection)
8681
* @since 4.1.5
82+
* @see #setAllowedOrigins
8783
*/
8884
public Collection<String> getAllowedOrigins() {
89-
return Collections.unmodifiableList(this.allowedOrigins);
85+
return Collections.unmodifiableSet(this.allowedOrigins);
9086
}
9187

9288

9389
@Override
9490
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
9591
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
92+
9693
if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
9794
response.setStatusCode(HttpStatus.FORBIDDEN);
9895
if (logger.isDebugEnabled()) {
99-
logger.debug("Handshake request rejected, Origin header value "
100-
+ request.getHeaders().getOrigin() + " not allowed");
96+
logger.debug("Handshake request rejected, Origin header value " +
97+
request.getHeaders().getOrigin() + " not allowed");
10198
}
10299
return false;
103100
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java

+47-45
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
import java.io.IOException;
2020
import java.nio.charset.Charset;
21-
import java.util.ArrayList;
2221
import java.util.Arrays;
22+
import java.util.Collection;
2323
import java.util.Collections;
2424
import java.util.Date;
2525
import java.util.HashSet;
26+
import java.util.LinkedHashSet;
2627
import java.util.List;
2728
import java.util.Random;
29+
import java.util.Set;
2830
import java.util.concurrent.TimeUnit;
2931
import javax.servlet.http.HttpServletRequest;
3032

@@ -56,7 +58,7 @@
5658
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
5759
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
5860
*
59-
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)}
61+
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins}
6062
* to specify a list of allowed origins (a list containing "*" will allow all origins).
6163
*
6264
* @author Rossen Stoyanchev
@@ -94,10 +96,10 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
9496

9597
private boolean webSocketEnabled = true;
9698

97-
private final List<String> allowedOrigins = new ArrayList<String>();
98-
9999
private boolean suppressCors = false;
100100

101+
protected final Set<String> allowedOrigins = new LinkedHashSet<String>();
102+
101103

102104
public AbstractSockJsService(TaskScheduler scheduler) {
103105
Assert.notNull(scheduler, "TaskScheduler must not be null");
@@ -274,6 +276,24 @@ public boolean isWebSocketEnabled() {
274276
return this.webSocketEnabled;
275277
}
276278

279+
/**
280+
* This option can be used to disable automatic addition of CORS headers for
281+
* SockJS requests.
282+
* <p>The default value is "false".
283+
* @since 4.1.2
284+
*/
285+
public void setSuppressCors(boolean suppressCors) {
286+
this.suppressCors = suppressCors;
287+
}
288+
289+
/**
290+
* @since 4.1.2
291+
* @see #setSuppressCors(boolean)
292+
*/
293+
public boolean shouldSuppressCors() {
294+
return this.suppressCors;
295+
}
296+
277297
/**
278298
* Configure allowed {@code Origin} header values. This check is mostly
279299
* designed for browsers. There is nothing preventing other types of client
@@ -289,36 +309,18 @@ public boolean isWebSocketEnabled() {
289309
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
290310
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
291311
*/
292-
public void setAllowedOrigins(List<String> allowedOrigins) {
293-
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
312+
public void setAllowedOrigins(Collection<String> allowedOrigins) {
313+
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
294314
this.allowedOrigins.clear();
295315
this.allowedOrigins.addAll(allowedOrigins);
296316
}
297317

298318
/**
299319
* @since 4.1.2
300-
* @see #setAllowedOrigins(List)
301-
*/
302-
public List<String> getAllowedOrigins() {
303-
return Collections.unmodifiableList(this.allowedOrigins);
304-
}
305-
306-
/**
307-
* This option can be used to disable automatic addition of CORS headers for
308-
* SockJS requests.
309-
* <p>The default value is "false".
310-
* @since 4.1.2
311-
*/
312-
public void setSuppressCors(boolean suppressCors) {
313-
this.suppressCors = suppressCors;
314-
}
315-
316-
/**
317-
* @since 4.1.2
318-
* @see #setSuppressCors(boolean)
320+
* @see #setAllowedOrigins
319321
*/
320-
public boolean shouldSuppressCors() {
321-
return this.suppressCors;
322+
public Collection<String> getAllowedOrigins() {
323+
return Collections.unmodifiableSet(this.allowedOrigins);
322324
}
323325

324326

@@ -465,24 +467,11 @@ private boolean validatePath(ServerHttpRequest request) {
465467
String path = request.getURI().getPath();
466468
int index = path.lastIndexOf('/') + 1;
467469
String filename = path.substring(index);
468-
return filename.indexOf(';') == -1;
470+
return (filename.indexOf(';') == -1);
469471
}
470472

471-
/**
472-
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing.
473-
*/
474-
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
475-
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
476-
477-
/**
478-
* Handle a SockJS session URL (i.e. transport-specific request).
479-
*/
480-
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
481-
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
482-
483-
484-
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
485-
HttpMethod... httpMethods) throws IOException {
473+
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods)
474+
throws IOException {
486475

487476
if (WebUtils.isSameOrigin(request)) {
488477
return true;
@@ -529,6 +518,19 @@ protected void sendMethodNotAllowed(ServerHttpResponse response, HttpMethod... h
529518
}
530519

531520

521+
/**
522+
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing.
523+
*/
524+
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
525+
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
526+
527+
/**
528+
* Handle a SockJS session URL (i.e. transport-specific request).
529+
*/
530+
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
531+
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
532+
533+
532534
private interface SockJsRequestHandler {
533535

534536
void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException;
@@ -546,8 +548,8 @@ public void handle(ServerHttpRequest request, ServerHttpResponse response) throw
546548
addNoCacheHeaders(response);
547549
if (checkOrigin(request, response)) {
548550
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
549-
String content = String.format(INFO_CONTENT, random.nextInt(),
550-
isSessionCookieNeeded(), isWebSocketEnabled());
551+
String content = String.format(
552+
INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
551553
response.getBody().write(content.getBytes());
552554
}
553555

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ protected boolean validateRequest(String serverId, String sessionId, String tran
326326
return false;
327327
}
328328

329-
if (!getAllowedOrigins().contains("*")) {
329+
if (!this.allowedOrigins.contains("*")) {
330330
TransportType transportType = TransportType.fromValue(transport);
331331
if (transportType == null || !transportType.supportsOrigin()) {
332332
if (logger.isWarnEnabled()) {

spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java

+15-15
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,13 @@
1616

1717
package org.springframework.web.socket.config;
1818

19-
import static org.hamcrest.Matchers.*;
20-
import static org.junit.Assert.*;
21-
2219
import java.io.IOException;
2320
import java.io.InputStream;
24-
import java.util.Arrays;
2521
import java.util.Date;
2622
import java.util.List;
2723
import java.util.Map;
2824
import java.util.concurrent.ScheduledFuture;
2925

30-
import org.junit.Before;
3126
import org.junit.Test;
3227

3328
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
@@ -67,6 +62,9 @@
6762
import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler;
6863
import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler;
6964

65+
import static org.hamcrest.Matchers.*;
66+
import static org.junit.Assert.*;
67+
7068
/**
7169
* Test fixture for HandlersBeanDefinitionParser.
7270
* See test configuration files websocket-config-handlers-*.xml.
@@ -76,13 +74,7 @@
7674
*/
7775
public class HandlersBeanDefinitionParserTests {
7876

79-
private GenericWebApplicationContext appContext;
80-
81-
82-
@Before
83-
public void setup() {
84-
this.appContext = new GenericWebApplicationContext();
85-
}
77+
private GenericWebApplicationContext appContext = new GenericWebApplicationContext();
8678

8779

8880
@Test
@@ -234,10 +226,12 @@ public void sockJsAttributes() {
234226

235227
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
236228
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
237-
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
238229
assertTrue(transportService.shouldSuppressCors());
230+
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com"));
231+
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com"));
239232
}
240233

234+
241235
private void loadBeanDefinitions(String fileName) {
242236
XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext);
243237
ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class);
@@ -278,9 +272,11 @@ public boolean supportsPartialMessages() {
278272
}
279273
}
280274

275+
281276
class FooWebSocketHandler extends TestWebSocketHandler {
282277
}
283278

279+
284280
class TestHandshakeHandler implements HandshakeHandler {
285281

286282
@Override
@@ -291,9 +287,11 @@ public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse respons
291287
}
292288
}
293289

290+
294291
class TestChannelInterceptor extends ChannelInterceptorAdapter {
295292
}
296293

294+
297295
class FooTestInterceptor implements HandshakeInterceptor {
298296

299297
@Override
@@ -309,9 +307,11 @@ public void afterHandshake(ServerHttpRequest request, ServerHttpResponse respons
309307
}
310308
}
311309

310+
312311
class BarTestInterceptor extends FooTestInterceptor {
313312
}
314313

314+
315315
@SuppressWarnings({ "unchecked", "rawtypes" })
316316
class TestTaskScheduler implements TaskScheduler {
317317

@@ -344,9 +344,9 @@ public ScheduledFuture scheduleWithFixedDelay(Runnable task, Date startTime, lon
344344
public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) {
345345
return null;
346346
}
347-
348347
}
349348

349+
350350
class TestMessageCodec implements SockJsMessageCodec {
351351

352352
@Override
@@ -363,4 +363,4 @@ public String[] decode(String content) throws IOException {
363363
public String[] decodeInputStream(InputStream content) throws IOException {
364364
return new String[0];
365365
}
366-
}
366+
}

0 commit comments

Comments
 (0)