Skip to content

Add support for customizing claims in JWT Client Assertion #10972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,35 @@ tokenResponseClient.addParametersConverter(
)
----
====

=== Customizing the JWT assertion

The JWT produced by `NimbusJwtClientAuthenticationParametersConverter` contains the `iss`, `sub`, `aud`, `jti`, `iat` and `exp` claims by default. You can customize the headers and/or claims by providing a `Consumer<NimbusJwtClientAuthenticationParametersConverter.JwtClientAuthenticationContext<T>>` to `setJwtClientAssertionCustomizer()`. The following example shows how to customize claims of the JWT:

====
.Java
[source,java,role="primary"]
----
Function<ClientRegistration, JWK> jwkResolver = ...

NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> converter =
new NimbusJwtClientAuthenticationParametersConverter<>(jwkResolver);
converter.setJwtClientAssertionCustomizer((context) -> {
context.getHeaders().header("custom-header", "header-value");
context.getClaims().claim("custom-claim", "claim-value");
});
----

.Kotlin
[source,kotlin,role="secondary"]
----
val jwkResolver = ...

val converter: NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> =
NimbusJwtClientAuthenticationParametersConverter(jwkResolver)
converter.setJwtClientAssertionCustomizer { context ->
context.headers.header("custom-header", "header-value")
context.claims.claim("custom-claim", "claim-value")
}
----
====
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,35 @@ val tokenResponseClient = DefaultClientCredentialsTokenResponseClient()
tokenResponseClient.setRequestEntityConverter(requestEntityConverter)
----
====

=== Customizing the JWT assertion

The JWT produced by `NimbusJwtClientAuthenticationParametersConverter` contains the `iss`, `sub`, `aud`, `jti`, `iat` and `exp` claims by default. You can customize the headers and/or claims by providing a `Consumer<NimbusJwtClientAuthenticationParametersConverter.JwtClientAuthenticationContext<T>>` to `setJwtClientAssertionCustomizer()`. The following example shows how to customize claims of the JWT:

====
.Java
[source,java,role="primary"]
----
Function<ClientRegistration, JWK> jwkResolver = ...

NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> converter =
new NimbusJwtClientAuthenticationParametersConverter<>(jwkResolver);
converter.setJwtClientAssertionCustomizer((context) -> {
context.getHeaders().header("custom-header", "header-value");
context.getClaims().claim("custom-claim", "claim-value");
});
----

.Kotlin
[source,kotlin,role="secondary"]
----
val jwkResolver = ...

val converter: NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> =
NimbusJwtClientAuthenticationParametersConverter(jwkResolver)
converter.setJwtClientAssertionCustomizer { context ->
context.headers.header("custom-header", "header-value")
context.claims.claim("custom-claim", "claim-value")
}
----
====
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;

import com.nimbusds.jose.jwk.JWK;
Expand Down Expand Up @@ -62,6 +63,7 @@
*
* @param <T> the type of {@link AbstractOAuth2AuthorizationGrantRequest}
* @author Joe Grandja
* @author Steve Riesenberg
* @since 5.5
* @see Converter
* @see com.nimbusds.jose.jwk.JWK
Expand All @@ -87,6 +89,9 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab

private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();

private Consumer<JwtClientAuthenticationContext<T>> jwtClientAssertionCustomizer = (context) -> {
};

/**
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
* provided parameters.
Expand Down Expand Up @@ -142,6 +147,10 @@ public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
.expiresAt(expiresAt);
// @formatter:on

JwtClientAuthenticationContext<T> jwtClientAssertionContext = new JwtClientAuthenticationContext<>(
authorizationGrantRequest, headersBuilder, claimsBuilder);
this.jwtClientAssertionCustomizer.accept(jwtClientAssertionContext);

JwsHeader jwsHeader = headersBuilder.build();
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();

Expand Down Expand Up @@ -189,6 +198,21 @@ else if (KeyType.OCT.equals(jwk.getKeyType())) {
return jwsAlgorithm;
}

/**
* Sets the {@link Consumer} to be provided the
* {@link JwtClientAuthenticationContext}, which contains the
* {@link JwsHeader.Builder} and {@link JwtClaimsSet.Builder} for further
* customization.
* @param jwtClientAssertionCustomizer the {@link Consumer} to be provided the
* {@link JwtClientAuthenticationContext}
* @since 5.7
*/
public void setJwtClientAssertionCustomizer(
Consumer<JwtClientAuthenticationContext<T>> jwtClientAssertionCustomizer) {
Assert.notNull(jwtClientAssertionCustomizer, "jwtClientAssertionCustomizer cannot be null");
this.jwtClientAssertionCustomizer = jwtClientAssertionCustomizer;
}

