diff --git a/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptor.java new file mode 100644 index 00000000000..be682659483 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptor.java @@ -0,0 +1,164 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.context; + +import java.util.Stack; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.support.ExecutorChannelInterceptor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.util.Assert; + +/** + * An {@link ExecutorChannelInterceptor} that takes an {@link Authentication} from the + * current {@link SecurityContext} (if any) in the + * {@link #preSend(Message, MessageChannel)} callback and stores it into an + * {@link #authenticationHeaderName} message header. Then sets the context from this + * header in the {@link #beforeHandle(Message, MessageChannel, MessageHandler)} and + * {@link #postReceive(Message, MessageChannel)} both of which typically happen on a + * different thread. + *

+ * Note: cannot be used in combination with a {@link SecurityContextChannelInterceptor} on + * the same channel since both these interceptors modify a security context on a handling + * and receiving operations. + * + * @author Artem Bilan + * @since 6.2 + * @see SecurityContextChannelInterceptor + */ +public final class SecurityContextPropagationChannelInterceptor implements ExecutorChannelInterceptor { + + private static final ThreadLocal> originalContext = new ThreadLocal<>(); + + private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder + .getContextHolderStrategy(); + + private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext(); + + private final String authenticationHeaderName; + + private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", + AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + + /** + * Create a new instance using the header of the name + * {@link SimpMessageHeaderAccessor#USER_HEADER}. + */ + public SecurityContextPropagationChannelInterceptor() { + this(SimpMessageHeaderAccessor.USER_HEADER); + } + + /** + * Create a new instance that uses the specified header to populate the + * {@link Authentication}. + * @param authenticationHeaderName the header name to populate the + * {@link Authentication}. Cannot be null. + */ + public SecurityContextPropagationChannelInterceptor(String authenticationHeaderName) { + Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null"); + this.authenticationHeaderName = authenticationHeaderName; + } + + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) { + this.securityContextHolderStrategy = strategy; + this.empty = this.securityContextHolderStrategy.createEmptyContext(); + } + + /** + * Configure an Authentication used for anonymous authentication. Default is:

+	 * new AnonymousAuthenticationToken("key", "anonymous",
+	 * 		AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
+	 * 
+ * @param authentication the Authentication used for anonymous authentication. Cannot + * be null. + */ + public void setAnonymousAuthentication(Authentication authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + this.anonymous = authentication; + } + + @Override + public Message preSend(Message message, MessageChannel channel) { + Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); + if (authentication == null) { + authentication = this.anonymous; + } + return MessageBuilder.fromMessage(message).setHeader(this.authenticationHeaderName, authentication).build(); + } + + @Override + public Message beforeHandle(Message message, MessageChannel channel, MessageHandler handler) { + return postReceive(message, channel); + } + + @Override + public Message postReceive(Message message, MessageChannel channel) { + setup(message); + return message; + } + + @Override + public void afterMessageHandled(Message message, MessageChannel channel, MessageHandler handler, Exception ex) { + cleanup(); + } + + private void setup(Message message) { + Authentication authentication = message.getHeaders().get(this.authenticationHeaderName, Authentication.class); + SecurityContext currentContext = this.securityContextHolderStrategy.getContext(); + Stack contextStack = originalContext.get(); + if (contextStack == null) { + contextStack = new Stack<>(); + originalContext.set(contextStack); + } + contextStack.push(currentContext); + SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); + context.setAuthentication(authentication); + this.securityContextHolderStrategy.setContext(context); + } + + private void cleanup() { + Stack contextStack = originalContext.get(); + if (contextStack == null || contextStack.isEmpty()) { + this.securityContextHolderStrategy.clearContext(); + originalContext.remove(); + return; + } + SecurityContext context = contextStack.pop(); + try { + if (this.empty.equals(context)) { + this.securityContextHolderStrategy.clearContext(); + originalContext.remove(); + } + else { + this.securityContextHolderStrategy.setContext(context); + } + } + catch (Throwable ex) { + this.securityContextHolderStrategy.clearContext(); + } + } + +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptorTests.java new file mode 100644 index 00000000000..1682b96200a --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextPropagationChannelInterceptorTests.java @@ -0,0 +1,160 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.context; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessageHandler; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolderStrategy; +import org.springframework.security.core.context.SecurityContextImpl; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +public class SecurityContextPropagationChannelInterceptorTests { + + @Mock + MessageChannel channel; + + @Mock + MessageHandler handler; + + MessageBuilder messageBuilder; + + Authentication authentication; + + SecurityContextPropagationChannelInterceptor interceptor; + + @BeforeEach + public void setup() { + this.authentication = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); + this.messageBuilder = MessageBuilder.withPayload("payload"); + this.interceptor = new SecurityContextPropagationChannelInterceptor(); + } + + @AfterEach + public void cleanup() { + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + SecurityContextHolder.clearContext(); + } + + @Test + public void preSendDefaultHeader() { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + Message message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(message.getHeaders()).containsEntry(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + } + + @Test + public void preSendCustomHeader() { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + String headerName = "header"; + this.interceptor = new SecurityContextPropagationChannelInterceptor(headerName); + Message message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(message.getHeaders()).containsEntry(headerName, this.authentication); + } + + @Test + public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() { + SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); + strategy.setContext(new SecurityContextImpl(this.authentication)); + this.interceptor.setSecurityContextHolderStrategy(strategy); + Message message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); + this.interceptor.beforeHandle(message, this.channel, this.handler); + verify(strategy, times(2)).getContext(); + assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); + } + + @Test + public void preSendUserNoContext() { + Message message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(message.getHeaders()).containsKey(SimpMessageHeaderAccessor.USER_HEADER); + assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER)) + .isInstanceOf(AnonymousAuthenticationToken.class); + } + + @Test + public void beforeHandleUserSet() { + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); + } + + @Test + public void postReceiveUserSet() { + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.postReceive(this.messageBuilder.build(), this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); + } + + @Test + public void authenticationIsPropagatedFromPreSendToPostReceive() { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + Message message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); + assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER)).isSameAs(this.authentication); + this.interceptor.postReceive(message, this.channel); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); + } + + @Test + public void beforeHandleUserNotSet() { + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void afterMessageHandledUserNotSet() { + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void afterMessageHandled() { + SecurityContextHolder.getContext().setAuthentication(this.authentication); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void restoresOriginalContext() { + TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER"); + SecurityContextHolder.getContext().setAuthentication(original); + this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); + this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); + this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original); + } + +}