16
16
package org .springframework .security .oauth2 .server .authorization ;
17
17
18
18
import java .nio .charset .StandardCharsets ;
19
+ import java .sql .DatabaseMetaData ;
19
20
import java .sql .PreparedStatement ;
20
21
import java .sql .ResultSet ;
21
22
import java .sql .SQLException ;
35
36
36
37
import org .springframework .dao .DataRetrievalFailureException ;
37
38
import org .springframework .jdbc .core .ArgumentPreparedStatementSetter ;
39
+ import org .springframework .jdbc .core .ConnectionCallback ;
38
40
import org .springframework .jdbc .core .JdbcOperations ;
39
41
import org .springframework .jdbc .core .PreparedStatementSetter ;
40
42
import org .springframework .jdbc .core .RowMapper ;
@@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
141
143
142
144
private final JdbcOperations jdbcOperations ;
143
145
private final LobHandler lobHandler ;
146
+ private static int tokenColumnType ;
144
147
private RowMapper <OAuth2Authorization > authorizationRowMapper ;
145
148
private Function <OAuth2Authorization , List <SqlParameterValue >> authorizationParametersMapper ;
146
149
@@ -169,12 +172,15 @@ public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
169
172
Assert .notNull (lobHandler , "lobHandler cannot be null" );
170
173
this .jdbcOperations = jdbcOperations ;
171
174
this .lobHandler = lobHandler ;
175
+ tokenColumnType = getColumnDataType (jdbcOperations , "access_token_value" );
172
176
OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper (registeredClientRepository );
173
177
authorizationRowMapper .setLobHandler (lobHandler );
174
178
this .authorizationRowMapper = authorizationRowMapper ;
175
- this .authorizationParametersMapper = new OAuth2AuthorizationParametersMapper ();
179
+ OAuth2AuthorizationParametersMapper authorizationParametersMapper = new OAuth2AuthorizationParametersMapper ();
180
+ this .authorizationParametersMapper = authorizationParametersMapper ;
176
181
}
177
182
183
+
178
184
@ Override
179
185
public void save (OAuth2Authorization authorization ) {
180
186
Assert .notNull (authorization , "authorization cannot be null" );
@@ -232,26 +238,33 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t
232
238
List <SqlParameterValue > parameters = new ArrayList <>();
233
239
if (tokenType == null ) {
234
240
parameters .add (new SqlParameterValue (Types .VARCHAR , token ));
235
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
236
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
237
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
241
+ parameters .add (mapTokenToSqlParameter ( token ));
242
+ parameters .add (mapTokenToSqlParameter ( token ));
243
+ parameters .add (mapTokenToSqlParameter ( token ));
238
244
return findBy (UNKNOWN_TOKEN_TYPE_FILTER , parameters );
239
245
} else if (OAuth2ParameterNames .STATE .equals (tokenType .getValue ())) {
240
246
parameters .add (new SqlParameterValue (Types .VARCHAR , token ));
241
247
return findBy (STATE_FILTER , parameters );
242
248
} else if (OAuth2ParameterNames .CODE .equals (tokenType .getValue ())) {
243
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
249
+ parameters .add (mapTokenToSqlParameter ( token ));
244
250
return findBy (AUTHORIZATION_CODE_FILTER , parameters );
245
251
} else if (OAuth2TokenType .ACCESS_TOKEN .equals (tokenType )) {
246
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
252
+ parameters .add (mapTokenToSqlParameter ( token ));
247
253
return findBy (ACCESS_TOKEN_FILTER , parameters );
248
254
} else if (OAuth2TokenType .REFRESH_TOKEN .equals (tokenType )) {
249
- parameters .add (new SqlParameterValue ( Types . BLOB , token . getBytes ( StandardCharsets . UTF_8 ) ));
255
+ parameters .add (mapTokenToSqlParameter ( token ));
250
256
return findBy (REFRESH_TOKEN_FILTER , parameters );
251
257
}
252
258
return null ;
253
259
}
254
260
261
+ private SqlParameterValue mapTokenToSqlParameter (String token ) {
262
+ if (Types .BLOB == tokenColumnType ) {
263
+ return new SqlParameterValue (Types .BLOB , token .getBytes (StandardCharsets .UTF_8 ));
264
+ }
265
+ return new SqlParameterValue (tokenColumnType , token );
266
+ }
267
+
255
268
private OAuth2Authorization findBy (String filter , List <SqlParameterValue > parameters ) {
256
269
try (LobCreator lobCreator = getLobHandler ().getLobCreator ()) {
257
270
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter (lobCreator ,
@@ -349,25 +362,22 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
349
362
builder .attribute (OAuth2ParameterNames .STATE , state );
350
363
}
351
364
352
- String tokenValue ;
353
365
Instant tokenIssuedAt ;
354
366
Instant tokenExpiresAt ;
355
- byte [] authorizationCodeValue = this . lobHandler . getBlobAsBytes (rs , "authorization_code_value" );
367
+ String authorizationCodeValue = getTokenValue (rs , "authorization_code_value" );
356
368
357
- if (authorizationCodeValue != null ) {
358
- tokenValue = new String (authorizationCodeValue , StandardCharsets .UTF_8 );
369
+ if (StringUtils .hasText (authorizationCodeValue )) {
359
370
tokenIssuedAt = rs .getTimestamp ("authorization_code_issued_at" ).toInstant ();
360
371
tokenExpiresAt = rs .getTimestamp ("authorization_code_expires_at" ).toInstant ();
361
372
Map <String , Object > authorizationCodeMetadata = parseMap (rs .getString ("authorization_code_metadata" ));
362
373
363
374
OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode (
364
- tokenValue , tokenIssuedAt , tokenExpiresAt );
375
+ authorizationCodeValue , tokenIssuedAt , tokenExpiresAt );
365
376
builder .token (authorizationCode , (metadata ) -> metadata .putAll (authorizationCodeMetadata ));
366
377
}
367
378
368
- byte [] accessTokenValue = this .lobHandler .getBlobAsBytes (rs , "access_token_value" );
369
- if (accessTokenValue != null ) {
370
- tokenValue = new String (accessTokenValue , StandardCharsets .UTF_8 );
379
+ String accessTokenValue = getTokenValue (rs , "access_token_value" );
380
+ if (StringUtils .hasText (accessTokenValue )) {
371
381
tokenIssuedAt = rs .getTimestamp ("access_token_issued_at" ).toInstant ();
372
382
tokenExpiresAt = rs .getTimestamp ("access_token_expires_at" ).toInstant ();
373
383
Map <String , Object > accessTokenMetadata = parseMap (rs .getString ("access_token_metadata" ));
@@ -381,25 +391,23 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
381
391
if (accessTokenScopes != null ) {
382
392
scopes = StringUtils .commaDelimitedListToSet (accessTokenScopes );
383
393
}
384
- OAuth2AccessToken accessToken = new OAuth2AccessToken (tokenType , tokenValue , tokenIssuedAt , tokenExpiresAt , scopes );
394
+ OAuth2AccessToken accessToken = new OAuth2AccessToken (tokenType , accessTokenValue , tokenIssuedAt , tokenExpiresAt , scopes );
385
395
builder .token (accessToken , (metadata ) -> metadata .putAll (accessTokenMetadata ));
386
396
}
387
397
388
- byte [] oidcIdTokenValue = this .lobHandler .getBlobAsBytes (rs , "oidc_id_token_value" );
389
- if (oidcIdTokenValue != null ) {
390
- tokenValue = new String (oidcIdTokenValue , StandardCharsets .UTF_8 );
398
+ String oidcIdTokenValue = getTokenValue (rs , "oidc_id_token_value" );
399
+ if (StringUtils .hasText (oidcIdTokenValue )) {
391
400
tokenIssuedAt = rs .getTimestamp ("oidc_id_token_issued_at" ).toInstant ();
392
401
tokenExpiresAt = rs .getTimestamp ("oidc_id_token_expires_at" ).toInstant ();
393
402
Map <String , Object > oidcTokenMetadata = parseMap (rs .getString ("oidc_id_token_metadata" ));
394
403
395
404
OidcIdToken oidcToken = new OidcIdToken (
396
- tokenValue , tokenIssuedAt , tokenExpiresAt , (Map <String , Object >) oidcTokenMetadata .get (OAuth2Authorization .Token .CLAIMS_METADATA_NAME ));
405
+ oidcIdTokenValue , tokenIssuedAt , tokenExpiresAt , (Map <String , Object >) oidcTokenMetadata .get (OAuth2Authorization .Token .CLAIMS_METADATA_NAME ));
397
406
builder .token (oidcToken , (metadata ) -> metadata .putAll (oidcTokenMetadata ));
398
407
}
399
408
400
- byte [] refreshTokenValue = this .lobHandler .getBlobAsBytes (rs , "refresh_token_value" );
401
- if (refreshTokenValue != null ) {
402
- tokenValue = new String (refreshTokenValue , StandardCharsets .UTF_8 );
409
+ String refreshTokenValue = getTokenValue (rs , "refresh_token_value" );
410
+ if (StringUtils .hasText (refreshTokenValue )) {
403
411
tokenIssuedAt = rs .getTimestamp ("refresh_token_issued_at" ).toInstant ();
404
412
tokenExpiresAt = null ;
405
413
Timestamp refreshTokenExpiresAt = rs .getTimestamp ("refresh_token_expires_at" );
@@ -409,12 +417,29 @@ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException
409
417
Map <String , Object > refreshTokenMetadata = parseMap (rs .getString ("refresh_token_metadata" ));
410
418
411
419
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken (
412
- tokenValue , tokenIssuedAt , tokenExpiresAt );
420
+ refreshTokenValue , tokenIssuedAt , tokenExpiresAt );
413
421
builder .token (refreshToken , (metadata ) -> metadata .putAll (refreshTokenMetadata ));
414
422
}
415
423
return builder .build ();
416
424
}
417
425
426
+ private String getTokenValue (ResultSet rs , String tokenColumn ) throws SQLException {
427
+ String tokenValue = null ;
428
+ if (Types .CLOB == tokenColumnType ) {
429
+ tokenValue = this .lobHandler .getClobAsString (rs , tokenColumn );
430
+ }
431
+ if (Types .VARCHAR == tokenColumnType ) {
432
+ tokenValue = rs .getString (tokenColumn );
433
+ }
434
+ if (Types .BLOB == tokenColumnType ) {
435
+ byte [] tokenValueByte = this .lobHandler .getBlobAsBytes (rs , tokenColumn );
436
+ if (tokenValueByte != null ) {
437
+ tokenValue = new String (tokenValueByte , StandardCharsets .UTF_8 );
438
+ }
439
+ }
440
+ return tokenValue ;
441
+ }
442
+
418
443
public final void setLobHandler (LobHandler lobHandler ) {
419
444
Assert .notNull (lobHandler , "lobHandler cannot be null" );
420
445
this .lobHandler = lobHandler ;
@@ -520,12 +545,12 @@ protected final ObjectMapper getObjectMapper() {
520
545
521
546
private <T extends AbstractOAuth2Token > List <SqlParameterValue > toSqlParameterList (OAuth2Authorization .Token <T > token ) {
522
547
List <SqlParameterValue > parameters = new ArrayList <>();
523
- byte [] tokenValue = null ;
548
+ String tokenValue = null ;
524
549
Timestamp tokenIssuedAt = null ;
525
550
Timestamp tokenExpiresAt = null ;
526
551
String metadata = null ;
527
552
if (token != null ) {
528
- tokenValue = token .getToken ().getTokenValue (). getBytes ( StandardCharsets . UTF_8 ) ;
553
+ tokenValue = token .getToken ().getTokenValue ();
529
554
if (token .getToken ().getIssuedAt () != null ) {
530
555
tokenIssuedAt = Timestamp .from (token .getToken ().getIssuedAt ());
531
556
}
@@ -534,7 +559,13 @@ private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterLi
534
559
}
535
560
metadata = writeMap (token .getMetadata ());
536
561
}
537
- parameters .add (new SqlParameterValue (Types .BLOB , tokenValue ));
562
+ if (Types .BLOB == tokenColumnType && StringUtils .hasText (tokenValue )) {
563
+ byte [] tokenValueAsBytes = tokenValue .getBytes (StandardCharsets .UTF_8 );
564
+ parameters .add (new SqlParameterValue (tokenColumnType , tokenValueAsBytes ));
565
+ } else {
566
+ parameters .add (new SqlParameterValue (tokenColumnType , tokenValue ));
567
+ }
568
+
538
569
parameters .add (new SqlParameterValue (Types .TIMESTAMP , tokenIssuedAt ));
539
570
parameters .add (new SqlParameterValue (Types .TIMESTAMP , tokenExpiresAt ));
540
571
parameters .add (new SqlParameterValue (Types .VARCHAR , metadata ));
@@ -551,6 +582,23 @@ private String writeMap(Map<String, Object> data) {
551
582
552
583
}
553
584
585
+ private static int getColumnDataType (JdbcOperations jdbcOperations , String columnName ){
586
+ return jdbcOperations .execute ((ConnectionCallback <Integer >) con -> {
587
+ DatabaseMetaData databaseMetaData = con .getMetaData ();
588
+ ResultSet rs = databaseMetaData .getColumns (null , null , TABLE_NAME , columnName );
589
+ if (rs .next ()) {
590
+ return rs .getInt ("DATA_TYPE" );
591
+ }
592
+ // NOTE: When using HSQL: When a database object is created with one of the CREATE statements if the name is enclosed in double quotes, the exact name is used as the case-normal form.
593
+ // But if it is not enclosed in double quotes, the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form
594
+ rs = databaseMetaData .getColumns (null , null , TABLE_NAME .toUpperCase (), columnName .toUpperCase ());
595
+ if (rs .next ()) {
596
+ return rs .getInt ("DATA_TYPE" );
597
+ }
598
+ return Types .NULL ;
599
+ });
600
+ }
601
+
554
602
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
555
603
private final LobCreator lobCreator ;
556
604
@@ -572,6 +620,15 @@ protected void doSetValue(PreparedStatement ps, int parameterPosition, Object ar
572
620
this .lobCreator .setBlobAsBytes (ps , parameterPosition , valueBytes );
573
621
return ;
574
622
}
623
+ if (paramValue .getSqlType () == Types .CLOB ) {
624
+ if (paramValue .getValue () != null ) {
625
+ Assert .isInstanceOf (String .class , paramValue .getValue (),
626
+ "Value of clob parameter must be String" );
627
+ }
628
+ String valueString = (String ) paramValue .getValue ();
629
+ this .lobCreator .setClobAsString (ps , parameterPosition , valueString );
630
+ return ;
631
+ }
575
632
}
576
633
super .doSetValue (ps , parameterPosition , argValue );
577
634
}
0 commit comments