Skip to content

✨ JDBC device_code authorization #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@
import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.lang.Nullable;
import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
Expand Down Expand Up @@ -106,20 +103,31 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
+ "refresh_token_value,"
+ "refresh_token_issued_at,"
+ "refresh_token_expires_at,"
+ "refresh_token_metadata";
+ "refresh_token_metadata,"
+ "user_code_value,"
+ "user_code_issued_at,"
+ "user_code_expires_at,"
+ "user_code_metadata,"
+ "device_code_value,"
+ "device_code_issued_at,"
+ "device_code_expires_at,"
+ "device_code_metadata";
// @formatter:on

private static final String TABLE_NAME = "oauth2_authorization";

private static final String PK_FILTER = "id = ?";
private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR " +
"access_token_value = ? OR oidc_id_token_value = ? OR refresh_token_value = ?";
private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR "
+ "access_token_value = ? OR oidc_id_token_value = ? OR refresh_token_value = ? OR "
+ "user_code_value = ? OR device_code_value = ?";

private static final String STATE_FILTER = "state = ?";
private static final String AUTHORIZATION_CODE_FILTER = "authorization_code_value = ?";
private static final String ACCESS_TOKEN_FILTER = "access_token_value = ?";
private static final String ID_TOKEN_FILTER = "oidc_id_token_value = ?";
private static final String REFRESH_TOKEN_FILTER = "refresh_token_value = ?";
private static final String USER_CODE_FILTER = "user_code_value = ?";
private static final String DEVICE_CODE_FILTER = "device_code_value = ?";

// @formatter:off
private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES
Expand All @@ -129,7 +137,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic

// @formatter:off
private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
// @formatter:on

// @formatter:off
Expand All @@ -138,7 +146,9 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
+ " authorization_code_value = ?, authorization_code_issued_at = ?, authorization_code_expires_at = ?, authorization_code_metadata = ?,"
+ " access_token_value = ?, access_token_issued_at = ?, access_token_expires_at = ?, access_token_metadata = ?, access_token_type = ?, access_token_scopes = ?,"
+ " oidc_id_token_value = ?, oidc_id_token_issued_at = ?, oidc_id_token_expires_at = ?, oidc_id_token_metadata = ?,"
+ " refresh_token_value = ?, refresh_token_issued_at = ?, refresh_token_expires_at = ?, refresh_token_metadata = ?"
+ " refresh_token_value = ?, refresh_token_issued_at = ?, refresh_token_expires_at = ?, refresh_token_metadata = ?,"
+ " user_code_value = ?, user_code_issued_at = ?, user_code_expires_at = ?, user_code_metadata = ?,"
+ " device_code_value = ?, device_code_issued_at = ?, device_code_expires_at = ?, device_code_metadata = ?"
+ " WHERE " + PK_FILTER;
// @formatter:on

Expand Down Expand Up @@ -244,6 +254,8 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
parameters.add(mapToSqlParameter("access_token_value", token));
parameters.add(mapToSqlParameter("oidc_id_token_value", token));
parameters.add(mapToSqlParameter("refresh_token_value", token));
parameters.add(mapToSqlParameter("user_code_value", token));
parameters.add(mapToSqlParameter("device_code_value", token));
return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
parameters.add(new SqlParameterValue(Types.VARCHAR, token));
Expand All @@ -260,6 +272,12 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
parameters.add(mapToSqlParameter("refresh_token_value", token));
return findBy(REFRESH_TOKEN_FILTER, parameters);
} else if (OAuth2TokenType.USER_CODE.equals(tokenType)) {
parameters.add(mapToSqlParameter("user_code_value", token));
return findBy(USER_CODE_FILTER, parameters);
} else if (OAuth2TokenType.DEVICE_CODE.equals(tokenType)) {
parameters.add(mapToSqlParameter("device_code_value", token));
return findBy(DEVICE_CODE_FILTER, parameters);
}
return null;
}
Expand Down Expand Up @@ -425,6 +443,35 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
refreshTokenValue, tokenIssuedAt, tokenExpiresAt);
builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
}

String userCodeValue = getLobValue(rs, "user_code_value");
if (StringUtils.hasText(userCodeValue)) {
tokenIssuedAt = rs.getTimestamp("user_code_issued_at").toInstant();
tokenExpiresAt = null;
Timestamp userCodeExpiresAt = rs.getTimestamp("user_code_expires_at");
if (userCodeExpiresAt != null) {
tokenExpiresAt = userCodeExpiresAt.toInstant();
}
Map<String, Object> userCodeMetadata = parseMap(getLobValue(rs, "user_code_metadata"));

OAuth2UserCode userCode = new OAuth2UserCode(userCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(userCode, (metadata) -> metadata.putAll(userCodeMetadata));
}

String deviceCodeValue = getLobValue(rs, "device_code_value");
if (StringUtils.hasText(deviceCodeValue)) {
tokenIssuedAt = rs.getTimestamp("device_code_issued_at").toInstant();
tokenExpiresAt = null;
Timestamp deviceCodeExpiresAt = rs.getTimestamp("device_code_expires_at");
if (deviceCodeExpiresAt != null) {
tokenExpiresAt = deviceCodeExpiresAt.toInstant();
}
Map<String, Object> deviceCodeMetadata = parseMap(getLobValue(rs, "device_code_metadata"));

OAuth2DeviceCode deviceCode = new OAuth2DeviceCode(deviceCodeValue, tokenIssuedAt, tokenExpiresAt);
builder.token(deviceCode, (metadata) -> metadata.putAll(deviceCodeMetadata));
}

return builder.build();
}

Expand Down Expand Up @@ -545,6 +592,17 @@ public List<SqlParameterValue> apply(OAuth2Authorization authorization) {
List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(
"refresh_token_value", "refresh_token_metadata", refreshToken);
parameters.addAll(refreshTokenSqlParameters);

OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
List<SqlParameterValue> userCodeSqlParameters = toSqlParameterList(
"user_code_value", "user_code_metadata", userCode);
parameters.addAll(userCodeSqlParameters);

OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class);
List<SqlParameterValue> deviceCodeSqlParameters = toSqlParameterList(
"device_code_value", "device_code_metadata", deviceCode);
parameters.addAll(deviceCodeSqlParameters);

return parameters;
}

Expand Down Expand Up @@ -670,6 +728,14 @@ private static void initColumnMetadata(JdbcOperations jdbcOperations) {
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "refresh_token_metadata", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "user_code_value", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "user_code_metadata", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "device_code_value", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
columnMetadata = getColumnMetadata(jdbcOperations, "device_code_metadata", Types.BLOB);
columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
}

private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName, int defaultDataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
import java.util.function.Consumer;

import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.util.SpringAuthorizationServerVersion;
import org.springframework.util.Assert;
Expand All @@ -50,6 +47,8 @@
* @see OAuth2Token
* @see OAuth2AccessToken
* @see OAuth2RefreshToken
* @see OAuth2UserCode
* @see OAuth2DeviceCode
*/
public class OAuth2Authorization implements Serializable {
private static final long serialVersionUID = SpringAuthorizationServerVersion.SERIAL_VERSION_UID;
Expand Down Expand Up @@ -129,6 +128,28 @@ public Token<OAuth2RefreshToken> getRefreshToken() {
return getToken(OAuth2RefreshToken.class);
}

/**
* Returns the {@link Token} of type {@link OAuth2UserCode}.
*
* @return the {@link Token} of type {@link OAuth2UserCode}, or {@code null} if not
* available
*/
@Nullable
public Token<OAuth2UserCode> getUserCode() {
return getToken(OAuth2UserCode.class);
}

/**
* Returns the {@link Token} of type {@link OAuth2DeviceCode}.
*
* @return the {@link Token} of type {@link OAuth2DeviceCode}, or {@code null} if not
* available
*/
@Nullable
public Token<OAuth2DeviceCode> getDeviceCode() {
return getToken(OAuth2DeviceCode.class);
}

/**
* Returns the {@link Token} of type {@code tokenType}.
*
Expand Down Expand Up @@ -460,6 +481,26 @@ public Builder refreshToken(OAuth2RefreshToken refreshToken) {
return token(refreshToken);
}

/**
* Sets the {@link OAuth2UserCode user token}.
*
* @param userCode the {@link OAuth2UserCode}
* @return the {@link Builder}
*/
public Builder userCode(OAuth2UserCode userCode) {
return token(userCode);
}

/**
* Sets the {@link OAuth2DeviceCode device token}.
*
* @param deviceCode the {@link OAuth2DeviceCode}
* @return the {@link Builder}
*/
public Builder deviceCode(OAuth2DeviceCode deviceCode) {
return token(deviceCode);
}

/**
* Sets the {@link OAuth2Token token}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public final class OAuth2TokenType implements Serializable {
private static final long serialVersionUID = SpringAuthorizationServerVersion.SERIAL_VERSION_UID;
public static final OAuth2TokenType ACCESS_TOKEN = new OAuth2TokenType("access_token");
public static final OAuth2TokenType REFRESH_TOKEN = new OAuth2TokenType("refresh_token");
public static final OAuth2TokenType USER_CODE = new OAuth2TokenType("user_code");
public static final OAuth2TokenType DEVICE_CODE = new OAuth2TokenType("device_code");
private final String value;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,13 @@ CREATE TABLE oauth2_authorization (
refresh_token_issued_at timestamp DEFAULT NULL,
refresh_token_expires_at timestamp DEFAULT NULL,
refresh_token_metadata blob DEFAULT NULL,
user_code_value blob DEFAULT NULL,
user_code_issued_at timestamp DEFAULT NULL,
user_code_expires_at timestamp DEFAULT NULL,
user_code_metadata blob DEFAULT NULL,
device_code_value blob DEFAULT NULL,
device_code_issued_at timestamp DEFAULT NULL,
device_code_expires_at timestamp DEFAULT NULL,
device_code_metadata blob DEFAULT NULL,
PRIMARY KEY (id)
);
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand All @@ -41,6 +45,8 @@
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
Expand Down Expand Up @@ -131,6 +137,12 @@ public RegisteredClientRepository registeredClientRepository() {
return new InMemoryRegisteredClientRepository(registeredClient);
}

@Bean
public OAuth2AuthorizationService authorizationService(JdbcTemplate jdbcTemplate,
RegisteredClientRepository registeredClientRepository) {
return new JdbcOAuth2AuthorizationService(jdbcTemplate, registeredClientRepository);
}

@Bean
public JWKSource<SecurityContext> jwkSource() {
KeyPair keyPair = generateRsaKey();
Expand Down Expand Up @@ -167,4 +179,18 @@ public AuthorizationServerSettings authorizationServerSettings() {
return AuthorizationServerSettings.builder().build();
}

@Bean
public EmbeddedDatabase embeddedDatabase() {
// @formatter:off
return new EmbeddedDatabaseBuilder()
.generateUniqueName(true)
.setType(EmbeddedDatabaseType.H2)
.setScriptEncoding("UTF-8")
.addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql")
.addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql")
.addScript("org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql")
.build();
// @formatter:on
}

}