Skip to content

Support allowedOriginPatterns in SockJS config #26108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public class SockJsServiceRegistration {

private final List<String> allowedOrigins = new ArrayList<>();

private final List<String> allowedOriginPatterns = new ArrayList<>();

@Nullable
private Boolean suppressCors;

Expand Down Expand Up @@ -232,6 +234,18 @@ protected SockJsServiceRegistration setAllowedOrigins(String... allowedOrigins)
return this;
}

/**
* Configure allowed {@code Origin} pattern header values.
* @since 5.3.2
*/
protected SockJsServiceRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
this.allowedOriginPatterns.clear();
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
}
return this;
}

/**
* This option can be used to disable automatic addition of CORS headers for
* SockJS requests.
Expand Down Expand Up @@ -284,6 +298,7 @@ protected SockJsService getSockJsService() {
service.setSuppressCors(this.suppressCors);
}
service.setAllowedOrigins(this.allowedOrigins);
service.setAllowedOriginPatterns(this.allowedOriginPatterns);

if (this.messageCodec != null) {
service.setMessageCodec(this.messageCodec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,11 @@ public interface StompWebSocketEndpointRegistration {
*/
StompWebSocketEndpointRegistration setAllowedOrigins(String... origins);

/**
* Configure allowed {@code Origin} header values.
*
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
*/
StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... originPatterns);

}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE

private final List<String> allowedOrigins = new ArrayList<>();

private final List<String> allowedOriginPatterns = new ArrayList<>();

@Nullable
private SockJsServiceRegistration registration;

Expand Down Expand Up @@ -97,6 +99,15 @@ public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOri
return this;
}

@Override
public StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
this.allowedOriginPatterns.clear();
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
}
return this;
}

@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new SockJsServiceRegistration();
Expand All @@ -112,13 +123,22 @@ public SockJsServiceRegistration withSockJS() {
if (!this.allowedOrigins.isEmpty()) {
this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
}
if (!this.allowedOriginPatterns.isEmpty()) {
this.registration.setAllowedOriginPatterns(StringUtils.toStringArray(this.allowedOriginPatterns));
}
return this.registration;
}

protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1);
interceptors.addAll(this.interceptors);
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
OriginHandshakeInterceptor originHandshakeInterceptor = new OriginHandshakeInterceptor(this.allowedOrigins);
interceptors.add(originHandshakeInterceptor);

if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) {
originHandshakeInterceptor.setAllowedOriginPatterns(this.allowedOriginPatterns);
}

return interceptors.toArray(new HandshakeInterceptor[0]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

package org.springframework.web.socket.server.support;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -30,6 +31,7 @@
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils;
Expand All @@ -45,7 +47,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {

protected final Log logger = LogFactory.getLog(getClass());

private final Set<String> allowedOrigins = new LinkedHashSet<>();
private final CorsConfiguration corsConfiguration = new CorsConfiguration();


/**
Expand Down Expand Up @@ -74,8 +76,7 @@ public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
this.corsConfiguration.setAllowedOrigins(new ArrayList<>(allowedOrigins));
}

/**
Expand All @@ -84,15 +85,41 @@ public void setAllowedOrigins(Collection<String> allowedOrigins) {
* @see #setAllowedOrigins
*/
public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableSet(this.allowedOrigins);
if (this.corsConfiguration.getAllowedOrigins() == null) {
return Collections.emptyList();
}
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOrigins()));
}

/**
* Configure allowed {@code Origin} pattern header values.
*
* @see CorsConfiguration#setAllowedOriginPatterns(List)
*/
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
this.corsConfiguration.setAllowedOriginPatterns(new ArrayList<>(allowedOriginPatterns));
}

/**
* Return the allowed {@code Origin} pattern header values.
*
* @since 5.3.2
* @see CorsConfiguration#getAllowedOriginPatterns()
*/
public Collection<String> getAllowedOriginPatterns() {
if (this.corsConfiguration.getAllowedOriginPatterns() == null) {
return Collections.emptyList();
}
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOriginPatterns()));
}


@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {

if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
if (!WebUtils.isSameOrigin(request) && this.corsConfiguration.checkOrigin(request.getHeaders().getOrigin()) == null) {
response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig

protected final Set<String> allowedOrigins = new LinkedHashSet<>();

protected final Set<String> allowedOriginPatterns = new LinkedHashSet<>();

private final SockJsRequestHandler infoHandler = new InfoHandler();

private final SockJsRequestHandler iframeHandler = new IframeHandler();
Expand Down Expand Up @@ -319,6 +321,17 @@ public void setAllowedOrigins(Collection<String> allowedOrigins) {
this.allowedOrigins.addAll(allowedOrigins);
}

/**
* Configure allowed {@code Origin} header values.
*
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
*/
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
this.allowedOriginPatterns.clear();
this.allowedOriginPatterns.addAll(allowedOriginPatterns);
}

/**
* Return configure allowed {@code Origin} header values.
* @since 4.1.2
Expand All @@ -328,6 +341,15 @@ public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableSet(this.allowedOrigins);
}

/**
* Return configure allowed {@code Origin} pattern header values.
* @since 5.3.2
* @see #setAllowedOriginPatterns
*/
public Collection<String> getAllowedOriginPatterns() {
return Collections.unmodifiableSet(this.allowedOriginPatterns);
}


/**
* This method determines the SockJS path and handles SockJS static URLs.
Expand Down Expand Up @@ -498,6 +520,7 @@ public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) {
CorsConfiguration config = new CorsConfiguration();
config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins));
config.setAllowedOriginPatterns(new ArrayList<>(this.allowedOriginPatterns));
config.addAllowedMethod("*");
config.setAllowCredentials(true);
config.setMaxAge(ONE_YEAR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,32 @@ public void allowedOriginsWithSockJsService() {
assertThat(sockJsService.shouldSuppressCors()).isFalse();
}

@Test
public void allowedOriginPatterns() {
WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);

String origin = "https://*.mydomain.com";
registration.setAllowedOriginPatterns(origin).withSockJS();

MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1);
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull();
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();

registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.withSockJS().setAllowedOriginPatterns(origin);
mappings = registration.getMappings();
assertThat(mappings.size()).isEqualTo(1);
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertThat(requestHandler.getSockJsService()).isNotNull();
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();
}

@Test // SPR-12283
public void disableCorsWithSockJsService() {
WebMvcStompWebSocketEndpointRegistration registration =
Expand Down