Skip to content

Allow the ability to configure AuthoritiesMapper in Reactive OAuth2Login #8361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.function.Function;
import java.util.function.Supplier;

import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

Expand Down Expand Up @@ -1056,8 +1057,11 @@ private ReactiveAuthenticationManager getAuthenticationManager() {

private ReactiveAuthenticationManager createDefault() {
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> client = getAccessTokenResponseClient();
ReactiveAuthenticationManager result = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService());

OAuth2LoginReactiveAuthenticationManager oauth2Manager = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService());
GrantedAuthoritiesMapper authoritiesMapper = getBeanOrNull(GrantedAuthoritiesMapper.class);
if (authoritiesMapper != null) {
oauth2Manager.setAuthoritiesMapper(authoritiesMapper);
}
boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
if (oidcAuthenticationProviderEnabled) {
Expand All @@ -1069,9 +1073,12 @@ private ReactiveAuthenticationManager createDefault() {
if (jwtDecoderFactory != null) {
oidc.setJwtDecoderFactory(jwtDecoderFactory);
}
result = new DelegatingReactiveAuthenticationManager(oidc, result);
if (authoritiesMapper != null) {
oidc.setAuthoritiesMapper(authoritiesMapper);
}
return new DelegatingReactiveAuthenticationManager(oidc, oauth2Manager);
}
return result;
return oauth2Manager;
}

/**
Expand Down
18 changes: 18 additions & 0 deletions docs/manual/src/docs/asciidoc/_includes/reactive/oauth2/login.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,21 @@ SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
return http.build();
}
----

You may register a `GrantedAuthoritiesMapper` `@Bean` to have it automatically applied to the default configuration, as shown in the following example:

[source,java]
----
@Bean
public GrantedAuthoritiesMapper userAuthoritiesMapper() {
...
}

@Bean
SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
http
// ...
.oauth2Login(withDefaults());
return http.build();
}
----
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,6 +95,18 @@ public Mono<Authentication> authenticate(Authentication authentication) {
});
}

/**
* Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OAuth2User#getAuthorities()}
* to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}.
*
* @since 5.4
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities
*/
public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}

private Mono<OAuth2LoginAuthenticationToken> onSuccess(OAuth2AuthorizationCodeAuthenticationToken authentication) {
OAuth2AccessToken accessToken = authentication.getAccessToken();
Map<String, Object> additionalParameters = authentication.getAdditionalParameters();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -156,6 +156,18 @@ public final void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistrat
this.jwtDecoderFactory = jwtDecoderFactory;
}

/**
* Sets the {@link GrantedAuthoritiesMapper} used for mapping {@link OAuth2User#getAuthorities()}
* to a new set of authorities which will be associated to the {@link OAuth2LoginAuthenticationToken}.
*
* @since 5.4
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities
*/
public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}

private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
Expand All @@ -33,8 +36,11 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
Expand Down Expand Up @@ -96,6 +102,12 @@ public void constructorWhenNullUserServiceThenIllegalArgumentException() {
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.manager.setAuthoritiesMapper(null))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void authenticateWhenNoSubscriptionThenDoesNothing() {
// we didn't do anything because it should cause a ClassCastException (as verified below)
Expand Down Expand Up @@ -178,6 +190,24 @@ public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToU
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}

@Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.build();
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
when(this.userService.loadUser(any())).thenReturn(Mono.just(user));
List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OAUTH_USER");
GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer((Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
manager.setAuthoritiesMapper(authoritiesMapper);

OAuth2LoginAuthenticationToken result = (OAuth2LoginAuthenticationToken) this.manager.authenticate(loginToken()).block();

assertThat(result.getAuthorities()).isEqualTo(mappedAuthorities);
}

private OAuth2AuthorizationCodeAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
Expand All @@ -29,6 +30,10 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import reactor.core.publisher.Mono;

import org.springframework.security.authentication.TestingAuthenticationToken;
Expand Down Expand Up @@ -63,6 +68,8 @@
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyCollection;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager.createHash;
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
Expand Down Expand Up @@ -123,6 +130,12 @@ public void setJwtDecoderFactoryWhenNullThenIllegalArgumentException() {
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.manager.setAuthoritiesMapper(null))
.isInstanceOf(IllegalArgumentException.class);
}

@Test
public void authenticateWhenNoSubscriptionThenDoesNothing() {
// we didn't do anything because it should cause a ClassCastException (as verified below)
Expand Down Expand Up @@ -316,6 +329,42 @@ public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToU
.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
}

@Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
ClientRegistration clientRegistration = this.registration.build();
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.build();

OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();

Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
claims.put(IdTokenClaimNames.SUB, "rob");
claims.put(IdTokenClaimNames.AUD, Collections.singletonList(clientRegistration.getClientId()));
claims.put(IdTokenClaimNames.NONCE, this.nonceHash);
Jwt idToken = jwt().claims(c -> c.putAll(claims)).build();


when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));

List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OIDC_USER");
GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer(
(Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
this.manager.setAuthoritiesMapper(authoritiesMapper);

Authentication result = this.manager.authenticate(authorizationCodeAuthentication).block();

assertThat(result.getAuthorities()).isEqualTo(mappedAuthorities);
}

private OAuth2AuthorizationCodeAuthenticationToken loginToken() {
ClientRegistration clientRegistration = this.registration.build();
Map<String, Object> attributes = new HashMap<>();
Expand Down