Skip to content
This repository was archived by the owner on Apr 5, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ public class ProviderSignInController {
private String signupUrl = "/signup";

private String postLoginUrl = "/";

private final boolean absoluteCallback;

private final String trailingChar;

/**
* Creates a new provider sign-in controller.
Expand All @@ -90,6 +94,8 @@ public ProviderSignInController(String applicationUrl, Provider<ConnectionFactor
this.connectionRepositoryProvider = connectionRepositoryProvider;
this.signInService = signInService;
this.baseCallbackUrl = applicationUrl + AnnotationUtils.findAnnotation(getClass(), RequestMapping.class).value()[0];
this.absoluteCallback = applicationUrl.matches("(?i)https?://.*");
this.trailingChar = (!absoluteCallback && applicationUrl.indexOf("/") == 0) ? "" : "/";
}

/**
Expand Down Expand Up @@ -120,13 +126,13 @@ public RedirectView signin(@PathVariable String providerId, WebRequest request)
ConnectionFactory<?> connectionFactory = getConnectionFactoryLocator().getConnectionFactory(providerId);
if (connectionFactory instanceof OAuth1ConnectionFactory) {
OAuth1Operations oauth1Ops = ((OAuth1ConnectionFactory<?>) connectionFactory).getOAuthOperations();
OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId), null);
OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId, request), null);
request.setAttribute(OAUTH_TOKEN_ATTRIBUTE, requestToken, WebRequest.SCOPE_SESSION);
String authenticateUrl = oauth1Ops.buildAuthenticateUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId)) : OAuth1Parameters.NONE);
String authenticateUrl = oauth1Ops.buildAuthenticateUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId, request)) : OAuth1Parameters.NONE);
return new RedirectView(authenticateUrl);
} else if (connectionFactory instanceof OAuth2ConnectionFactory) {
OAuth2Operations oauth2Ops = ((OAuth2ConnectionFactory<?>) connectionFactory).getOAuthOperations();
String authenticateUrl = oauth2Ops.buildAuthenticateUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId), request.getParameter("scope")));
String authenticateUrl = oauth2Ops.buildAuthenticateUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId, request), request.getParameter("scope")));
return new RedirectView(authenticateUrl);
} else {
return handleSignInWithConnectionFactory(connectionFactory, request);
Expand Down Expand Up @@ -162,7 +168,7 @@ public RedirectView oauth1Callback(@PathVariable String providerId, @RequestPara
@RequestMapping(value="/{providerId}", method=RequestMethod.GET, params="code")
public RedirectView oauth2Callback(@PathVariable String providerId, @RequestParam("code") String code, WebRequest request) {
OAuth2ConnectionFactory<?> connectionFactory = (OAuth2ConnectionFactory<?>) getConnectionFactoryLocator().getConnectionFactory(providerId);
AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId), null);
AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId, request), null);
Connection<?> connection = connectionFactory.createConnection(accessGrant);
return handleSignIn(connection, request);
}
Expand All @@ -183,8 +189,10 @@ private ConnectionFactoryLocator getConnectionFactoryLocator() {
return connectionFactoryLocatorProvider.get();
}

private String callbackUrl(String providerId) {
return baseCallbackUrl + "/" + providerId;
protected String callbackUrl(String providerId, WebRequest request) {
if (absoluteCallback) return baseCallbackUrl + "/" + providerId;
String proto = request.isSecure() ? "https://" : "http://";
return proto + request.getHeader("Host") + trailingChar + baseCallbackUrl + "/" + providerId;
}

private OAuthToken extractCachedRequestToken(WebRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ public class ConnectController {

private Provider<ConnectionRepository> connectionRepositoryProvider;

private final boolean absoluteCallback;

private final String trailingChar;

/**
* Constructs a ConnectController.
* @param applicationUrl the base secure URL for this application, used to construct the callback URL passed to the service providers at the beginning of the connection process.
Expand All @@ -87,6 +91,8 @@ public ConnectController(String applicationUrl, ConnectionFactoryLocator connect
this.connectionFactoryLocator = connectionFactoryLocator;
this.connectionRepositoryProvider = connectionRepositoryProvider;
this.interceptors = new LinkedMultiValueMap<Class<?>, ConnectInterceptor<?>>();
this.absoluteCallback = applicationUrl.matches("(?i)https?://.*");
this.trailingChar = (!absoluteCallback && applicationUrl.indexOf("/") == 0) ? "" : "/";
}

/**
Expand Down Expand Up @@ -134,13 +140,13 @@ public RedirectView connect(@PathVariable String providerId, WebRequest request)
preConnect(connectionFactory, request);
if (connectionFactory instanceof OAuth1ConnectionFactory) {
OAuth1Operations oauth1Ops = ((OAuth1ConnectionFactory<?>) connectionFactory).getOAuthOperations();
OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId), null);
OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId, request), null);
request.setAttribute(OAUTH_TOKEN_ATTRIBUTE, requestToken, WebRequest.SCOPE_SESSION);
String authorizeUrl = oauth1Ops.buildAuthorizeUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId)) : OAuth1Parameters.NONE);
String authorizeUrl = oauth1Ops.buildAuthorizeUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId, request)) : OAuth1Parameters.NONE);
return new RedirectView(authorizeUrl);
} else if (connectionFactory instanceof OAuth2ConnectionFactory) {
OAuth2Operations oauth2Ops = ((OAuth2ConnectionFactory<?>) connectionFactory).getOAuthOperations();
String authorizeUrl = oauth2Ops.buildAuthorizeUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId), request.getParameter("scope")));
String authorizeUrl = oauth2Ops.buildAuthorizeUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId, request), request.getParameter("scope")));
return new RedirectView(authorizeUrl);
} else {
return handleConnectToCustomConnectionFactory(connectionFactory, request);
Expand Down Expand Up @@ -170,7 +176,7 @@ public RedirectView oauth1Callback(@PathVariable String providerId, @RequestPara
@RequestMapping(value="/{providerId}", method=RequestMethod.GET, params="code")
public RedirectView oauth2Callback(@PathVariable String providerId, @RequestParam("code") String code, WebRequest request) {
OAuth2ConnectionFactory<?> connectionFactory = (OAuth2ConnectionFactory<?>) connectionFactoryLocator.getConnectionFactory(providerId);
AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId), null);
AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId, request), null);
Connection<?> connection = connectionFactory.createConnection(accessGrant);
addConnection(request, connectionFactory, connection);
return redirectToProvider(providerId);
Expand Down Expand Up @@ -243,8 +249,10 @@ private String baseViewPath(String providerId) {
return "connect/" + providerId;
}

private String callbackUrl(String providerId) {
return baseCallbackUrl + "/" + providerId;
protected String callbackUrl(String providerId, WebRequest request) {
if (absoluteCallback) return baseCallbackUrl + "/" + providerId;
String proto = request.isSecure() ? "https://" : "http://";
return proto + request.getHeader("Host") + trailingChar + baseCallbackUrl + "/" + providerId;
}

private OAuthToken extractCachedRequestToken(WebRequest request) {
Expand Down