diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestRelyingPartyRegistrations.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestRelyingPartyRegistrations.java index a9103f27e70..b69456e2bcd 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestRelyingPartyRegistrations.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/TestRelyingPartyRegistrations.java @@ -43,8 +43,8 @@ static RelyingPartyRegistration saml2AuthenticationConfiguration() { Saml2X509Credential idpVerificationCertificate = verificationCertificate(); String acsUrlTemplate = "{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; return RelyingPartyRegistration.withRegistrationId(registrationId) - .remoteIdpEntityId(idpEntityId) - .idpWebSsoUrl(webSsoEndpoint) + .providerDetails(c -> c.entityId(idpEntityId)) + .providerDetails(c -> c.webSsoUrl(webSsoEndpoint)) .credentials(c -> c.add(signingCredential)) .credentials(c -> c.add(idpVerificationCertificate)) .localEntityIdTemplate(localEntityIdTemplate) diff --git a/config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt index 519aa14b4e1..bf6dd2f85ef 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt @@ -85,10 +85,10 @@ class Saml2DslTests { relyingPartyRegistrationRepository = InMemoryRelyingPartyRegistrationRepository( RelyingPartyRegistration.withRegistrationId("samlId") - .remoteIdpEntityId("entityId") .assertionConsumerServiceUrlTemplate("{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI) .credentials { c -> c.add(Saml2X509Credential(loadCert("rod.cer"), VERIFICATION)) } - .idpWebSsoUrl("ssoUrl") + .providerDetails { c -> c.webSsoUrl("ssoUrl") } + .providerDetails { c -> c.entityId("entityId") } .build() ) } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java index 67983a12f90..3aee2258bbf 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java @@ -54,7 +54,11 @@ public String createAuthenticationRequest(Saml2AuthenticationRequest request) { */ @Override public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) { - String xml = createAuthenticationRequest(context, context.getRelyingPartyRegistration().getSigningCredentials()); + List signingCredentials = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ? + context.getRelyingPartyRegistration().getSigningCredentials() : + emptyList(); + + String xml = createAuthenticationRequest(context, signingCredentials); return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context) .samlRequest(samlEncode(xml.getBytes(UTF_8))) .build(); @@ -66,19 +70,24 @@ public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2Authe @Override public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) { String xml = createAuthenticationRequest(context, emptyList()); - List signingCredentials = context.getRelyingPartyRegistration().getSigningCredentials(); Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context); - String deflatedAndEncoded = samlEncode(samlDeflate(xml)); - Map signedParams = this.saml.signQueryParameters( - signingCredentials, - deflatedAndEncoded, - context.getRelayState() - ); - result.samlRequest(signedParams.get("SAMLRequest")) - .relayState(signedParams.get("RelayState")) - .sigAlg(signedParams.get("SigAlg")) - .signature(signedParams.get("Signature")); + result.samlRequest(deflatedAndEncoded) + .relayState(context.getRelayState()); + + if (context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest()) { + List signingCredentials = context.getRelyingPartyRegistration().getSigningCredentials(); + Map signedParams = this.saml.signQueryParameters( + signingCredentials, + deflatedAndEncoded, + context.getRelayState() + ); + result.samlRequest(signedParams.get("SAMLRequest")) + .relayState(signedParams.get("RelayState")) + .sigAlg(signedParams.get("SigAlg")) + .signature(signedParams.get("Signature")); + } + return result.build(); } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java index e769b4354df..3e80e69ae88 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestContext.java @@ -91,7 +91,7 @@ public String getRelayState() { * @return the Destination value */ public String getDestination() { - return this.getRelyingPartyRegistration().getIdpWebSsoUrl(); + return this.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl(); } /** diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java index c0327ab1046..f96d492ed94 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java @@ -51,8 +51,8 @@ * //IDP certificate for verification of incoming messages * Saml2X509Credential idpVerificationCertificate = getVerificationCertificate(); * RelyingPartyRegistration rp = RelyingPartyRegistration.withRegistrationId(registrationId) - * .remoteIdpEntityId(idpEntityId) - * .idpWebSsoUrl(webSsoEndpoint) + * .providerDetails(config -> config.entityId(idpEntityId)); + * .providerDetails(config -> config.webSsoUrl(url)); * .credentials(c -> c.add(signingCredential)) * .credentials(c -> c.add(idpVerificationCertificate)) * .localEntityIdTemplate(localEntityIdTemplate) @@ -64,37 +64,41 @@ public class RelyingPartyRegistration { private final String registrationId; - private final String remoteIdpEntityId; private final String assertionConsumerServiceUrlTemplate; - private final String idpWebSsoUrl; private final List credentials; private final String localEntityIdTemplate; - - private RelyingPartyRegistration(String idpEntityId, String registrationId, String assertionConsumerServiceUrlTemplate, - String idpWebSsoUri, List credentials, String localEntityIdTemplate) { - hasText(idpEntityId, "idpEntityId cannot be empty"); + private final ProviderDetails providerDetails; + + private RelyingPartyRegistration( + String registrationId, + String assertionConsumerServiceUrlTemplate, + ProviderDetails providerDetails, + List credentials, + String localEntityIdTemplate) { hasText(registrationId, "registrationId cannot be empty"); hasText(assertionConsumerServiceUrlTemplate, "assertionConsumerServiceUrlTemplate cannot be empty"); hasText(localEntityIdTemplate, "localEntityIdTemplate cannot be empty"); notEmpty(credentials, "credentials cannot be empty"); - notNull(idpWebSsoUri, "idpWebSsoUri cannot be empty"); + notNull(providerDetails, "providerDetails cannot be null"); + hasText(providerDetails.webSsoUrl, "providerDetails.webSsoUrl cannot be empty"); for (Saml2X509Credential c : credentials) { notNull(c, "credentials cannot contain null elements"); } this.registrationId = registrationId; - this.remoteIdpEntityId = idpEntityId; this.assertionConsumerServiceUrlTemplate = assertionConsumerServiceUrlTemplate; this.credentials = unmodifiableList(new LinkedList<>(credentials)); - this.idpWebSsoUrl = idpWebSsoUri; + this.providerDetails = providerDetails; this.localEntityIdTemplate = localEntityIdTemplate; } /** * Returns the entity ID of the IDP, the asserting party. * @return entity ID of the asserting party + * @deprecated use {@link ProviderDetails#getEntityId()} from {@link #getProviderDetails()} */ + @Deprecated public String getRemoteIdpEntityId() { - return this.remoteIdpEntityId; + return this.providerDetails.getEntityId(); } /** @@ -119,9 +123,20 @@ public String getAssertionConsumerServiceUrlTemplate() { * Contains the URL for which to send the SAML 2 Authentication Request to initiate * a single sign on flow. * @return a IDP URL that accepts REDIRECT or POST binding for authentication requests + * @deprecated use {@link ProviderDetails#getWebSsoUrl()} from {@link #getProviderDetails()} */ + @Deprecated public String getIdpWebSsoUrl() { - return this.idpWebSsoUrl; + return this.getProviderDetails().webSsoUrl; + } + + /** + * Returns specific configuration around the Identity Provider SSO endpoint + * @return the IDP SSO endpoint configuration + * @since 5.3 + */ + public ProviderDetails getProviderDetails() { + return this.providerDetails; } /** @@ -200,13 +215,158 @@ public static Builder withRegistrationId(String registrationId) { return new Builder(registrationId); } - public static class Builder { + /** + * Creates a {@code RelyingPartyRegistration} {@link Builder} based on an existing object + * @param registration the {@code RelyingPartyRegistration} + * @return {@code Builder} to create a {@code RelyingPartyRegistration} object + */ + public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) { + Assert.notNull(registration, "registration cannot be null"); + return withRegistrationId(registration.getRegistrationId()) + .providerDetails(c -> { + c.webSsoUrl(registration.getProviderDetails().getWebSsoUrl()); + c.binding(registration.getProviderDetails().getBinding()); + c.signAuthNRequest(registration.getProviderDetails().isSignAuthNRequest()); + c.entityId(registration.getProviderDetails().getEntityId()); + }) + .credentials(c -> c.addAll(registration.getCredentials())) + .localEntityIdTemplate(registration.getLocalEntityIdTemplate()) + .assertionConsumerServiceUrlTemplate(registration.getAssertionConsumerServiceUrlTemplate()) + ; + } + + /** + * Configuration for IDP SSO endpoint configuration + * @since 5.3 + */ + public final static class ProviderDetails { + private final String entityId; + private final String webSsoUrl; + private final boolean signAuthNRequest; + private final Saml2MessageBinding binding; + + private ProviderDetails( + String entityId, + String webSsoUrl, + boolean signAuthNRequest, + Saml2MessageBinding binding) { + hasText(entityId, "entityId cannot be null or empty"); + notNull(webSsoUrl, "webSsoUrl cannot be null"); + notNull(binding, "binding cannot be null"); + this.entityId = entityId; + this.webSsoUrl = webSsoUrl; + this.signAuthNRequest = signAuthNRequest; + this.binding = binding; + } + + /** + * Returns the entity ID of the Identity Provider + * @return the entity ID of the IDP + */ + public String getEntityId() { + return entityId; + } + + /** + * Contains the URL for which to send the SAML 2 Authentication Request to initiate + * a single sign on flow. + * @return a IDP URL that accepts REDIRECT or POST binding for authentication requests + */ + public String getWebSsoUrl() { + return webSsoUrl; + } + + /** + * @return {@code true} if AuthNRequests from this relying party to the IDP should be signed + * {@code false} if no signature is required. + */ + public boolean isSignAuthNRequest() { + return signAuthNRequest; + } + + /** + * @return the type of SAML 2 Binding the AuthNRequest should be sent on + */ + public Saml2MessageBinding getBinding() { + return binding; + } + + /** + * Builder for IDP SSO endpoint configuration + * @since 5.3 + */ + public final static class Builder { + private String entityId; + private String webSsoUrl; + private boolean signAuthNRequest = true; + private Saml2MessageBinding binding = Saml2MessageBinding.REDIRECT; + + /** + * Sets the {@code EntityID} for the remote asserting party, the Identity Provider. + * + * @param entityId - the EntityID of the IDP. May be a URL. + * @return this object + */ + public Builder entityId(String entityId) { + this.entityId = entityId; + return this; + } + + /** + * Sets the {@code SSO URL} for the remote asserting party, the Identity Provider. + * + * @param url - a URL that accepts authentication requests via REDIRECT or POST bindings + * @return this object + */ + public Builder webSsoUrl(String url) { + this.webSsoUrl = url; + return this; + } + + /** + * Set to true if the AuthNRequest message should be signed + * + * @param signAuthNRequest true if the message should be signed + * @return this object + */ + public Builder signAuthNRequest(boolean signAuthNRequest) { + this.signAuthNRequest = signAuthNRequest; + return this; + } + + + /** + * Sets the message binding to be used when sending an AuthNRequest message + * + * @param binding either {@link Saml2MessageBinding#POST} or {@link Saml2MessageBinding#REDIRECT} + * @return this object + */ + public Builder binding(Saml2MessageBinding binding) { + this.binding = binding; + return this; + } + + /** + * Creates an immutable ProviderDetails object representing the configuration for an Identity Provider, IDP + * @return immutable ProviderDetails object + */ + public ProviderDetails build() { + return new ProviderDetails( + this.entityId, + this.webSsoUrl, + this.signAuthNRequest, + this.binding + ); + } + } + } + + public final static class Builder { private String registrationId; - private String remoteIdpEntityId; - private String idpWebSsoUrl; private String assertionConsumerServiceUrlTemplate; private List credentials = new LinkedList<>(); private String localEntityIdTemplate = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; + private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder(); private Builder(String registrationId) { this.registrationId = registrationId; @@ -227,9 +387,11 @@ public Builder registrationId(String id) { * Sets the {@code entityId} for the remote asserting party, the Identity Provider. * @param entityId the IDP entityId * @return this object + * @deprecated use {@link #providerDetails(Consumer< ProviderDetails.Builder >)} */ + @Deprecated public Builder remoteIdpEntityId(String entityId) { - this.remoteIdpEntityId = entityId; + this.providerDetails(idp -> idp.entityId(entityId)); return this; } @@ -250,9 +412,21 @@ public Builder assertionConsumerServiceUrlTemplate(String assertionConsumerServi * Sets the {@code SSO URL} for the remote asserting party, the Identity Provider. * @param url - a URL that accepts authentication requests via REDIRECT or POST bindings * @return this object + * @deprecated use {@link #providerDetails(Consumer< ProviderDetails.Builder >)} */ + @Deprecated public Builder idpWebSsoUrl(String url) { - this.idpWebSsoUrl = url; + providerDetails(config -> config.webSsoUrl(url)); + return this; + } + + /** + * Configures the IDP SSO endpoint + * @param providerDetails a consumer that configures the IDP SSO endpoint + * @return this object + */ + public Builder providerDetails(Consumer providerDetails) { + providerDetails.accept(this.providerDetails); return this; } @@ -288,17 +462,19 @@ public Builder localEntityIdTemplate(String template) { return this; } + /** + * Constructs a RelyingPartyRegistration object based on the builder configurations + * @return a RelyingPartyRegistration instance + */ public RelyingPartyRegistration build() { return new RelyingPartyRegistration( - remoteIdpEntityId, - registrationId, - assertionConsumerServiceUrlTemplate, - idpWebSsoUrl, - credentials, - localEntityIdTemplate + this.registrationId, + this.assertionConsumerServiceUrlTemplate, + this.providerDetails.build(), + this.credentials, + this.localEntityIdTemplate ); } } - } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java index e78017184e9..2088f55cf2c 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2ServletUtils.java @@ -39,7 +39,7 @@ static String getServiceProviderEntityId(RelyingPartyRegistration rp, HttpServle return resolveUrlTemplate( rp.getLocalEntityIdTemplate(), getApplicationUri(request), - rp.getRemoteIdpEntityId(), + rp.getProviderDetails().getEntityId(), rp.getRegistrationId() ); } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java index 666013f1cc7..a332664be25 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java @@ -101,7 +101,7 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken( responseXml, request.getRequestURL().toString(), - rp.getRemoteIdpEntityId(), + rp.getProviderDetails().getEntityId(), localSpEntityId, rp.getCredentials() ); diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index 881afda820a..03c628d2d60 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -16,17 +16,21 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import org.springframework.http.MediaType; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; @@ -74,23 +78,42 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } String registrationId = matcher.getVariables().get("registrationId"); - sendRedirect(request, response, registrationId); + RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); + if (relyingParty == null) { + response.sendError(HttpServletResponse.SC_UNAUTHORIZED); + return; + } + if (this.logger.isDebugEnabled()) { + this.logger.debug(format("Creating SAML2 SP Authentication Request for IDP[%s]", relyingParty.getRegistrationId())); + } + Saml2AuthenticationRequestContext authnRequestCtx = createRedirectAuthenticationRequestContext(relyingParty, request); + if (relyingParty.getProviderDetails().getBinding() == Saml2MessageBinding.REDIRECT) { + sendRedirect(response, authnRequestCtx); + } + else { + sendPost(response, authnRequestCtx); + } } - private void sendRedirect(HttpServletRequest request, HttpServletResponse response, String registrationId) + private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext authnRequestCtx) throws IOException { - if (this.logger.isDebugEnabled()) { - this.logger.debug(format("Creating SAML2 SP Authentication Request for IDP[%s]", registrationId)); - } - RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); - String redirectUrl = createSamlRequestRedirectUrl(request, relyingParty); + String redirectUrl = createSamlRequestRedirectUrl(authnRequestCtx); response.sendRedirect(redirectUrl); } - private String createSamlRequestRedirectUrl(HttpServletRequest request, RelyingPartyRegistration relyingParty) { - Saml2AuthenticationRequestContext authnRequest = createRedirectAuthenticationRequestContext(relyingParty, request); + private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext authnRequestCtx) + throws IOException { + Saml2PostAuthenticationRequest authNData = + this.authenticationRequestFactory.createPostAuthenticationRequest(authnRequestCtx); + String html = createSamlPostRequestFormData(authNData); + response.setContentType(MediaType.TEXT_HTML_VALUE); + response.getWriter().write(html); + } + + private String createSamlRequestRedirectUrl(Saml2AuthenticationRequestContext authnRequestCtx) { + Saml2RedirectAuthenticationRequest authNData = - this.authenticationRequestFactory.createRedirectAuthenticationRequest(authnRequest); + this.authenticationRequestFactory.createRedirectAuthenticationRequest(authnRequestCtx); UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(authNData.getAuthenticationRequestUri()); addParameter("SAMLRequest", authNData.getSamlRequest(), uriBuilder); addParameter("RelayState", authNData.getRelayState(), uriBuilder); @@ -123,7 +146,7 @@ private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestCon Saml2ServletUtils.resolveUrlTemplate( relyingParty.getAssertionConsumerServiceUrlTemplate(), Saml2ServletUtils.getApplicationUri(request), - relyingParty.getRemoteIdpEntityId(), + relyingParty.getProviderDetails().getEntityId(), relyingParty.getRegistrationId() ) ) @@ -131,4 +154,56 @@ private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestCon .build() ; } + + private String htmlEscape(String value) { + if (hasText(value)) { + return HtmlUtils.htmlEscape(value); + } + return value; + } + + private String createSamlPostRequestFormData(Saml2PostAuthenticationRequest request) { + String destination = request.getAuthenticationRequestUri(); + String relayState = htmlEscape(request.getRelayState()); + String samlRequest = htmlEscape(request.getSamlRequest()); + StringBuilder postHtml = new StringBuilder() + .append("\n") + .append("\n") + .append(" \n") + .append(" \n") + .append(" \n") + .append(" \n") + .append(" \n") + .append(" \n") + .append("
\n") + .append("
\n") + .append(" \n") + ; + if (hasText(relayState)) { + postHtml + .append(" \n"); + } + postHtml + .append("
\n") + .append(" \n") + .append("
\n") + .append(" \n") + .append(" \n") + .append("") + ; + return postHtml.toString(); + } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index 5834b5176de..a795538f501 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -26,12 +26,12 @@ import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; -import java.nio.charset.StandardCharsets; - +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2X509Credentials.relyingPartyCredentials; +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT; @@ -46,19 +46,20 @@ public class OpenSamlAuthenticationRequestFactoryTests { @Rule public ExpectedException exception = ExpectedException.none(); + private RelyingPartyRegistration relyingPartyRegistration; @Before public void setUp() { - RelyingPartyRegistration registration = RelyingPartyRegistration.withRegistrationId("id") + relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id") .assertionConsumerServiceUrlTemplate("template") - .idpWebSsoUrl("https://destination/sso") - .remoteIdpEntityId("remote-entity-id") + .providerDetails(c -> c.webSsoUrl("https://destination/sso")) + .providerDetails(c -> c.entityId("remote-entity-id")) .localEntityIdTemplate("local-entity-id") .credentials(c -> c.addAll(relyingPartyCredentials())) .build(); contextBuilder = Saml2AuthenticationRequestContext.builder() .issuer("https://issuer") - .relyingPartyRegistration(registration) + .relyingPartyRegistration(relyingPartyRegistration) .assertionConsumerServiceUrl("https://issuer/sso"); context = contextBuilder.build(); factory = new OpenSamlAuthenticationRequestFactory(); @@ -84,6 +85,60 @@ public void createRedirectAuthenticationRequestWhenUsingContextThenAllValuesAreS assertThat(result.getBinding()).isEqualTo(REDIRECT); } + @Test + public void createRedirectAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { + + context = contextBuilder + .relayState("Relay State Value") + .relyingPartyRegistration( + withRelyingPartyRegistration(relyingPartyRegistration) + .providerDetails(c -> c.signAuthNRequest(false)) + .build() + ) + .build(); + Saml2RedirectAuthenticationRequest result = factory.createRedirectAuthenticationRequest(context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getSigAlg()).isNull(); + assertThat(result.getSignature()).isNull(); + assertThat(result.getBinding()).isEqualTo(REDIRECT); + } + + @Test + public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { + context = contextBuilder + .relayState("Relay State Value") + .relyingPartyRegistration( + withRelyingPartyRegistration(relyingPartyRegistration) + .providerDetails(c -> c.signAuthNRequest(false)) + .build() + ) + .build(); + Saml2PostAuthenticationRequest result = factory.createPostAuthenticationRequest(context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getBinding()).isEqualTo(POST); + assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)) + .doesNotContain("ds:Signature"); + } + + @Test + public void createPostAuthenticationRequestWhenSignRequestThenSignatureIsPresent() { + context = contextBuilder + .relayState("Relay State Value") + .relyingPartyRegistration( + withRelyingPartyRegistration(relyingPartyRegistration) + .build() + ) + .build(); + Saml2PostAuthenticationRequest result = factory.createPostAuthenticationRequest(context); + assertThat(result.getSamlRequest()).isNotEmpty(); + assertThat(result.getRelayState()).isEqualTo("Relay State Value"); + assertThat(result.getBinding()).isEqualTo(POST); + assertThat(new String(samlDecode(result.getSamlRequest()), UTF_8)) + .contains("ds:Signature"); + } + @Test public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { AuthnRequest authn = getAuthNRequest(POST); @@ -114,7 +169,7 @@ private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) { samlRequest = Saml2Utils.samlInflate(samlDecode(samlRequest)); } else { - samlRequest = new String(samlDecode(samlRequest), StandardCharsets.UTF_8); + samlRequest = new String(samlDecode(samlRequest), UTF_8); } return (AuthnRequest) OpenSamlImplementation.getInstance().resolve(samlRequest); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java index 6c79ba187d8..7b66577fbf9 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationRequestFactoryTests.java @@ -33,8 +33,8 @@ public class Saml2AuthenticationRequestFactoryTests { private RelyingPartyRegistration registration = RelyingPartyRegistration.withRegistrationId("id") .assertionConsumerServiceUrlTemplate("template") - .idpWebSsoUrl("https://example.com/destination") - .remoteIdpEntityId("remote-entity-id") + .providerDetails(c -> c.webSsoUrl("https://example.com/destination")) + .providerDetails(c -> c.entityId("remote-entity-id")) .localEntityIdTemplate("local-entity-id") .credentials(c -> c.addAll(relyingPartyCredentials())) .build(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java new file mode 100644 index 00000000000..b060599fd81 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import org.junit.Test; +import org.springframework.security.saml2.credentials.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; + +import static org.assertj.core.api.Assertions.assertThat; + +public class RelyingPartyRegistrationTests { + + @Test + public void withRelyingPartyRegistrationWorks() { + RelyingPartyRegistration registration = relyingPartyRegistration(); + RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build(); + compareRegistrations(registration, copy); + } + + private void compareRegistrations(RelyingPartyRegistration registration, RelyingPartyRegistration copy) { + assertThat(copy.getRegistrationId()) + .isEqualTo(registration.getRegistrationId()) + .isEqualTo("simplesamlphp"); + assertThat(copy.getProviderDetails().getEntityId()) + .isEqualTo(registration.getProviderDetails().getEntityId()) + .isEqualTo("https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php"); + assertThat(copy.getAssertionConsumerServiceUrlTemplate()) + .isEqualTo(registration.getAssertionConsumerServiceUrlTemplate()) + .isEqualTo("{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI); + assertThat(copy.getCredentials()) + .containsAll(registration.getCredentials()) + .containsExactly( + registration.getCredentials().get(0), + registration.getCredentials().get(1) + ); + assertThat(copy.getLocalEntityIdTemplate()) + .isEqualTo(registration.getLocalEntityIdTemplate()) + .isEqualTo("{baseUrl}/saml2/service-provider-metadata/{registrationId}"); + assertThat(copy.getProviderDetails().getWebSsoUrl()) + .isEqualTo(registration.getProviderDetails().getWebSsoUrl()) + .isEqualTo("https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php"); + assertThat(copy.getProviderDetails().getBinding()) + .isEqualTo(registration.getProviderDetails().getBinding()) + .isEqualTo(Saml2MessageBinding.POST); + assertThat(copy.getProviderDetails().isSignAuthNRequest()) + .isEqualTo(registration.getProviderDetails().isSignAuthNRequest()) + .isFalse(); + } + + + private RelyingPartyRegistration relyingPartyRegistration() { + //remote IDP entity ID + String idpEntityId = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/metadata.php"; + //remote WebSSO Endpoint - Where to Send AuthNRequests to + String webSsoEndpoint = "https://simplesaml-for-spring-saml.cfapps.io/saml2/idp/SSOService.php"; + //local registration ID + String registrationId = "simplesamlphp"; + //local entity ID - autogenerated based on URL + String localEntityIdTemplate = "{baseUrl}/saml2/service-provider-metadata/{registrationId}"; + //local signing (and decryption key) + Saml2X509Credential signingCredential = TestSaml2X509Credentials.relyingPartyCredentials().get(0); + //IDP certificate for verification of incoming messages + Saml2X509Credential idpVerificationCertificate = TestSaml2X509Credentials.relyingPartyCredentials().get(1); + String acsUrlTemplate = "{baseUrl}" + Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; + return RelyingPartyRegistration.withRegistrationId(registrationId) + .providerDetails(c -> { + c.webSsoUrl(webSsoEndpoint); + c.binding(Saml2MessageBinding.POST); + c.signAuthNRequest(false); + c.entityId(idpEntityId); + }) + .credentials(c -> c.add(signingCredential)) + .credentials(c -> c.add(idpVerificationCertificate)) + .localEntityIdTemplate(localEntityIdTemplate) + .assertionConsumerServiceUrlTemplate(acsUrlTemplate) + .build(); + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestSaml2X509Credentials.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestSaml2X509Credentials.java new file mode 100644 index 00000000000..b49c32eb90e --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestSaml2X509Credentials.java @@ -0,0 +1,143 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import org.opensaml.security.crypto.KeySupport; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.credentials.Saml2X509Credential; + +import java.io.ByteArrayInputStream; +import java.security.KeyException; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.List; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION; +import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION; +import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.SIGNING; +import static org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION; + +final class TestSaml2X509Credentials { + static List relyingPartyCredentials() { + return Arrays.asList( + new Saml2X509Credential( + spPrivateKey(), + spCertificate(), + SIGNING, + DECRYPTION + ), + new Saml2X509Credential( + idpCertificate(), + ENCRYPTION, + VERIFICATION + ) + ); + } + + private static X509Certificate certificate(String cert) { + ByteArrayInputStream certBytes = new ByteArrayInputStream(cert.getBytes()); + try { + return (X509Certificate) CertificateFactory + .getInstance("X.509") + .generateCertificate(certBytes); + } + catch (CertificateException e) { + throw new Saml2Exception(e); + } + } + + private static PrivateKey privateKey(String key) { + try { + return KeySupport.decodePrivateKey(key.getBytes(UTF_8), new char[0]); + } + catch (KeyException e) { + throw new Saml2Exception(e); + } + } + + private static X509Certificate idpCertificate() { + return certificate("-----BEGIN CERTIFICATE-----\n" + + "MIIEEzCCAvugAwIBAgIJAIc1qzLrv+5nMA0GCSqGSIb3DQEBCwUAMIGfMQswCQYD\n" + + "VQQGEwJVUzELMAkGA1UECAwCQ08xFDASBgNVBAcMC0Nhc3RsZSBSb2NrMRwwGgYD\n" + + "VQQKDBNTYW1sIFRlc3RpbmcgU2VydmVyMQswCQYDVQQLDAJJVDEgMB4GA1UEAwwX\n" + + "c2ltcGxlc2FtbHBocC5jZmFwcHMuaW8xIDAeBgkqhkiG9w0BCQEWEWZoYW5pa0Bw\n" + + "aXZvdGFsLmlvMB4XDTE1MDIyMzIyNDUwM1oXDTI1MDIyMjIyNDUwM1owgZ8xCzAJ\n" + + "BgNVBAYTAlVTMQswCQYDVQQIDAJDTzEUMBIGA1UEBwwLQ2FzdGxlIFJvY2sxHDAa\n" + + "BgNVBAoME1NhbWwgVGVzdGluZyBTZXJ2ZXIxCzAJBgNVBAsMAklUMSAwHgYDVQQD\n" + + "DBdzaW1wbGVzYW1scGhwLmNmYXBwcy5pbzEgMB4GCSqGSIb3DQEJARYRZmhhbmlr\n" + + "QHBpdm90YWwuaW8wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC4cn62\n" + + "E1xLqpN34PmbrKBbkOXFjzWgJ9b+pXuaRft6A339uuIQeoeH5qeSKRVTl32L0gdz\n" + + "2ZivLwZXW+cqvftVW1tvEHvzJFyxeTW3fCUeCQsebLnA2qRa07RkxTo6Nf244mWW\n" + + "RDodcoHEfDUSbxfTZ6IExSojSIU2RnD6WllYWFdD1GFpBJOmQB8rAc8wJIBdHFdQ\n" + + "nX8Ttl7hZ6rtgqEYMzYVMuJ2F2r1HSU1zSAvwpdYP6rRGFRJEfdA9mm3WKfNLSc5\n" + + "cljz0X/TXy0vVlAV95l9qcfFzPmrkNIst9FZSwpvB49LyAVke04FQPPwLgVH4gph\n" + + "iJH3jvZ7I+J5lS8VAgMBAAGjUDBOMB0GA1UdDgQWBBTTyP6Cc5HlBJ5+ucVCwGc5\n" + + "ogKNGzAfBgNVHSMEGDAWgBTTyP6Cc5HlBJ5+ucVCwGc5ogKNGzAMBgNVHRMEBTAD\n" + + "AQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAvMS4EQeP/ipV4jOG5lO6/tYCb/iJeAduO\n" + + "nRhkJk0DbX329lDLZhTTL/x/w/9muCVcvLrzEp6PN+VWfw5E5FWtZN0yhGtP9R+v\n" + + "ZnrV+oc2zGD+no1/ySFOe3EiJCO5dehxKjYEmBRv5sU/LZFKZpozKN/BMEa6CqLu\n" + + "xbzb7ykxVr7EVFXwltPxzE9TmL9OACNNyF5eJHWMRMllarUvkcXlh4pux4ks9e6z\n" + + "V9DQBy2zds9f1I3qxg0eX6JnGrXi/ZiCT+lJgVe3ZFXiejiLAiKB04sXW3ti0LW3\n" + + "lx13Y1YlQ4/tlpgTgfIJxKV6nyPiLoK0nywbMd+vpAirDt2Oc+hk\n" + + "-----END CERTIFICATE-----\n"); + } + + + private static X509Certificate spCertificate() { + + return certificate("-----BEGIN CERTIFICATE-----\n" + + "MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBhMC\n" + + "VVMxEzARBgNVBAgMCldhc2hpbmd0b24xEjAQBgNVBAcMCVZhbmNvdXZlcjEdMBsG\n" + + "A1UECgwUU3ByaW5nIFNlY3VyaXR5IFNBTUwxCzAJBgNVBAsMAnNwMSAwHgYDVQQD\n" + + "DBdzcC5zcHJpbmcuc2VjdXJpdHkuc2FtbDAeFw0xODA1MTQxNDMwNDRaFw0yODA1\n" + + "MTExNDMwNDRaMIGEMQswCQYDVQQGEwJVUzETMBEGA1UECAwKV2FzaGluZ3RvbjES\n" + + "MBAGA1UEBwwJVmFuY291dmVyMR0wGwYDVQQKDBRTcHJpbmcgU2VjdXJpdHkgU0FN\n" + + "TDELMAkGA1UECwwCc3AxIDAeBgNVBAMMF3NwLnNwcmluZy5zZWN1cml0eS5zYW1s\n" + + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDRu7/EI0BlNzMEBFVAcbx+lLos\n" + + "vzIWU+01dGTY8gBdhMQNYKZ92lMceo2CuVJ66cUURPym3i7nGGzoSnAxAre+0YIM\n" + + "+U0razrWtAUE735bkcqELZkOTZLelaoOztmWqRbe5OuEmpewH7cx+kNgcVjdctOG\n" + + "y3Q6x+I4qakY/9qhBQIDAQABMA0GCSqGSIb3DQEBCwUAA4GBAAeViTvHOyQopWEi\n" + + "XOfI2Z9eukwrSknDwq/zscR0YxwwqDBMt/QdAODfSwAfnciiYLkmEjlozWRtOeN+\n" + + "qK7UFgP1bRl5qksrYX5S0z2iGJh0GvonLUt3e20Ssfl5tTEDDnAEUMLfBkyaxEHD\n" + + "RZ/nbTJ7VTeZOSyRoVn5XHhpuJ0B\n" + + "-----END CERTIFICATE-----"); + } + + private static PrivateKey spPrivateKey() { + return privateKey("-----BEGIN PRIVATE KEY-----\n" + + "MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBANG7v8QjQGU3MwQE\n" + + "VUBxvH6Uuiy/MhZT7TV0ZNjyAF2ExA1gpn3aUxx6jYK5UnrpxRRE/KbeLucYbOhK\n" + + "cDECt77Rggz5TStrOta0BQTvfluRyoQtmQ5Nkt6Vqg7O2ZapFt7k64Sal7AftzH6\n" + + "Q2BxWN1y04bLdDrH4jipqRj/2qEFAgMBAAECgYEAj4ExY1jjdN3iEDuOwXuRB+Nn\n" + + "x7pC4TgntE2huzdKvLJdGvIouTArce8A6JM5NlTBvm69mMepvAHgcsiMH1zGr5J5\n" + + "wJz23mGOyhM1veON41/DJTVG+cxq4soUZhdYy3bpOuXGMAaJ8QLMbQQoivllNihd\n" + + "vwH0rNSK8LTYWWPZYIECQQDxct+TFX1VsQ1eo41K0T4fu2rWUaxlvjUGhK6HxTmY\n" + + "8OMJptunGRJL1CUjIb45Uz7SP8TPz5FwhXWsLfS182kRAkEA3l+Qd9C9gdpUh1uX\n" + + "oPSNIxn5hFUrSTW1EwP9QH9vhwb5Vr8Jrd5ei678WYDLjUcx648RjkjhU9jSMzIx\n" + + "EGvYtQJBAMm/i9NR7IVyyNIgZUpz5q4LI21rl1r4gUQuD8vA36zM81i4ROeuCly0\n" + + "KkfdxR4PUfnKcQCX11YnHjk9uTFj75ECQEFY/gBnxDjzqyF35hAzrYIiMPQVfznt\n" + + "YX/sDTE2AdVBVGaMj1Cb51bPHnNC6Q5kXKQnj/YrLqRQND09Q7ParX0CQQC5NxZr\n" + + "9jKqhHj8yQD6PlXTsY4Occ7DH6/IoDenfdEVD5qlet0zmd50HatN2Jiqm5ubN7CM\n" + + "INrtuLp4YHbgk1mi\n" + + "-----END PRIVATE KEY-----"); + } + +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 3e2af6e65d3..4b5bb37bb4b 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -16,10 +16,6 @@ package org.springframework.security.saml2.provider.service.servlet.filter; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import javax.servlet.ServletException; - import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockFilterChain; @@ -27,11 +23,17 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; +import javax.servlet.ServletException; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential; public class Saml2WebSsoAuthenticationRequestFilterTests { @@ -55,8 +57,8 @@ public void setup() { rpBuilder = RelyingPartyRegistration .withRegistrationId("registration-id") - .remoteIdpEntityId("idp-entity-id") - .idpWebSsoUrl(IDP_SSO_URL) + .providerDetails(c -> c.entityId("idp-entity-id")) + .providerDetails(c -> c.webSsoUrl(IDP_SSO_URL)) .assertionConsumerServiceUrlTemplate("template") .credentials(c -> c.add(signingCredential())); } @@ -109,4 +111,40 @@ public void doFilterWhenSimpleSignatureSpecifiedThenSignatureParametersAreInTheR .startsWith(IDP_SSO_URL); } + @Test + public void doFilterWhenSignatureIsDisabledThenSignatureParametersAreNotInTheRedirectURL() throws Exception { + when(repository.findByRegistrationId("registration-id")).thenReturn( + rpBuilder + .providerDetails(c -> c.signAuthNRequest(false)) + .build() + ); + final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param"; + final String relayStateEncoded = UriUtils.encode(relayStateValue, StandardCharsets.ISO_8859_1); + request.setParameter("RelayState", relayStateValue); + filter.doFilterInternal(request, response, filterChain); + assertThat(response.getHeader("Location")) + .contains("RelayState="+relayStateEncoded) + .doesNotContain("SigAlg=") + .doesNotContain("Signature=") + .startsWith(IDP_SSO_URL); + } + + @Test + public void doFilterWhenPostFormDataIsPresent() throws Exception { + when(repository.findByRegistrationId("registration-id")).thenReturn( + rpBuilder + .providerDetails(c -> c.binding(POST)) + .build() + ); + final String relayStateValue = "https://my-relay-state.example.com?with=param&other=param&javascript{alert('1');}"; + final String relayStateEncoded = HtmlUtils.htmlEscape(relayStateValue); + request.setParameter("RelayState", relayStateValue); + filter.doFilterInternal(request, response, filterChain); + assertThat(response.getHeader("Location")).isNull(); + assertThat(response.getContentAsString()) + .contains("
") + .contains(" config.entityId(idpEntityId)) + .providerDetails(config -> config.webSsoUrl(webSsoEndpoint)) .credentials(c -> c.add(signingCredential)) .credentials(c -> c.add(idpVerificationCertificate)) .localEntityIdTemplate(localEntityIdTemplate)