@@ -5,13 +5,16 @@ use log::{debug, error};
5
5
use once_cell:: sync:: OnceCell ;
6
6
use regex:: { Regex , RegexSet } ;
7
7
use sqlparser:: ast:: Statement :: { Query , StartTransaction } ;
8
+ use sqlparser:: ast:: { BinaryOperator , Expr , SetExpr , Value } ;
8
9
use sqlparser:: dialect:: PostgreSqlDialect ;
9
10
use sqlparser:: parser:: Parser ;
10
11
11
12
use crate :: config:: Role ;
12
13
use crate :: pool:: PoolSettings ;
13
14
use crate :: sharding:: Sharder ;
14
15
16
+ use std:: collections:: BTreeSet ;
17
+
15
18
/// Regexes used to parse custom commands.
16
19
const CUSTOM_SQL_REGEXES : [ & str ; 7 ] = [
17
20
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$" ,
@@ -256,7 +259,7 @@ impl QueryRouter {
256
259
}
257
260
258
261
/// Try to infer which server to connect to based on the contents of the query.
259
- pub fn infer_role ( & mut self , mut buf : BytesMut ) -> bool {
262
+ pub fn infer ( & mut self , mut buf : BytesMut ) -> bool {
260
263
debug ! ( "Inferring role" ) ;
261
264
262
265
let code = buf. get_u8 ( ) as char ;
@@ -324,7 +327,21 @@ impl QueryRouter {
324
327
}
325
328
326
329
// Likely a read-only query
327
- Query { .. } => {
330
+ Query ( query) => {
331
+ match & self . pool_settings . automatic_sharding_key {
332
+ Some ( _) => {
333
+ // TODO: if we have multiple queries in the same message,
334
+ // we can either split them and execute them individually
335
+ // or discard shard selection. If they point to the same shard though,
336
+ // we can let them through as-is.
337
+ // This is basically building a database now :)
338
+ self . active_shard = self . infer_shard ( query) ;
339
+ debug ! ( "Automatically using shard: {:?}" , self . active_shard) ;
340
+ }
341
+
342
+ None => ( ) ,
343
+ } ;
344
+
328
345
self . active_role = match self . primary_reads_enabled ( ) {
329
346
false => Some ( Role :: Replica ) , // If primary should not be receiving reads, use a replica.
330
347
true => None , // Any server role is fine in this case.
@@ -342,6 +359,118 @@ impl QueryRouter {
342
359
true
343
360
}
344
361
362
+ /// A `selection` is the `WHERE` clause. This parses
363
+ /// the clause and extracts the sharding key, if present.
364
+ fn selection_parser ( & self , expr : & Expr ) -> Vec < i64 > {
365
+ let mut result = Vec :: new ( ) ;
366
+ let mut found = false ;
367
+
368
+ match expr {
369
+ // This parses `sharding_key = 5`. But it's technically
370
+ // legal to write `5 = sharding_key`. I don't judge the people
371
+ // who do that, but I think ORMs will still use the first variant,
372
+ // so we can leave the second as a TODO.
373
+ Expr :: BinaryOp { left, op, right } => {
374
+ match & * * left {
375
+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( & left) ) ,
376
+ Expr :: Identifier ( ident) => {
377
+ found = ident. value
378
+ == * self . pool_settings . automatic_sharding_key . as_ref ( ) . unwrap ( ) ;
379
+ }
380
+ _ => ( ) ,
381
+ } ;
382
+
383
+ match op {
384
+ BinaryOperator :: Eq => ( ) ,
385
+ BinaryOperator :: Or => ( ) ,
386
+ BinaryOperator :: And => ( ) ,
387
+ _ => {
388
+ // TODO: support other operators than equality.
389
+ debug ! ( "Unsupported operation: {:?}" , op) ;
390
+ return Vec :: new ( ) ;
391
+ }
392
+ } ;
393
+
394
+ match & * * right {
395
+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( & right) ) ,
396
+ Expr :: Value ( Value :: Number ( value, ..) ) => {
397
+ if found {
398
+ match value. parse :: < i64 > ( ) {
399
+ Ok ( value) => result. push ( value) ,
400
+ Err ( _) => {
401
+ debug ! ( "Sharding key was not an integer: {}" , value) ;
402
+ }
403
+ } ;
404
+ }
405
+ }
406
+ _ => ( ) ,
407
+ } ;
408
+ }
409
+
410
+ _ => ( ) ,
411
+ } ;
412
+
413
+ debug ! ( "Sharding keys found: {:?}" , result) ;
414
+
415
+ result
416
+ }
417
+
418
+ /// Try to figure out which shard the query should go to.
419
+ fn infer_shard ( & self , query : & sqlparser:: ast:: Query ) -> Option < usize > {
420
+ let mut shards = BTreeSet :: new ( ) ;
421
+
422
+ match & * query. body {
423
+ SetExpr :: Query ( query) => {
424
+ match self . infer_shard ( & * query) {
425
+ Some ( shard) => {
426
+ shards. insert ( shard) ;
427
+ }
428
+ None => ( ) ,
429
+ } ;
430
+ }
431
+
432
+ SetExpr :: Select ( select) => {
433
+ match & select. selection {
434
+ Some ( selection) => {
435
+ let sharding_keys = self . selection_parser ( & selection) ;
436
+
437
+ // TODO: Add support for prepared statements here.
438
+ // This should just give us the position of the value in the `B` message.
439
+
440
+ let sharder = Sharder :: new (
441
+ self . pool_settings . shards ,
442
+ self . pool_settings . sharding_function ,
443
+ ) ;
444
+
445
+ for value in sharding_keys {
446
+ let shard = sharder. shard ( value) ;
447
+ shards. insert ( shard) ;
448
+ }
449
+ }
450
+
451
+ None => ( ) ,
452
+ } ;
453
+ }
454
+ _ => ( ) ,
455
+ } ;
456
+
457
+ match shards. len ( ) {
458
+ // Didn't find a sharding key, you're on your own.
459
+ 0 => {
460
+ debug ! ( "No sharding keys found" ) ;
461
+ None
462
+ }
463
+
464
+ 1 => Some ( shards. into_iter ( ) . last ( ) . unwrap ( ) ) ,
465
+
466
+ // TODO: support querying multiple shards (some day...)
467
+ _ => {
468
+ debug ! ( "More than one sharding key found" ) ;
469
+ None
470
+ }
471
+ }
472
+ }
473
+
345
474
/// Get the current desired server role we should be talking to.
346
475
pub fn role ( & self ) -> Option < Role > {
347
476
self . active_role
@@ -392,7 +521,7 @@ mod test {
392
521
}
393
522
394
523
#[ test]
395
- fn test_infer_role_replica ( ) {
524
+ fn test_infer_replica ( ) {
396
525
QueryRouter :: setup ( ) ;
397
526
let mut qr = QueryRouter :: new ( ) ;
398
527
assert ! ( qr. try_execute_command( simple_query( "SET SERVER ROLE TO 'auto'" ) ) != None ) ;
@@ -410,13 +539,13 @@ mod test {
410
539
411
540
for query in queries {
412
541
// It's a recognized query
413
- assert ! ( qr. infer_role ( query) ) ;
542
+ assert ! ( qr. infer ( query) ) ;
414
543
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
415
544
}
416
545
}
417
546
418
547
#[ test]
419
- fn test_infer_role_primary ( ) {
548
+ fn test_infer_primary ( ) {
420
549
QueryRouter :: setup ( ) ;
421
550
let mut qr = QueryRouter :: new ( ) ;
422
551
@@ -429,24 +558,24 @@ mod test {
429
558
430
559
for query in queries {
431
560
// It's a recognized query
432
- assert ! ( qr. infer_role ( query) ) ;
561
+ assert ! ( qr. infer ( query) ) ;
433
562
assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
434
563
}
435
564
}
436
565
437
566
#[ test]
438
- fn test_infer_role_primary_reads_enabled ( ) {
567
+ fn test_infer_primary_reads_enabled ( ) {
439
568
QueryRouter :: setup ( ) ;
440
569
let mut qr = QueryRouter :: new ( ) ;
441
570
let query = simple_query ( "SELECT * FROM items WHERE id = 5" ) ;
442
571
assert ! ( qr. try_execute_command( simple_query( "SET PRIMARY READS TO on" ) ) != None ) ;
443
572
444
- assert ! ( qr. infer_role ( query) ) ;
573
+ assert ! ( qr. infer ( query) ) ;
445
574
assert_eq ! ( qr. role( ) , None ) ;
446
575
}
447
576
448
577
#[ test]
449
- fn test_infer_role_parse_prepared ( ) {
578
+ fn test_infer_parse_prepared ( ) {
450
579
QueryRouter :: setup ( ) ;
451
580
let mut qr = QueryRouter :: new ( ) ;
452
581
qr. try_execute_command ( simple_query ( "SET SERVER ROLE TO 'auto'" ) ) ;
@@ -461,7 +590,7 @@ mod test {
461
590
res. put ( prepared_stmt) ;
462
591
res. put_i16 ( 0 ) ;
463
592
464
- assert ! ( qr. infer_role ( res) ) ;
593
+ assert ! ( qr. infer ( res) ) ;
465
594
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
466
595
}
467
596
@@ -625,11 +754,11 @@ mod test {
625
754
assert_eq ! ( qr. role( ) , None ) ;
626
755
627
756
let query = simple_query ( "INSERT INTO test_table VALUES (1)" ) ;
628
- assert_eq ! ( qr. infer_role ( query) , true ) ;
757
+ assert_eq ! ( qr. infer ( query) , true ) ;
629
758
assert_eq ! ( qr. role( ) , Some ( Role :: Primary ) ) ;
630
759
631
760
let query = simple_query ( "SELECT * FROM test_table" ) ;
632
- assert_eq ! ( qr. infer_role ( query) , true ) ;
761
+ assert_eq ! ( qr. infer ( query) , true ) ;
633
762
assert_eq ! ( qr. role( ) , Some ( Role :: Replica ) ) ;
634
763
635
764
assert ! ( qr. query_parser_enabled( ) ) ;
@@ -644,12 +773,13 @@ mod test {
644
773
645
774
let pool_settings = PoolSettings {
646
775
pool_mode : PoolMode :: Transaction ,
647
- shards : 0 ,
776
+ shards : 2 ,
648
777
user : crate :: config:: User :: default ( ) ,
649
778
default_role : Some ( Role :: Replica ) ,
650
779
query_parser_enabled : true ,
651
780
primary_reads_enabled : false ,
652
781
sharding_function : ShardingFunction :: PgBigintHash ,
782
+ automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
653
783
} ;
654
784
let mut qr = QueryRouter :: new ( ) ;
655
785
assert_eq ! ( qr. active_role, None ) ;
@@ -672,20 +802,25 @@ mod test {
672
802
let q2 = simple_query ( "SET SERVER ROLE TO 'default'" ) ;
673
803
assert ! ( qr. try_execute_command( q2) != None ) ;
674
804
assert_eq ! ( qr. active_role. unwrap( ) , pool_settings. clone( ) . default_role) ;
805
+
806
+ // Here we go :)
807
+ let q3 = simple_query ( "SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)" ) ;
808
+ assert ! ( qr. infer( q3) ) ;
809
+ assert_eq ! ( qr. shard( ) , 1 ) ;
675
810
}
676
811
677
812
#[ test]
678
813
fn test_parse_multiple_queries ( ) {
679
814
QueryRouter :: setup ( ) ;
680
815
681
816
let mut qr = QueryRouter :: new ( ) ;
682
- assert ! ( qr. infer_role ( simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) ) ;
817
+ assert ! ( qr. infer ( simple_query( "BEGIN; SELECT 1; COMMIT;" ) ) ) ;
683
818
assert_eq ! ( qr. role( ) , Role :: Primary ) ;
684
819
685
- assert ! ( qr. infer_role ( simple_query( "SELECT 1; SELECT 2;" ) ) ) ;
820
+ assert ! ( qr. infer ( simple_query( "SELECT 1; SELECT 2;" ) ) ) ;
686
821
assert_eq ! ( qr. role( ) , Role :: Replica ) ;
687
822
688
- assert ! ( qr. infer_role ( simple_query(
823
+ assert ! ( qr. infer ( simple_query(
689
824
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
690
825
) ) ) ;
691
826
assert_eq ! ( qr. role( ) , Role :: Primary ) ;
0 commit comments