From 1433393a5d0fde8895de2d0e8211ea9f60367ed6 Mon Sep 17 00:00:00 2001 From: MrJovanovic13 <34819606+MrJovanovic13@users.noreply.github.com> Date: Thu, 29 Feb 2024 20:15:44 +0100 Subject: [PATCH] Provide more flexibility on when to display consent page --- ...ationCodeRequestAuthenticationContext.java | 45 ++++++++++ ...tionCodeRequestAuthenticationProvider.java | 80 ++++++++++++----- ...odeRequestAuthenticationProviderTests.java | 85 +++++++++++++++++++ 3 files changed, 188 insertions(+), 22 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java index c158d9408..b59b02539 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationContext.java @@ -21,6 +21,8 @@ import java.util.function.Consumer; import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.util.Assert; @@ -63,6 +65,27 @@ public RegisteredClient getRegisteredClient() { return get(RegisteredClient.class); } + /** + * Returns the {@link OAuth2AuthorizationRequest oauth2 authorization request}. + * + * @return the {@link OAuth2AuthorizationRequest} + */ + @Nullable + public OAuth2AuthorizationRequest getOAuth2AuthorizationRequest() { + return get(OAuth2AuthorizationRequest.class); + } + + /** + * Returns the {@link OAuth2AuthorizationConsent oauth2 authorization consent}. + * + * @return the {@link OAuth2AuthorizationConsent} + */ + @Nullable + public OAuth2AuthorizationConsent getOAuth2AuthorizationConsent() { + return get(OAuth2AuthorizationConsent.class); + } + + /** * Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationCodeRequestAuthenticationToken}. * @@ -92,6 +115,28 @@ public Builder registeredClient(RegisteredClient registeredClient) { return put(RegisteredClient.class, registeredClient); } + /** + * Sets the {@link OAuth2AuthorizationRequest oauth2 authorization request}. + * + * @param authorizationRequest the {@link OAuth2AuthorizationRequest} + * @return the {@link Builder} for further configuration + * @since 1.3.0 + */ + public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) { + return put(OAuth2AuthorizationRequest.class, authorizationRequest); + } + + /** + * Sets the {@link OAuth2AuthorizationConsent oauth2 authorization consent}. + * + * @param authorizationConsent the {@link OAuth2AuthorizationConsent} + * @return the {@link Builder} for further configuration + * @since 1.3.0 + */ + public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) { + return put(OAuth2AuthorizationConsent.class, authorizationConsent); + } + /** * Builds a new {@link OAuth2AuthorizationCodeRequestAuthenticationContext}. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java index 21b2f072f..ca2197de7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java @@ -19,6 +19,7 @@ import java.util.Base64; import java.util.Set; import java.util.function.Consumer; +import java.util.function.Predicate; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -80,6 +81,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen private OAuth2TokenGenerator authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator(); private Consumer authenticationValidator = new OAuth2AuthorizationCodeRequestAuthenticationValidator(); + private Predicate requiresAuthorizationConsent; /** * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters. @@ -96,6 +98,7 @@ public OAuth2AuthorizationCodeRequestAuthenticationProvider(RegisteredClientRepo this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; this.authorizationConsentService = authorizationConsentService; + this.requiresAuthorizationConsent = this::requireAuthorizationConsent; } @Override @@ -171,7 +174,19 @@ public Authentication authenticate(Authentication authentication) throws Authent OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService.findById( registeredClient.getId(), principal.getName()); - if (requireAuthorizationConsent(registeredClient, authorizationRequest, currentAuthorizationConsent)) { + OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = + OAuth2AuthorizationCodeRequestAuthenticationContext.with(authorizationCodeRequestAuthentication) + .registeredClient(registeredClient) + .authorizationRequest(authorizationRequest); + + if (currentAuthorizationConsent != null) { + authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent); + } + + OAuth2AuthorizationCodeRequestAuthenticationContext contextWithAuthorizationRequestAndAuthorizationConsent = + authenticationContextBuilder.build(); + + if (requiresAuthorizationConsent.test(contextWithAuthorizationRequestAndAuthorizationConsent)) { String state = DEFAULT_STATE_GENERATOR.generateKey(); OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest) .attribute(OAuth2ParameterNames.STATE, state) @@ -264,7 +279,48 @@ public void setAuthenticationValidator(Consumer + * The {@link OAuth2AuthorizationCodeRequestAuthenticationContext} gives the predicate access to the {@link OAuth2AuthorizationCodeRequestAuthenticationToken}, + * as well as, the following context attributes: + * {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getRegisteredClient()} containing {@link RegisteredClient} used to make the request. + * {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationRequest()} containing {@link OAuth2AuthorizationRequest}. + * {@link OAuth2AuthorizationCodeRequestAuthenticationContext#getOAuth2AuthorizationConsent()} containing {@link OAuth2AuthorizationConsent} granted in the request. + * + * @param requiresAuthorizationConsent the {@link Predicate} that determines if authorization consent is required. + * @since 1.3.0 + */ + public void setRequiresAuthorizationConsent(Predicate requiresAuthorizationConsent) { + Assert.notNull(requiresAuthorizationConsent, "requiresAuthorizationConsent cannot be null"); + this.requiresAuthorizationConsent = requiresAuthorizationConsent; + } + + private boolean requireAuthorizationConsent(OAuth2AuthorizationCodeRequestAuthenticationContext context) { + RegisteredClient registeredClient = context.getRegisteredClient(); + if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) { + return false; + } + + OAuth2AuthorizationRequest authorizationRequest = context.getOAuth2AuthorizationRequest(); + // 'openid' scope does not require consent + if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) && + authorizationRequest.getScopes().size() == 1) { + return false; + } + + OAuth2AuthorizationConsent authorizationConsent = context.getOAuth2AuthorizationConsent(); + if (authorizationConsent != null && + authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) { + return false; + } + + return true; + } + + private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, + Authentication principal, OAuth2AuthorizationRequest authorizationRequest) { return OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(principal.getName()) @@ -295,26 +351,6 @@ private static OAuth2TokenContext createAuthorizationCodeTokenContext( return tokenContextBuilder.build(); } - private static boolean requireAuthorizationConsent(RegisteredClient registeredClient, - OAuth2AuthorizationRequest authorizationRequest, OAuth2AuthorizationConsent authorizationConsent) { - - if (!registeredClient.getClientSettings().isRequireAuthorizationConsent()) { - return false; - } - // 'openid' scope does not require consent - if (authorizationRequest.getScopes().contains(OidcScopes.OPENID) && - authorizationRequest.getScopes().size() == 1) { - return false; - } - - if (authorizationConsent != null && - authorizationConsent.getScopes().containsAll(authorizationRequest.getScopes())) { - return false; - } - - return true; - } - private static boolean isPrincipalAuthenticated(Authentication principal) { return principal != null && !AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) && diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java index b9e96da43..042e2a6b8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.Set; import java.util.function.Consumer; +import java.util.function.Predicate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -72,6 +73,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests { private OAuth2AuthorizationConsentService authorizationConsentService; private OAuth2AuthorizationCodeRequestAuthenticationProvider authenticationProvider; private TestingAuthenticationToken principal; + private Predicate requiresAuthorizationConsent; @BeforeEach public void setUp() { @@ -129,6 +131,13 @@ public void setAuthenticationValidatorWhenNullThenThrowIllegalArgumentException( .hasMessage("authenticationValidator cannot be null"); } + @Test + public void setRequiresAuthorizationConsentWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setRequiresAuthorizationConsent(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("requiresAuthorizationConsent cannot be null"); + } + @Test public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -443,6 +452,82 @@ public void authenticateWhenRequireAuthorizationConsentThenReturnAuthorizationCo assertThat(authenticationResult.isAuthenticated()).isTrue(); } + @Test + public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateTrueThenReturnAuthorizationConsent() { + this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> true); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build()) + .build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[0]; + OAuth2AuthorizationCodeRequestAuthenticationToken authentication = + new OAuth2AuthorizationCodeRequestAuthenticationToken( + AUTHORIZATION_URI, registeredClient.getClientId(), principal, + redirectUri, STATE, registeredClient.getScopes(), null); + + OAuth2AuthorizationConsentAuthenticationToken authenticationResult = + (OAuth2AuthorizationConsentAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization authorization = authorizationCaptor.getValue(); + + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); + assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(authentication.getAuthorizationUri()); + assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo(authentication.getRedirectUri()); + assertThat(authorizationRequest.getScopes()).isEqualTo(authentication.getScopes()); + assertThat(authorizationRequest.getState()).isEqualTo(authentication.getState()); + assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(authentication.getAdditionalParameters()); + + assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId()); + assertThat(authorization.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(authorization.getAttribute(Principal.class.getName())).isEqualTo(this.principal); + String state = authorization.getAttribute(OAuth2ParameterNames.STATE); + assertThat(state).isNotNull(); + assertThat(state).isNotEqualTo(authentication.getState()); + + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPrincipal()).isEqualTo(this.principal); + assertThat(authenticationResult.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri()); + assertThat(authenticationResult.getScopes()).isEmpty(); + assertThat(authenticationResult.getState()).isEqualTo(state); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + } + + @Test + public void authenticateWhenRequireAuthorizationConsentAndRequiresAuthorizationConsentPredicateFalseThenAuthorizationConsentNotRequired() { + this.authenticationProvider.setRequiresAuthorizationConsent((authenticationContext) -> false); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient() + .clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build()) + .scopes(scopes -> { + scopes.clear(); + scopes.add(OidcScopes.OPENID); + scopes.add(OidcScopes.EMAIL); + }) + .build(); + when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) + .thenReturn(registeredClient); + + String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[1]; + OAuth2AuthorizationCodeRequestAuthenticationToken authentication = + new OAuth2AuthorizationCodeRequestAuthenticationToken( + AUTHORIZATION_URI, registeredClient.getClientId(), principal, + redirectUri, STATE, registeredClient.getScopes(), null); + + OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = + (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult); + } + @Test public void authenticateWhenRequireAuthorizationConsentAndOnlyOpenidScopeRequestedThenAuthorizationConsentNotRequired() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient()