diff --git a/spring-social-web/src/main/java/org/springframework/social/connect/signin/web/ProviderSignInController.java b/spring-social-web/src/main/java/org/springframework/social/connect/signin/web/ProviderSignInController.java index 1dd54ab89..918156dea 100644 --- a/spring-social-web/src/main/java/org/springframework/social/connect/signin/web/ProviderSignInController.java +++ b/spring-social-web/src/main/java/org/springframework/social/connect/signin/web/ProviderSignInController.java @@ -70,6 +70,10 @@ public class ProviderSignInController { private String signupUrl = "/signup"; private String postLoginUrl = "/"; + + private final boolean absoluteCallback; + + private final String trailingChar; /** * Creates a new provider sign-in controller. @@ -90,6 +94,8 @@ public ProviderSignInController(String applicationUrl, Provider connectionFactory = getConnectionFactoryLocator().getConnectionFactory(providerId); if (connectionFactory instanceof OAuth1ConnectionFactory) { OAuth1Operations oauth1Ops = ((OAuth1ConnectionFactory) connectionFactory).getOAuthOperations(); - OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId), null); + OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId, request), null); request.setAttribute(OAUTH_TOKEN_ATTRIBUTE, requestToken, WebRequest.SCOPE_SESSION); - String authenticateUrl = oauth1Ops.buildAuthenticateUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId)) : OAuth1Parameters.NONE); + String authenticateUrl = oauth1Ops.buildAuthenticateUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId, request)) : OAuth1Parameters.NONE); return new RedirectView(authenticateUrl); } else if (connectionFactory instanceof OAuth2ConnectionFactory) { OAuth2Operations oauth2Ops = ((OAuth2ConnectionFactory) connectionFactory).getOAuthOperations(); - String authenticateUrl = oauth2Ops.buildAuthenticateUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId), request.getParameter("scope"))); + String authenticateUrl = oauth2Ops.buildAuthenticateUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId, request), request.getParameter("scope"))); return new RedirectView(authenticateUrl); } else { return handleSignInWithConnectionFactory(connectionFactory, request); @@ -162,7 +168,7 @@ public RedirectView oauth1Callback(@PathVariable String providerId, @RequestPara @RequestMapping(value="/{providerId}", method=RequestMethod.GET, params="code") public RedirectView oauth2Callback(@PathVariable String providerId, @RequestParam("code") String code, WebRequest request) { OAuth2ConnectionFactory connectionFactory = (OAuth2ConnectionFactory) getConnectionFactoryLocator().getConnectionFactory(providerId); - AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId), null); + AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId, request), null); Connection connection = connectionFactory.createConnection(accessGrant); return handleSignIn(connection, request); } @@ -183,8 +189,10 @@ private ConnectionFactoryLocator getConnectionFactoryLocator() { return connectionFactoryLocatorProvider.get(); } - private String callbackUrl(String providerId) { - return baseCallbackUrl + "/" + providerId; + protected String callbackUrl(String providerId, WebRequest request) { + if (absoluteCallback) return baseCallbackUrl + "/" + providerId; + String proto = request.isSecure() ? "https://" : "http://"; + return proto + request.getHeader("Host") + trailingChar + baseCallbackUrl + "/" + providerId; } private OAuthToken extractCachedRequestToken(WebRequest request) { diff --git a/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java b/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java index 4600a3db7..4eddc3e50 100644 --- a/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java +++ b/spring-social-web/src/main/java/org/springframework/social/connect/web/ConnectController.java @@ -75,6 +75,10 @@ public class ConnectController { private Provider connectionRepositoryProvider; + private final boolean absoluteCallback; + + private final String trailingChar; + /** * Constructs a ConnectController. * @param applicationUrl the base secure URL for this application, used to construct the callback URL passed to the service providers at the beginning of the connection process. @@ -87,6 +91,8 @@ public ConnectController(String applicationUrl, ConnectionFactoryLocator connect this.connectionFactoryLocator = connectionFactoryLocator; this.connectionRepositoryProvider = connectionRepositoryProvider; this.interceptors = new LinkedMultiValueMap, ConnectInterceptor>(); + this.absoluteCallback = applicationUrl.matches("(?i)https?://.*"); + this.trailingChar = (!absoluteCallback && applicationUrl.indexOf("/") == 0) ? "" : "/"; } /** @@ -134,13 +140,13 @@ public RedirectView connect(@PathVariable String providerId, WebRequest request) preConnect(connectionFactory, request); if (connectionFactory instanceof OAuth1ConnectionFactory) { OAuth1Operations oauth1Ops = ((OAuth1ConnectionFactory) connectionFactory).getOAuthOperations(); - OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId), null); + OAuthToken requestToken = oauth1Ops.fetchRequestToken(callbackUrl(providerId, request), null); request.setAttribute(OAUTH_TOKEN_ATTRIBUTE, requestToken, WebRequest.SCOPE_SESSION); - String authorizeUrl = oauth1Ops.buildAuthorizeUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId)) : OAuth1Parameters.NONE); + String authorizeUrl = oauth1Ops.buildAuthorizeUrl(requestToken.getValue(), oauth1Ops.getVersion() == OAuth1Version.CORE_10 ? new OAuth1Parameters(callbackUrl(providerId, request)) : OAuth1Parameters.NONE); return new RedirectView(authorizeUrl); } else if (connectionFactory instanceof OAuth2ConnectionFactory) { OAuth2Operations oauth2Ops = ((OAuth2ConnectionFactory) connectionFactory).getOAuthOperations(); - String authorizeUrl = oauth2Ops.buildAuthorizeUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId), request.getParameter("scope"))); + String authorizeUrl = oauth2Ops.buildAuthorizeUrl(GrantType.AUTHORIZATION_CODE, new OAuth2Parameters(callbackUrl(providerId, request), request.getParameter("scope"))); return new RedirectView(authorizeUrl); } else { return handleConnectToCustomConnectionFactory(connectionFactory, request); @@ -170,7 +176,7 @@ public RedirectView oauth1Callback(@PathVariable String providerId, @RequestPara @RequestMapping(value="/{providerId}", method=RequestMethod.GET, params="code") public RedirectView oauth2Callback(@PathVariable String providerId, @RequestParam("code") String code, WebRequest request) { OAuth2ConnectionFactory connectionFactory = (OAuth2ConnectionFactory) connectionFactoryLocator.getConnectionFactory(providerId); - AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId), null); + AccessGrant accessGrant = connectionFactory.getOAuthOperations().exchangeForAccess(code, callbackUrl(providerId, request), null); Connection connection = connectionFactory.createConnection(accessGrant); addConnection(request, connectionFactory, connection); return redirectToProvider(providerId); @@ -243,8 +249,10 @@ private String baseViewPath(String providerId) { return "connect/" + providerId; } - private String callbackUrl(String providerId) { - return baseCallbackUrl + "/" + providerId; + protected String callbackUrl(String providerId, WebRequest request) { + if (absoluteCallback) return baseCallbackUrl + "/" + providerId; + String proto = request.isSecure() ? "https://" : "http://"; + return proto + request.getHeader("Host") + trailingChar + baseCallbackUrl + "/" + providerId; } private OAuthToken extractCachedRequestToken(WebRequest request) {