diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java index 53ddeb73d97..62bc1e5493d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java @@ -48,7 +48,8 @@ class Saml2PostAuthenticationRequestMixin { Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest, @JsonProperty("relayState") String relayState, @JsonProperty("authenticationRequestUri") String authenticationRequestUri, - @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { + @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId, + @JsonProperty("id") String id) { } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java index 247b52104c5..3412f962f50 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java @@ -49,7 +49,8 @@ class Saml2RedirectAuthenticationRequestMixin { @JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature, @JsonProperty("relayState") String relayState, @JsonProperty("authenticationRequestUri") String authenticationRequestUri, - @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) { + @JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId, + @JsonProperty("id") String id) { } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java index 04e6a958f8e..39f6b725d51 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java @@ -49,6 +49,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable private final String relyingPartyRegistrationId; + private final String id; + /** * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest} * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or @@ -58,15 +60,18 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable * send the XML message, cannot be empty or null * @param relyingPartyRegistrationId the registration id of the relying party, may be * null + * @param id This is the unique id used in the {@link #samlRequest}, cannot be empty + * or null */ AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, - String relyingPartyRegistrationId) { + String relyingPartyRegistrationId, String id) { Assert.hasText(samlRequest, "samlRequest cannot be null or empty"); Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty"); this.authenticationRequestUri = authenticationRequestUri; this.samlRequest = samlRequest; this.relayState = relayState; this.relyingPartyRegistrationId = relyingPartyRegistrationId; + this.id = id; } /** @@ -106,6 +111,15 @@ public String getRelyingPartyRegistrationId() { return this.relyingPartyRegistrationId; } + /** + * The unique identifier for this Authentication Request + * @return the Authentication Request identifier + * @since 5.8 + */ + public String getId() { + return this.id; + } + /** * Returns the binding this AuthNRequest will be sent and encoded with. If * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be @@ -127,6 +141,8 @@ public static class Builder> { String relyingPartyRegistrationId; + String id; + /** * @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead */ @@ -184,6 +200,19 @@ public T authenticationRequestUri(String authenticationRequestUri) { return _this(); } + /** + * This is the unique id used in the {@link #samlRequest} + * @param id the SAML2 request id + * @return the {@link AbstractSaml2AuthenticationRequest.Builder} for further + * configurations + * @since 5.8 + */ + public T id(String id) { + Assert.notNull(id, "id cannot be null"); + this.id = id; + return _this(); + } + } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java index 29dc000b392..dbf348ebe54 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java @@ -31,8 +31,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest { Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri, - String relyingPartyRegistrationId) { - super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); + String relyingPartyRegistrationId, String id) { + super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id); } /** @@ -69,7 +69,7 @@ private Builder(RelyingPartyRegistration registration) { */ public Saml2PostAuthenticationRequest build() { return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri, - this.relyingPartyRegistrationId); + this.relyingPartyRegistrationId, this.id); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java index 600ef993c93..8dd6589962d 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java @@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe private final String signature; private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState, - String authenticationRequestUri, String relyingPartyRegistrationId) { - super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId); + String authenticationRequestUri, String relyingPartyRegistrationId, String id) { + super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id); this.sigAlg = sigAlg; this.signature = signature; } @@ -116,7 +116,7 @@ public Builder signature(String signature) { */ public Saml2RedirectAuthenticationRequest build() { return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature, - this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId); + this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId, this.id); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java index 79c2881a1b8..0746f60222e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -142,13 +142,14 @@ T resolve(HttpServletRequest requ String xml = serialize(authnRequest); String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)); return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded) - .relayState(relayState).build(); + .relayState(relayState).id(authnRequest.getID()).build(); } else { String xml = serialize(authnRequest); String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest - .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState); + .withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState) + .id(authnRequest.getID()); if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) { Map parameters = OpenSamlSigningUtils.sign(registration) .param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java index 3183d274349..eea1df51dde 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java @@ -58,6 +58,7 @@ void shouldDeserialize() throws Exception { .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); assertThat(authRequest.getRelyingPartyRegistrationId()) .isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID); + assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID); } @Test @@ -73,6 +74,24 @@ void shouldDeserializeWithNoRegistrationId() throws Exception { assertThat(authRequest.getAuthenticationRequestUri()) .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); assertThat(authRequest.getRelyingPartyRegistrationId()).isNull(); + assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID); + } + + @Test + void shouldDeserializeWithNoId() throws Exception { + String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON + .replace(", \"id\": \"" + TestSaml2JsonPayloads.ID + "\"", ""); + + Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class); + + assertThat(authRequest).isNotNull(); + assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST); + assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE); + assertThat(authRequest.getAuthenticationRequestUri()) + .isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI); + assertThat(authRequest.getRelyingPartyRegistrationId()) + .isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID); + assertThat(authRequest.getId()).isNull(); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java index 18f2e7deb8a..636407a596b 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java @@ -97,6 +97,7 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() { static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue"; static final String SIG_ALG = "sigAlgValue"; static final String SIGNATURE = "signatureValue"; + static final String ID = "idValue"; // @formatter:off static final String DEFAULT_REDIRECT_AUTH_REQUEST_JSON = "{" @@ -106,7 +107,8 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() { + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"," + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," + " \"sigAlg\": \"" + SIG_ALG + "\"," - + " \"signature\": \"" + SIGNATURE + "\"" + + " \"signature\": \"" + SIGNATURE + "\"," + + " \"id\": \"" + ID + "\"" + "}"; // @formatter:on @@ -116,11 +118,11 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() { + " \"samlRequest\": \"" + SAML_REQUEST + "\"," + " \"relayState\": \"" + RELAY_STATE + "\"," + " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\"," - + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"" + + " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\"," + + " \"id\": \"" + ID + "\"" + "}"; // @formatter:on - static final String ID = "idValue"; static final String LOCATION = "locationValue"; static final String BINDNG = "REDIRECT"; static final String ADDITIONAL_PARAM = "additionalParamValue"; @@ -146,7 +148,7 @@ static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationReques TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID) .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) .build()) - .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build(); + .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).id(ID).build(); } static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() { @@ -155,7 +157,7 @@ static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticati .registrationId(RELYINGPARTY_REGISTRATION_ID) .assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI)) .build()) - .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build(); + .samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).id(ID).build(); } static Saml2LogoutRequest createDefaultSaml2LogoutRequest() { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java index b81e991b83c..9bcb3620b7c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java @@ -16,8 +16,8 @@ package org.springframework.security.saml2.provider.service.web.authentication; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.springframework.mock.web.MockHttpServletRequest; @@ -40,7 +40,7 @@ public class OpenSamlAuthenticationRequestResolverTests { private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder; - @Before + @BeforeEach public void setUp() { this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration(); } @@ -65,6 +65,7 @@ public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects( assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256); assertThat(result.getSignature()).isNotEmpty(); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(result.getId()).isNotEmpty(); } @Test @@ -88,6 +89,7 @@ public void resolveAuthenticationRequestWhenUnsignedRedirectThenRedirectsAndNoSi assertThat(result.getSigAlg()).isNull(); assertThat(result.getSignature()).isNull(); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(result.getId()).isNotEmpty(); } @Test @@ -98,7 +100,9 @@ public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() { RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials() .assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build(); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> resolver.resolve(request, null)); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> { + })); } @Test @@ -122,6 +126,7 @@ public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() { assertThat(result.getRelayState()).isNotNull(); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).doesNotContain("Signature"); + assertThat(result.getId()).isNotEmpty(); } @Test @@ -144,6 +149,7 @@ public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() { assertThat(result.getRelayState()).isNotNull(); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST); assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).contains("Signature"); + assertThat(result.getId()).isNotEmpty(); } @Test @@ -154,12 +160,14 @@ public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() { (party) -> party.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1))) .build(); OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration); - Saml2RedirectAuthenticationRequest result = resolver.resolve(request, null); + Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> { + }); assertThat(result.getSamlRequest()).isNotEmpty(); assertThat(result.getRelayState()).isNotNull(); assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); assertThat(result.getSignature()).isNotNull(); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); + assertThat(result.getId()).isNotEmpty(); } private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) {