Skip to content

Add ID to Saml2 Post and Redirect Requests #11489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class Saml2PostAuthenticationRequestMixin {
Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
@JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
@JsonProperty("id") String id) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class Saml2RedirectAuthenticationRequestMixin {
@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
@JsonProperty("relayState") String relayState,
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
@JsonProperty("id") String id) {
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable

private final String relyingPartyRegistrationId;

private final String id;

/**
* Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
* @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
Expand All @@ -58,15 +60,18 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
* send the XML message, cannot be empty or null
* @param relyingPartyRegistrationId the registration id of the relying party, may be
* null
* @param id This is the unique id used in the {@link #samlRequest}, cannot be empty
* or null
*/
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) {
String relyingPartyRegistrationId, String id) {
Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
this.authenticationRequestUri = authenticationRequestUri;
this.samlRequest = samlRequest;
this.relayState = relayState;
this.relyingPartyRegistrationId = relyingPartyRegistrationId;
this.id = id;
}

/**
Expand Down Expand Up @@ -106,6 +111,15 @@ public String getRelyingPartyRegistrationId() {
return this.relyingPartyRegistrationId;
}

/**
* The unique identifier for this Authentication Request
* @return the Authentication Request identifier
* @since 5.8
*/
public String getId() {
return this.id;
}

/**
* Returns the binding this AuthNRequest will be sent and encoded with. If
* {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
Expand All @@ -127,6 +141,8 @@ public static class Builder<T extends Builder<T>> {

String relyingPartyRegistrationId;

String id;

/**
* @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
*/
Expand Down Expand Up @@ -184,6 +200,19 @@ public T authenticationRequestUri(String authenticationRequestUri) {
return _this();
}

/**
* This is the unique id used in the {@link #samlRequest}
* @param id the SAML2 request id
* @return the {@link AbstractSaml2AuthenticationRequest.Builder} for further
* configurations
* @since 5.8
*/
public T id(String id) {
Assert.notNull(id, "id cannot be null");
this.id = id;
return _this();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {

Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
String relyingPartyRegistrationId) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
String relyingPartyRegistrationId, String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
}

/**
Expand Down Expand Up @@ -69,7 +69,7 @@ private Builder(RelyingPartyRegistration registration) {
*/
public Saml2PostAuthenticationRequest build() {
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
this.relyingPartyRegistrationId);
this.relyingPartyRegistrationId, this.id);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
private final String signature;

private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
String authenticationRequestUri, String relyingPartyRegistrationId) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
String authenticationRequestUri, String relyingPartyRegistrationId, String id) {
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
this.sigAlg = sigAlg;
this.signature = signature;
}
Expand Down Expand Up @@ -116,7 +116,7 @@ public Builder signature(String signature) {
*/
public Saml2RedirectAuthenticationRequest build() {
return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId);
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId, this.id);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,14 @@ <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest requ
String xml = serialize(authnRequest);
String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8));
return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded)
.relayState(relayState).build();
.relayState(relayState).id(authnRequest.getID()).build();
}
else {
String xml = serialize(authnRequest);
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState);
.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
.id(authnRequest.getID());
if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void shouldDeserialize() throws Exception {
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
}

@Test
Expand All @@ -73,6 +74,24 @@ void shouldDeserializeWithNoRegistrationId() throws Exception {
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
}

@Test
void shouldDeserializeWithNoId() throws Exception {
String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON
.replace(", \"id\": \"" + TestSaml2JsonPayloads.ID + "\"", "");

Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);

assertThat(authRequest).isNotNull();
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
assertThat(authRequest.getAuthenticationRequestUri())
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
assertThat(authRequest.getRelyingPartyRegistrationId())
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
assertThat(authRequest.getId()).isNull();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
static final String SIG_ALG = "sigAlgValue";
static final String SIGNATURE = "signatureValue";
static final String ID = "idValue";

// @formatter:off
static final String DEFAULT_REDIRECT_AUTH_REQUEST_JSON = "{"
Expand All @@ -106,7 +107,8 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"sigAlg\": \"" + SIG_ALG + "\","
+ " \"signature\": \"" + SIGNATURE + "\""
+ " \"signature\": \"" + SIGNATURE + "\","
+ " \"id\": \"" + ID + "\""
+ "}";
// @formatter:on

Expand All @@ -116,11 +118,11 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
+ " \"relayState\": \"" + RELAY_STATE + "\","
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+ " \"id\": \"" + ID + "\""
+ "}";
// @formatter:on

static final String ID = "idValue";
static final String LOCATION = "locationValue";
static final String BINDNG = "REDIRECT";
static final String ADDITIONAL_PARAM = "additionalParamValue";
Expand All @@ -146,7 +148,7 @@ static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationReques
TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).id(ID).build();
}

static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
Expand All @@ -155,7 +157,7 @@ static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticati
.registrationId(RELYINGPARTY_REGISTRATION_ID)
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
.build())
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).id(ID).build();
}

static Saml2LogoutRequest createDefaultSaml2LogoutRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.springframework.security.saml2.provider.service.web.authentication;

import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.opensaml.xmlsec.signature.support.SignatureConstants;

import org.springframework.mock.web.MockHttpServletRequest;
Expand All @@ -40,7 +40,7 @@ public class OpenSamlAuthenticationRequestResolverTests {

private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;

@Before
@BeforeEach
public void setUp() {
this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration();
}
Expand All @@ -65,6 +65,7 @@ public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects(
assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
assertThat(result.getSignature()).isNotEmpty();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
assertThat(result.getId()).isNotEmpty();
}

@Test
Expand All @@ -88,6 +89,7 @@ public void resolveAuthenticationRequestWhenUnsignedRedirectThenRedirectsAndNoSi
assertThat(result.getSigAlg()).isNull();
assertThat(result.getSignature()).isNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
assertThat(result.getId()).isNotEmpty();
}

@Test
Expand All @@ -98,7 +100,9 @@ public void resolveAuthenticationRequestWhenSignedThenCredentialIsRequired() {
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> resolver.resolve(request, null));
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> resolver.resolve(request, (r, authnRequest) -> {
}));
}

@Test
Expand All @@ -122,6 +126,7 @@ public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
assertThat(result.getRelayState()).isNotNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).doesNotContain("Signature");
assertThat(result.getId()).isNotEmpty();
}

@Test
Expand All @@ -144,6 +149,7 @@ public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() {
assertThat(result.getRelayState()).isNotNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.POST);
assertThat(new String(Saml2Utils.samlDecode(result.getSamlRequest()))).contains("Signature");
assertThat(result.getId()).isNotEmpty();
}

@Test
Expand All @@ -154,12 +160,14 @@ public void resolveAuthenticationRequestWhenSHA1SignRequestThenSigns() {
(party) -> party.signingAlgorithms((algs) -> algs.add(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1)))
.build();
OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, null);
Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
});
assertThat(result.getSamlRequest()).isNotEmpty();
assertThat(result.getRelayState()).isNotNull();
assertThat(result.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1);
assertThat(result.getSignature()).isNotNull();
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
assertThat(result.getId()).isNotEmpty();
}

private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) {
Expand Down