private static final class JwsEncoderHolder {

private final JwtEncoder jwsEncoder;
Expand All @@ -210,4 +234,59 @@ private JWK getJwk() {

}

/**
* A context that holds client authentication-specific state and is used by
* {@link NimbusJwtClientAuthenticationParametersConverter} when attempting to
* customize the JSON Web Token (JWS) client assertion.
*
* @param <T> the type of {@link AbstractOAuth2AuthorizationGrantRequest}
* @since 5.7
*/
public static final class JwtClientAuthenticationContext<T extends AbstractOAuth2AuthorizationGrantRequest> {

private final T authorizationGrantRequest;

private final JwsHeader.Builder headers;

private final JwtClaimsSet.Builder claims;

private JwtClientAuthenticationContext(T authorizationGrantRequest, JwsHeader.Builder headers,
JwtClaimsSet.Builder claims) {
this.authorizationGrantRequest = authorizationGrantRequest;
this.headers = headers;
this.claims = claims;
}

/**
* Returns the {@link AbstractOAuth2AuthorizationGrantRequest authorization grant
* request}.
* @return the {@link AbstractOAuth2AuthorizationGrantRequest authorization grant
* request}
*/
public T getAuthorizationGrantRequest() {
return this.authorizationGrantRequest;
}

/**
* Returns the {@link JwsHeader.Builder} to be used to customize headers of the
* JSON Web Token (JWS).
* @return the {@link JwsHeader.Builder} to be used to customize headers of the
* JSON Web Token (JWS)
*/
public JwsHeader.Builder getHeaders() {
return this.headers;
}

/**
* Returns the {@link JwtClaimsSet.Builder} to be used to customize claims of the
* JSON Web Token (JWS).
* @return the {@link JwtClaimsSet.Builder} to be used to customize claims of the
* JSON Web Token (JWS)
*/
public JwtClaimsSet.Builder getClaims() {
return this.claims;
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -83,6 +83,12 @@ public void convertWhenAuthorizationGrantRequestNullThenThrowIllegalArgumentExce
.withMessage("authorizationGrantRequest cannot be null");
}

@Test
public void setJwtClientAssertionCustomizerWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setJwtClientAssertionCustomizer(null))
.withMessage("jwtClientAssertionCustomizer cannot be null");
}

@Test
public void convertWhenOtherClientAuthenticationMethodThenNotCustomized() {
// @formatter:off
Expand Down Expand Up @@ -179,6 +185,51 @@ public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized()
assertThat(jws.getExpiresAt()).isNotNull();
}

@Test
public void convertWhenJwtClientAssertionCustomizerSetThenUsed() {
OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK;
given(this.jwkResolver.apply(any())).willReturn(secretJwk);

String headerName = "custom-header";
String headerValue = "header-value";
String claimName = "custom-claim";
String claimValue = "claim-value";
this.converter.setJwtClientAssertionCustomizer((context) -> {
context.getHeaders().header(headerName, headerValue);
context.getClaims().claim(claimName, claimValue);
});

// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.build();
// @formatter:on

OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);

assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE))
.isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
assertThat(encodedJws).isNotNull();

NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withSecretKey(secretJwk.toSecretKey()).build();
Jwt jws = jwtDecoder.decode(encodedJws);

assertThat(jws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(MacAlgorithm.HS256.getName());
assertThat(jws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(secretJwk.getKeyID());
assertThat(jws.getHeaders().get(headerName)).isEqualTo(headerValue);
assertThat(jws.<String>getClaim(JwtClaimNames.ISS)).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getSubject()).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getAudience())
.isEqualTo(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri()));
assertThat(jws.getId()).isNotNull();
assertThat(jws.getIssuedAt()).isNotNull();
assertThat(jws.getExpiresAt()).isNotNull();
assertThat(jws.getClaimAsString(claimName)).isEqualTo(claimValue);
}

// gh-9814
@Test
public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
Expand Down