Skip to content

Commit 75dea30

Browse files
author
Dmitriy Dubson
committed
Add OAuth2TokenEndpointAuthenticationSuccessHandler
Fixes gh-925
1 parent 7c19716 commit 75dea30

File tree

4 files changed

+227
-41
lines changed

4 files changed

+227
-41
lines changed

docs/modules/ROOT/pages/protocol-endpoints.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ The supported https://datatracker.ietf.org/doc/html/rfc6749#section-1.3[authoriz
263263

264264
* `*AuthenticationConverter*` -- A `DelegatingAuthenticationConverter` composed of `OAuth2AuthorizationCodeAuthenticationConverter`, `OAuth2RefreshTokenAuthenticationConverter`, `OAuth2ClientCredentialsAuthenticationConverter`, and `OAuth2DeviceCodeAuthenticationConverter`.
265265
* `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2AuthorizationCodeAuthenticationProvider`, `OAuth2RefreshTokenAuthenticationProvider`, `OAuth2ClientCredentialsAuthenticationProvider`, and `OAuth2DeviceCodeAuthenticationProvider`.
266-
* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an `OAuth2AccessTokenAuthenticationToken` and returns the `OAuth2AccessTokenResponse`.
266+
* `*AuthenticationSuccessHandler*` -- An `OAuth2TokenEndpointAuthenticationSuccessHandler`.
267267
* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler`.
268268

269269
[[oauth2-token-introspection-endpoint]]

oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,24 @@
1616
package org.springframework.security.oauth2.server.authorization.web;
1717

1818
import java.io.IOException;
19-
import java.time.temporal.ChronoUnit;
2019
import java.util.Arrays;
21-
import java.util.Map;
2220

2321
import jakarta.servlet.FilterChain;
2422
import jakarta.servlet.ServletException;
2523
import jakarta.servlet.http.HttpServletRequest;
2624
import jakarta.servlet.http.HttpServletResponse;
27-
2825
import org.springframework.core.log.LogMessage;
2926
import org.springframework.http.HttpMethod;
30-
import org.springframework.http.converter.HttpMessageConverter;
31-
import org.springframework.http.server.ServletServerHttpResponse;
3227
import org.springframework.security.authentication.AbstractAuthenticationToken;
3328
import org.springframework.security.authentication.AuthenticationDetailsSource;
3429
import org.springframework.security.authentication.AuthenticationManager;
3530
import org.springframework.security.core.Authentication;
3631
import org.springframework.security.core.context.SecurityContextHolder;
37-
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3832
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3933
import org.springframework.security.oauth2.core.OAuth2Error;
4034
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
41-
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
4235
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
4336
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
44-
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
4537
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
4638
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
4739
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken;
@@ -54,14 +46,14 @@
5446
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2DeviceCodeAuthenticationConverter;
5547
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
5648
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2RefreshTokenAuthenticationConverter;
49+
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2TokenEndpointAuthenticationSuccessHandler;
5750
import org.springframework.security.web.authentication.AuthenticationConverter;
5851
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
5952
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
6053
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
6154
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
6255
import org.springframework.security.web.util.matcher.RequestMatcher;
6356
import org.springframework.util.Assert;
64-
import org.springframework.util.CollectionUtils;
6557
import org.springframework.web.filter.OncePerRequestFilter;
6658

6759
/**
@@ -103,12 +95,10 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
10395
private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
10496
private final AuthenticationManager authenticationManager;
10597
private final RequestMatcher tokenEndpointMatcher;
106-
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
107-
new OAuth2AccessTokenResponseHttpMessageConverter();
10898
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
10999
new WebAuthenticationDetailsSource();
110100
private AuthenticationConverter authenticationConverter;
111-
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
101+
private AuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2TokenEndpointAuthenticationSuccessHandler();
112102
private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
113103

114104
/**
@@ -218,34 +208,6 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent
218208
this.authenticationFailureHandler = authenticationFailureHandler;
219209
}
220210

221-
private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
222-
Authentication authentication) throws IOException {
223-
224-
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
225-
(OAuth2AccessTokenAuthenticationToken) authentication;
226-
227-
OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
228-
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
229-
Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();
230-
231-
OAuth2AccessTokenResponse.Builder builder =
232-
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
233-
.tokenType(accessToken.getTokenType())
234-
.scopes(accessToken.getScopes());
235-
if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
236-
builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
237-
}
238-
if (refreshToken != null) {
239-
builder.refreshToken(refreshToken.getTokenValue());
240-
}
241-
if (!CollectionUtils.isEmpty(additionalParameters)) {
242-
builder.additionalParameters(additionalParameters);
243-
}
244-
OAuth2AccessTokenResponse accessTokenResponse = builder.build();
245-
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
246-
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
247-
}
248-
249211
private static void throwError(String errorCode, String parameterName) {
250212
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI);
251213
throw new OAuth2AuthenticationException(error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2020-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.web.authentication;
17+
18+
import java.io.IOException;
19+
import java.time.temporal.ChronoUnit;
20+
import java.util.Map;
21+
22+
import jakarta.servlet.FilterChain;
23+
import jakarta.servlet.ServletException;
24+
import jakarta.servlet.http.HttpServletRequest;
25+
import jakarta.servlet.http.HttpServletResponse;
26+
import org.apache.commons.logging.Log;
27+
import org.apache.commons.logging.LogFactory;
28+
import org.springframework.http.converter.HttpMessageConverter;
29+
import org.springframework.http.server.ServletServerHttpResponse;
30+
import org.springframework.security.core.Authentication;
31+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
32+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
33+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
34+
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
35+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
36+
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
37+
import org.springframework.util.Assert;
38+
import org.springframework.util.CollectionUtils;
39+
40+
/**
41+
* An implementation of an {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
42+
* and returning the {@link OAuth2AccessTokenResponse Access Token Response}.
43+
*
44+
* @author Dmitriy Dubson
45+
* @see AuthenticationSuccessHandler
46+
* @see OAuth2AccessTokenResponseHttpMessageConverter
47+
* @since 1.2
48+
*/
49+
public class OAuth2TokenEndpointAuthenticationSuccessHandler implements AuthenticationSuccessHandler {
50+
private final Log logger = LogFactory.getLog(getClass());
51+
52+
private HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
53+
new OAuth2AccessTokenResponseHttpMessageConverter();
54+
55+
@Override
56+
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, FilterChain chain, Authentication authentication) throws IOException, ServletException {
57+
AuthenticationSuccessHandler.super.onAuthenticationSuccess(request, response, chain, authentication);
58+
}
59+
60+
@Override
61+
public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
62+
if (!(authentication instanceof OAuth2AccessTokenAuthenticationToken accessTokenAuthentication)) {
63+
// TODO: determine what should the HTTP response be if authentication type is incorrect.
64+
if (this.logger.isWarnEnabled()) {
65+
this.logger.warn(Authentication.class.getSimpleName() + " must be of type " +
66+
OAuth2AccessTokenAuthenticationToken.class.getName() +
67+
" but was " + authentication.getClass().getName());
68+
}
69+
return;
70+
}
71+
72+
OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
73+
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
74+
Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();
75+
76+
OAuth2AccessTokenResponse.Builder builder =
77+
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
78+
.tokenType(accessToken.getTokenType())
79+
.scopes(accessToken.getScopes());
80+
if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
81+
builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
82+
}
83+
if (refreshToken != null) {
84+
builder.refreshToken(refreshToken.getTokenValue());
85+
}
86+
if (!CollectionUtils.isEmpty(additionalParameters)) {
87+
builder.additionalParameters(additionalParameters);
88+
}
89+
90+
OAuth2AccessTokenResponse accessTokenResponse = builder.build();
91+
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
92+
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
93+
}
94+
95+
/**
96+
* Sets the {@link HttpMessageConverter} used for converting an {@link OAuth2AccessTokenResponse} to an HTTP response.
97+
*
98+
* @param accessTokenHttpResponseConverter the {@link HttpMessageConverter} used for converting an {@link OAuth2AccessTokenResponse} to an HTTP response
99+
*/
100+
public void setAccessTokenHttpResponseConverter(HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter) {
101+
Assert.notNull(accessTokenHttpResponseConverter, "accessTokenHttpResponseConverter cannot be null");
102+
this.accessTokenHttpResponseConverter = accessTokenHttpResponseConverter;
103+
}
104+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright 2020-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.web.authentication;
17+
18+
import java.io.IOException;
19+
import java.time.Instant;
20+
import java.time.temporal.ChronoUnit;
21+
import java.util.Collections;
22+
import java.util.Map;
23+
import java.util.Set;
24+
25+
import jakarta.servlet.ServletException;
26+
import org.junit.jupiter.api.Test;
27+
import org.mockito.ArgumentCaptor;
28+
import org.springframework.http.converter.HttpMessageConverter;
29+
import org.springframework.http.server.ServletServerHttpResponse;
30+
import org.springframework.mock.web.MockHttpServletRequest;
31+
import org.springframework.mock.web.MockHttpServletResponse;
32+
import org.springframework.security.core.Authentication;
33+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
34+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
35+
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
36+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
37+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
38+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
39+
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
40+
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
41+
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
42+
43+
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
45+
import static org.assertj.core.api.Assertions.within;
46+
import static org.mockito.ArgumentMatchers.isNull;
47+
import static org.mockito.Mockito.mock;
48+
import static org.mockito.Mockito.verify;
49+
import static org.mockito.Mockito.verifyNoInteractions;
50+
51+
/**
52+
* Tests for {@link OAuth2TokenEndpointAuthenticationSuccessHandler}.
53+
*
54+
* @author Dmitriy Dubson
55+
*/
56+
public class OAuth2TokenEndpointAuthenticationSuccessHandlerTests {
57+
private final OAuth2TokenEndpointAuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2TokenEndpointAuthenticationSuccessHandler();
58+
59+
@Test
60+
public void setAccessTokenHttpResponseConverterWhenNullThenThrowIllegalArgumentException() {
61+
// @formatter:off
62+
assertThatThrownBy(() -> this.authenticationSuccessHandler.setAccessTokenHttpResponseConverter(null))
63+
.isInstanceOf(IllegalArgumentException.class)
64+
.hasMessage("accessTokenHttpResponseConverter cannot be null");
65+
// @formatter:on
66+
}
67+
68+
@Test
69+
public void onAuthenticationSuccessWritesAccessTokenToHttpResponse() throws ServletException, IOException {
70+
MockHttpServletRequest request = new MockHttpServletRequest();
71+
MockHttpServletResponse response = new MockHttpServletResponse();
72+
73+
RegisteredClient testClient = TestRegisteredClients.registeredClient().build();
74+
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
75+
testClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, testClient.getClientSecret());
76+
Instant issuedAt = Instant.now();
77+
Instant expiresAt = Instant.now().plusSeconds(300);
78+
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
79+
"access-token-value", issuedAt, expiresAt, Set.of("scope1"));
80+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token-value", issuedAt);
81+
Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
82+
Authentication authentication = new OAuth2AccessTokenAuthenticationToken(testClient, clientPrincipal, accessToken, refreshToken, additionalParameters);
83+
84+
HttpMessageConverter<OAuth2AccessTokenResponse> responseConverter = mock(HttpMessageConverter.class);
85+
authenticationSuccessHandler.setAccessTokenHttpResponseConverter(responseConverter);
86+
authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication);
87+
88+
ArgumentCaptor<OAuth2AccessTokenResponse> accessTokenResponseCaptor = ArgumentCaptor.forClass(OAuth2AccessTokenResponse.class);
89+
ArgumentCaptor<ServletServerHttpResponse> servletServerHttpResponseArgumentCaptor = ArgumentCaptor.forClass(ServletServerHttpResponse.class);
90+
verify(responseConverter).write(accessTokenResponseCaptor.capture(), isNull(), servletServerHttpResponseArgumentCaptor.capture());
91+
92+
OAuth2AccessTokenResponse actualAccessTokenResponse = accessTokenResponseCaptor.getValue();
93+
assertThat(actualAccessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-value");
94+
assertThat(actualAccessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
95+
assertThat(actualAccessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("scope1");
96+
assertThat(actualAccessTokenResponse.getAccessToken().getIssuedAt()).isCloseTo(issuedAt, within(2, ChronoUnit.SECONDS));
97+
assertThat(actualAccessTokenResponse.getAccessToken().getExpiresAt()).isCloseTo(expiresAt, within(2, ChronoUnit.SECONDS));
98+
assertThat(actualAccessTokenResponse.getRefreshToken()).isNotNull();
99+
assertThat(actualAccessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token-value");
100+
assertThat(actualAccessTokenResponse.getAdditionalParameters()).containsExactlyEntriesOf(additionalParameters);
101+
102+
assertThat(servletServerHttpResponseArgumentCaptor.getValue().getServletResponse()).isEqualTo(response);
103+
}
104+
105+
@Test
106+
public void onAuthenticationSuccessWhenInvaliAuthenticationThenNoResponse() throws ServletException, IOException {
107+
MockHttpServletRequest request = new MockHttpServletRequest();
108+
MockHttpServletResponse response = new MockHttpServletResponse();
109+
110+
RegisteredClient testClient = TestRegisteredClients.registeredClient().build();
111+
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
112+
testClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, testClient.getClientSecret());
113+
114+
HttpMessageConverter<OAuth2AccessTokenResponse> responseConverter = mock(HttpMessageConverter.class);
115+
authenticationSuccessHandler.setAccessTokenHttpResponseConverter(responseConverter);
116+
authenticationSuccessHandler.onAuthenticationSuccess(request, response, new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, Set.of(), Map.of()));
117+
118+
verifyNoInteractions(responseConverter);
119+
}
120+
}

0 commit comments

Comments
 (0)