diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java index b7f4d9c7994..ecee9777f1e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,13 @@ public interface Saml2ErrorCodes { */ String UNKNOWN_RESPONSE_CLASS = "unknown_response_class"; + /** + * The serialized AuthNRequest could not be deserialized correctly. + * + * @since 5.7 + */ + String MALFORMED_REQUEST_DATA = "malformed_request_data"; + /** * The response data is malformed or incomplete. An invalid XML object was received, * and XML unmarshalling failed. @@ -116,4 +123,11 @@ public interface Saml2ErrorCodes { */ String RELYING_PARTY_REGISTRATION_NOT_FOUND = "relying_party_registration_not_found"; + /** + * The InResponseTo content of the response does not match the ID of the AuthNRequest. + * + * @since 5.7 + */ + String INVALID_IN_RESPONSE_TO = "invalid_in_response_to"; + } diff --git a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 51d9dc1abb6..835fb616778 100644 --- a/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -37,6 +37,7 @@ import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.schema.XSAny; import org.opensaml.core.xml.schema.XSBoolean; import org.opensaml.core.xml.schema.XSBooleanValue; @@ -57,13 +58,17 @@ import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnStatement; import org.opensaml.saml.saml2.core.Condition; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; import org.opensaml.saml.saml2.core.StatusCode; +import org.opensaml.saml.saml2.core.Subject; import org.opensaml.saml.saml2.core.SubjectConfirmation; +import org.opensaml.saml.saml2.core.SubjectConfirmationData; +import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller; import org.opensaml.saml.saml2.encryption.Decrypter; import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator; @@ -85,6 +90,7 @@ import org.springframework.security.saml2.core.Saml2ErrorCodes; import org.springframework.security.saml2.core.Saml2ResponseValidatorResult; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -349,6 +355,37 @@ public void setResponseAuthenticationConverter( this.responseAuthenticationConverter = responseAuthenticationConverter; } + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, + String inResponseTo) { + if (!StringUtils.hasText(inResponseTo)) { + return Saml2ResponseValidatorResult.success(); + } + AuthnRequest request; + try { + request = parseRequest(storedRequest); + } + catch (Exception ex) { + String message = "The stored AuthNRequest could not be properly deserialized [" + ex.getMessage() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message)); + } + if (request == null) { + String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]" + + " but no saved AuthNRequest request was found"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + else if (!request.getID().equals(inResponseTo)) { + String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the " + + "AuthNRequest [" + request.getID() + "]"; + return Saml2ResponseValidatorResult + .failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message)); + } + else { + return Saml2ResponseValidatorResult.success(); + } + } + /** * Construct a default strategy for validating the SAML 2.0 Response * @return the default response validator strategy @@ -365,6 +402,10 @@ public static Converter createDefau response.getID()); result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); } + + String inResponseTo = response.getInResponseTo(); + result = result.concat(validateInResponseTo(token.getAuthenticationRequest(), inResponseTo)); + String issuer = response.getIssuer().getValue(); String destination = response.getDestination(); String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation(); @@ -447,7 +488,7 @@ public Authentication authenticate(Authentication authentication) throws Authent try { Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication; String serializedResponse = token.getSaml2Response(); - Response response = parse(serializedResponse); + Response response = parseResponse(serializedResponse); process(token, response); AbstractAuthenticationToken authenticationResponse = this.responseAuthenticationConverter .convert(new ResponseToken(response, token)); @@ -469,7 +510,7 @@ public boolean supports(Class authentication) { return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication); } - private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException { + private Response parseResponse(String response) throws Saml2Exception, Saml2AuthenticationException { try { Document document = this.parserPool .parse(new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))); @@ -481,6 +522,28 @@ private Response parse(String response) throws Saml2Exception, Saml2Authenticati } } + private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) throws Exception { + if (request == null) { + return null; + } + String samlRequest = request.getSamlRequest(); + if (!StringUtils.hasText(samlRequest)) { + return null; + } + if (request.getBinding() == Saml2MessageBinding.REDIRECT) { + samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest)); + } + else { + samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8); + } + Document document = XMLObjectProviderRegistrySupport.getParserPool() + .parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8))); + Element element = document.getDocumentElement(); + AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport + .getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME); + return (AuthnRequest) unmarshaller.unmarshall(element); + } + private void process(Saml2AuthenticationToken token, Response response) { String issuer = response.getIssuer().getValue(); this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer)); @@ -685,6 +748,30 @@ private static Converter createAss }; } + private static boolean assertionContainsInResponseTo(Assertion assertion) { + Subject subject = (assertion != null) ? assertion.getSubject() : null; + List confirmations = (subject != null) ? subject.getSubjectConfirmations() + : new ArrayList<>(); + return confirmations.stream().filter((confirmation) -> { + SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData(); + return confirmationData != null && StringUtils.hasText(confirmationData.getInResponseTo()); + }).findFirst().orElse(null) != null; + } + + private static void addRequestIdToValidationContext(AbstractSaml2AuthenticationRequest storedRequest, + Map context) { + String requestId = null; + try { + AuthnRequest request = parseRequest(storedRequest); + requestId = (request != null) ? request.getID() : null; + } + catch (Exception ex) { + } + if (StringUtils.hasText(requestId)) { + context.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId); + } + } + private static ValidationContext createValidationContext(AssertionToken assertionToken, Consumer> paramsConsumer) { RelyingPartyRegistration relyingPartyRegistration = assertionToken.token.getRelyingPartyRegistration(); @@ -692,6 +779,10 @@ private static ValidationContext createValidationContext(AssertionToken assertio String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation(); String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId(); Map params = new HashMap<>(); + Assertion assertion = assertionToken.getAssertion(); + if (assertionContainsInResponseTo(assertion)) { + addRequestIdToValidationContext(assertionToken.token.getAuthenticationRequest(), params); + } params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience)); params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient)); params.put(SAML2AssertionValidationParameters.VALID_ISSUERS, Collections.singleton(assertingPartyEntityId)); @@ -733,13 +824,6 @@ protected ValidationResult validateAddress(SubjectConfirmation confirmation, Ass // applications should validate their own addresses - gh-7514 return ValidationResult.VALID; } - - @Override - protected ValidationResult validateInResponseTo(SubjectConfirmation confirmation, Assertion assertion, - ValidationContext context, boolean required) { - // applications should validate their own in response to - return ValidationResult.VALID; - } }); } diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index ed3ad288f78..2cf7c62ba75 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -19,6 +19,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -39,12 +40,14 @@ import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.core.xml.schema.XSDateTime; import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder; +import org.opensaml.saml.common.SignableSAMLObject; import org.opensaml.saml.common.assertion.ValidationContext; import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters; import org.opensaml.saml.saml2.core.Assertion; import org.opensaml.saml.saml2.core.Attribute; import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.AttributeValue; +import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.Conditions; import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAttribute; @@ -73,6 +76,7 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken; import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations; import org.springframework.util.StringUtils; @@ -144,9 +148,7 @@ public void authenticateWhenXmlErrorThenThrowAuthenticationException() { public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { Response response = response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); response.getAssertions().add(assertion()); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) .satisfies(errorOf(Saml2ErrorCodes.INVALID_DESTINATION)); @@ -176,9 +178,7 @@ public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationExcept Assertion assertion = assertion(); assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData() .setNotOnOrAfter(Instant.now().minus(Duration.ofDays(3))); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) @@ -190,9 +190,7 @@ public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { Response response = response(); Assertion assertion = assertion(); assertion.setSubject(null); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) @@ -204,9 +202,7 @@ public void authenticateWhenUsernameMissingThenThrowAuthenticationException() { Response response = response(); Assertion assertion = assertion(); assertion.getSubject().getNameID().setValue(null); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) @@ -219,10 +215,113 @@ public void authenticateWhenAssertionContainsValidationAddressThenItSucceeds() { Assertion assertion = assertion(); assertion.getSubject().getSubjectConfirmations() .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); + Saml2AuthenticationToken token = token(response, verifying(registration())); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsMatchRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + this.provider.authenticate(token); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, true); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchWithRequestID() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion("BAD"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithRequestID() { + Response response = response(); + response.setInResponseTo("BAD"); + response.getAssertions().add(signed(assertion("SAML2"))); + response.getAssertions().add(signed(assertion("SAML2"))); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() { + Response response = response(); + response.setInResponseTo("SAML2"); + response.getAssertions().add(signed(assertion())); + response.getAssertions().add(signed(assertion())); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, true); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data"); + } + + @Test + public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() { + Response response = response(); + response.setInResponseTo("BAD"); Saml2AuthenticationToken token = token(response, verifying(registration())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to"); + } + + @Test + public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() { + Response response = response(); + response.getAssertions().add(signed(assertion())); + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2", + Saml2MessageBinding.POST, false); + Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest); this.provider.authenticate(token); } @@ -232,9 +331,7 @@ public void authenticateWhenAssertionContainsAttributesThenItSucceeds() { Assertion assertion = assertion(); List attributes = attributeStatements(); assertion.getAttributeStatements().addAll(attributes); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); Authentication authentication = this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); @@ -258,9 +355,7 @@ public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() { AttributeStatement attribute = TestOpenSamlObjects.customAttributeStatement("Address", TestCustomOpenSamlObjects.instance()); assertion.getAttributeStatements().add(attribute); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); Authentication authentication = this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); @@ -291,9 +386,7 @@ public void authenticateWhenEncryptedAssertionWithSignatureThenItSucceeds() { EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); this.provider.authenticate(token); } @@ -303,9 +396,7 @@ public void authenticateWhenEncryptedAssertionWithResponseSignatureThenItSucceed EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); this.provider.authenticate(token); } @@ -318,9 +409,7 @@ public void authenticateWhenEncryptedNameIdWithSignatureThenItSucceeds() { TestSaml2X509Credentials.assertingPartyEncryptingCredential()); assertion.getSubject().setNameID(null); assertion.getSubject().setEncryptedID(encryptedID); - response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); this.provider.authenticate(token); } @@ -335,9 +424,7 @@ public void authenticateWhenEncryptedAttributeThenDecrypts() { statement.getEncryptedAttributes().add(attribute); assertion.getAttributeStatements().add(statement); response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal(); assertThat(principal.getAttribute("name")).containsExactly("value"); @@ -349,9 +436,7 @@ public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationExcep EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) .satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); @@ -363,9 +448,7 @@ public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationExcepti EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, registration() + Saml2AuthenticationToken token = token(signed(response), registration() .decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()))); assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> this.provider.authenticate(token)) @@ -378,9 +461,7 @@ public void authenticateWhenAuthenticationHasDetailsThenSucceeds() { Assertion assertion = assertion(); assertion.getSubject().getSubjectConfirmations() .forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10")); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); token.setDetails("some-details"); Authentication authentication = this.provider.authenticate(token); @@ -395,9 +476,7 @@ public void writeObjectWhenTypeIsSaml2AuthenticationThenNoException() throws IOE EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion, TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, decrypting(verifying(registration()))); + Saml2AuthenticationToken token = token(signed(response), decrypting(verifying(registration()))); Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token); // the following code will throw an exception if authentication isn't serializable ByteArrayOutputStream byteStream = new ByteArrayOutputStream(1024); @@ -432,9 +511,7 @@ public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() { OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME); assertion.getConditions().getConditions().add(oneTimeUse); response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); // @formatter:off assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> provider.authenticate(token)).isInstanceOf(Saml2AuthenticationException.class) @@ -456,9 +533,7 @@ public void authenticateWhenCustomAssertionValidatorThenUses() { Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); given(validator.convert(any(OpenSaml4AuthenticationProvider.AssertionToken.class))) .willReturn(Saml2ResponseValidatorResult.success()); provider.authenticate(token); @@ -475,9 +550,7 @@ public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillCh RELYING_PARTY_ENTITY_ID); // broken // signature response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); // @formatter:off assertThatExceptionOfType(Saml2AuthenticationException.class) .isThrownBy(() -> provider.authenticate(token)) @@ -496,9 +569,7 @@ public void authenticateWhenValidationContextCustomizedThenUsers() { OpenSaml4AuthenticationProvider.createDefaultAssertionValidator((assertionToken) -> context)); Response response = response(); Assertion assertion = assertion(); - response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); // @formatter:off assertThatExceptionOfType(Saml2AuthenticationException.class) @@ -570,13 +641,12 @@ public void setAssertionElementsDecrypterWhenNullThenIllegalArgument() { public void authenticateWhenCustomResponseElementsDecrypterThenDecryptsResponse() { Response response = response(); Assertion assertion = assertion(); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); response.getEncryptedAssertions().add(new EncryptedAssertionBuilder().buildObject()); TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, verifying(registration())); - this.provider.setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(assertion)); + this.provider + .setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(signed(assertion))); Authentication authentication = this.provider.authenticate(token); assertThat(authentication.getName()).isEqualTo("test@saml.user"); } @@ -588,9 +658,7 @@ public void authenticateWhenCustomAssertionElementsDecrypterThenDecryptsAssertio EncryptedID id = new EncryptedIDBuilder().buildObject(); id.setEncryptedData(new EncryptedDataBuilder().buildObject()); assertion.getSubject().setEncryptedID(id); - TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(), - RELYING_PARTY_ENTITY_ID); - response.getAssertions().add(assertion); + response.getAssertions().add(signed(assertion)); Saml2AuthenticationToken token = token(response, verifying(registration())); this.provider.setAssertionElementsDecrypter((tuple) -> { NameID name = new NameIDBuilder().buildObject(); @@ -639,9 +707,7 @@ public void authenticateWhenCustomResponseValidatorThenUses() { Response response = response(); Assertion assertion = assertion(); response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); given(validator.convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class))) .willReturn(Saml2ResponseValidatorResult.success()); provider.authenticate(token); @@ -655,9 +721,7 @@ public void authenticateWhenAssertionIssuerNotValidThenFailsWithInvalidIssuer() Assertion assertion = assertion(); assertion.setIssuer(TestOpenSamlObjects.issuer("https://invalid.idp.test/saml2/idp")); response.getAssertions().add(assertion); - TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), - ASSERTING_PARTY_ENTITY_ID); - Saml2AuthenticationToken token = token(response, verifying(registration())); + Saml2AuthenticationToken token = token(signed(response), verifying(registration())); assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> provider.authenticate(token)) .withMessageContaining("did not match any valid issuers"); } @@ -702,13 +766,27 @@ private Response response(String destination, String issuerEntityId) { return response; } - private Assertion assertion() { + private AuthnRequest request() { + AuthnRequest request = TestOpenSamlObjects.authnRequest(); + return request; + } + + private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) { + String xml = serialize(request); + return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8)) + : Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); + } + + private Assertion assertion(String inResponseTo) { Assertion assertion = TestOpenSamlObjects.assertion(); assertion.setIssueInstant(Instant.now()); for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) { SubjectConfirmationData data = confirmation.getSubjectConfirmationData(); data.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); data.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000))); + if (StringUtils.hasText(inResponseTo)) { + data.setInResponseTo(inResponseTo); + } } Conditions conditions = assertion.getConditions(); conditions.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000))); @@ -716,6 +794,16 @@ private Assertion assertion() { return assertion; } + private Assertion assertion() { + return assertion(null); + } + + private T signed(T toSign) { + TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(), + RELYING_PARTY_ENTITY_ID); + return toSign; + } + private List attributeStatements() { List attributeStatements = TestOpenSamlObjects.attributeStatements(); AttributeBuilder attributeBuilder = new AttributeBuilder(); @@ -739,6 +827,27 @@ private Saml2AuthenticationToken token(Response response, RelyingPartyRegistrati return new Saml2AuthenticationToken(registration.build(), serialize(response)); } + private Saml2AuthenticationToken token(Response response, RelyingPartyRegistration.Builder registration, + AbstractSaml2AuthenticationRequest authenticationRequest) { + return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest); + } + + private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId, + Saml2MessageBinding binding, boolean corruptRequestString) { + AuthnRequest request = request(); + if (requestId != null) { + request.setID(requestId); + } + String serializedRequest = serializedRequest(request, binding); + if (corruptRequestString) { + serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2); + } + AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class); + given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest); + given(mockAuthenticationRequest.getBinding()).willReturn(binding); + return mockAuthenticationRequest; + } + private RelyingPartyRegistration.Builder registration() { return TestRelyingPartyRegistrations.noCredentials().entityId(RELYING_PARTY_ENTITY_ID) .assertionConsumerServiceLocation(DESTINATION)