Skip to content

Allow reader/writer endpoints for pools #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ enum ClientConnectionType {
CancelQuery,
}

#[derive(Clone, Copy, Debug)]
pub enum ClientRoutingMode {
Default,
Reader,
Writer,
}

/// The client state. One of these is created per client.
pub struct Client<S, T> {
/// The reads are buffered (8K by default).
Expand Down Expand Up @@ -73,6 +80,8 @@ pub struct Client<S, T> {
last_server_id: Option<i32>,

target_pool: ConnectionPool,

routing_mode: ClientRoutingMode,
}

/// Client entrypoint.
Expand Down Expand Up @@ -264,10 +273,50 @@ where

trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {

let database_param = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let database_name_parts = database_param.split("/").collect::<Vec<&str>>();
let (database_name, routing_mode) = match database_name_parts.len() {
1 => (
database_name_parts[0].to_string(),
ClientRoutingMode::Default,
),
2 => match database_name_parts[1] {
"reader" => {
info!("Client connected in force reader mode");
(
database_name_parts[0].to_string(),
ClientRoutingMode::Reader,
)
}
"writer" => {
info!("Client connected in force writer mode");
(
database_name_parts[0].to_string(),
ClientRoutingMode::Writer,
)
}
_ => {
error_response(
&mut write,
&format!("Invalid database mode {}", database_name_parts[1]),
)
.await?;
return Err(Error::ClientError);
}
},
_ => {
error_response(
&mut write,
&format!("Invalid database name {}", database_param),
)
.await?;
return Err(Error::ClientError);
}
};

let user = match parameters.get("user") {
Some(user) => user,
Expand All @@ -276,7 +325,7 @@ where

let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.filter(|db| *db == &database_name)
.count()
== 1;

Expand Down Expand Up @@ -328,14 +377,14 @@ where
generate_server_info_for_admin(),
)
} else {
let target_pool = match get_pool(database.clone(), user.clone()) {
let target_pool = match get_pool(database_name.clone(), user.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
database_name, user
),
)
.await?;
Expand Down Expand Up @@ -375,6 +424,7 @@ where
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
routing_mode: routing_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
Expand Down Expand Up @@ -404,6 +454,7 @@ where
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: false,
routing_mode: ClientRoutingMode::Default,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
Expand Down Expand Up @@ -450,7 +501,9 @@ where

// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut query_router =
QueryRouter::new(self.target_pool.clone(), self.routing_mode.clone());

let mut round_robin = 0;

// Our custom protocol loop.
Expand Down
55 changes: 46 additions & 9 deletions src/query_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;

use crate::client::ClientRoutingMode;
use crate::config::Role;
use crate::pool::{ConnectionPool, PoolSettings};
use crate::sharding::{Sharder, ShardingFunction};
Expand Down Expand Up @@ -56,6 +57,8 @@ pub struct QueryRouter {
primary_reads_enabled: bool,

pool_settings: PoolSettings,

client_routing_mode: ClientRoutingMode,
}

impl QueryRouter {
Expand Down Expand Up @@ -91,13 +94,14 @@ impl QueryRouter {
}

/// Create a new instance of the query router. Each client gets its own.
pub fn new(target_pool: ConnectionPool) -> QueryRouter {
pub fn new(target_pool: ConnectionPool, client_routing_mode: ClientRoutingMode) -> QueryRouter {
QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: target_pool.settings.query_parser_enabled,
primary_reads_enabled: target_pool.settings.primary_reads_enabled,
pool_settings: target_pool.settings,
client_routing_mode: client_routing_mode,
}
}

Expand Down Expand Up @@ -339,7 +343,11 @@ impl QueryRouter {

/// Get the current desired server role we should be talking to.
pub fn role(&self) -> Option<Role> {
self.active_role
match self.client_routing_mode {
ClientRoutingMode::Default => self.active_role,
ClientRoutingMode::Reader => Some(Role::Replica),
ClientRoutingMode::Writer => Some(Role::Primary),
}
}

/// Get desired shard we should be talking to.
Expand Down Expand Up @@ -370,15 +378,15 @@ mod test {
#[test]
fn test_defaults() {
QueryRouter::setup();
let qr = QueryRouter::new(ConnectionPool::default());
let qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);

assert_eq!(qr.role(), None);
}

#[test]
fn test_infer_role_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true);

Expand All @@ -402,7 +410,7 @@ mod test {
#[test]
fn test_infer_role_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);

let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
Expand All @@ -421,7 +429,7 @@ mod test {
#[test]
fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);

Expand All @@ -432,7 +440,7 @@ mod test {
#[test]
fn test_infer_role_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand Down Expand Up @@ -523,7 +531,7 @@ mod test {
#[test]
fn test_try_execute_command() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);

// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
Expand Down Expand Up @@ -600,7 +608,7 @@ mod test {
#[test]
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default());
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Default);
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

Expand All @@ -621,4 +629,33 @@ mod test {
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
}

#[test]
fn test_client_routing_mode() {
QueryRouter::setup();
let mut qr = QueryRouter::new(ConnectionPool::default(), ClientRoutingMode::Reader);
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), Some(Role::Replica));

let query = simple_query("BEGIN");
assert_eq!(qr.infer_role(query), true);
assert_eq!(qr.role(), Some(Role::Replica));

let query = simple_query("INSERT INTO test_table VALUES (1)");
assert_eq!(qr.infer_role(query), true);
assert_eq!(qr.role(), Some(Role::Replica));

let query = simple_query("SELECT * FROM test_table");
assert_eq!(qr.infer_role(query), true);
assert_eq!(qr.role(), Some(Role::Replica));

assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
}
}
21 changes: 21 additions & 0 deletions tests/ruby/tests.rb
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def poorly_behaved_client
end


# Test reader/writer endpoints
def test_reader_writer_endpoints
conn = PG::connect("postgres://sharding_user:[email protected]:6432/sharded_db/reader?application_name=testing_pgcat")
conn.async_exec 'BEGIN'
conn.async_exec 'SELECT 1'
conn.async_exec 'COMMIT'
conn.close

conn = PG::connect("postgres://sharding_user:[email protected]:6432/sharded_db/writer?application_name=testing_pgcat")
conn.async_exec 'BEGIN'
conn.async_exec 'SELECT 1'
conn.async_exec 'COMMIT'
conn.close

puts 'Reader/Writer clients ok'
end

test_reader_writer_endpoints

def test_server_parameters
server_conn = PG::connect("postgres://sharding_user:[email protected]:6432/sharded_db?application_name=testing_pgcat")
raise StandardError, "Bad server version" if server_conn.server_version == 0
Expand All @@ -141,3 +160,5 @@ def test_server_parameters

puts 'Server parameters ok'
end

test_server_parameters