diff --git a/Cargo.lock b/Cargo.lock index 929f1a81..56015e4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,7 @@ dependencies = [ "tracing", "tracing-subscriber", "trust-dns-resolver", + "uuid", "webpki-roots", ] @@ -1899,6 +1900,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 805a4c7a..3fd69c55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ itertools = "0.10" clap = { version = "4.3.1", features = ["derive", "env"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]} +uuid = { version = "1.4.1", features = ["v4"] } [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/README.md b/README.md index ae310cde..065d2c96 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal |-------------|------------|--------------| | Transaction pooling | **Stable** | Identical to PgBouncer with notable improvements for handling bad clients and abandoned transactions. | | Session pooling | **Stable** | Identical to PgBouncer. | +| Transparent pooling | **Stable** | A new pooling mechanism that enables transparent (distributed) transactions. | | Multi-threaded runtime | **Stable** | Using Tokio asynchronous runtime, the pooler takes advantage of multicore machines. | | Load balancing of read queries | **Stable** | Queries are automatically load balanced between replicas and the primary. | | Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. | @@ -145,6 +146,9 @@ In transaction mode, a client talks to one server for the duration of a single t This mode is enabled by default. +### Transparent mode +In transparent mode, a client talks to one or more servers for the duration of a single transaction; once it's over, the servers are returned to the pool. `SET SHARD` and `SET SHARDING KEY` statements **are** supported, but prepared statements, other `SET` statements and advisory locks **are not** supported. + ### Load balancing of read queries All queries are load balanced against the configured servers using either the random or least open connections algorithms. The most straightforward configuration example would be to put this pooler in front of several replicas and let it load balance all queries. diff --git a/pgcat.toml b/pgcat.toml index 772a1365..579b9bec 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -143,7 +143,7 @@ result = [ # Pool mode (see PgBouncer docs for more). # `session` one server connection per connected client # `transaction` one server connection per client transaction -pool_mode = "transaction" +pool_mode = "transparent" # Load balancing mode # `random` selects the server at random @@ -280,8 +280,6 @@ username = "sharding_user" # if `server_password` is not set. password = "sharding_user" -pool_mode = "transaction" - # PostgreSQL username used to connect to the server. # server_username = "another_user" diff --git a/src/client.rs b/src/client.rs index 98a0669c..04d9f4e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,12 +1,15 @@ use crate::errors::{ClientIdentifier, Error}; use crate::pool::BanReason; +use crate::server_xact::TransactionState; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use once_cell::sync::Lazy; +use sqlparser::ast::Statement; use std::collections::HashMap; +use std::ops::ControlFlow; use std::sync::{atomic::AtomicUsize, Arc}; -use std::time::Instant; +use std::time::{Duration, Instant}; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; @@ -14,6 +17,7 @@ use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::auth_passthrough::refetch_auth_hash; +use crate::client_xact::*; use crate::config::{ get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, }; @@ -47,21 +51,21 @@ pub struct Client { /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: T, + pub(crate) write: T, /// Internal buffer, where we place messages until we have to flush /// them to the backend. buffer: BytesMut, /// Address - addr: std::net::SocketAddr, + pub(crate) addr: std::net::SocketAddr, /// The client was started with the sole reason to cancel another running query. cancel_mode: bool, /// In transaction mode, the connection is released after each transaction. /// Session mode has slightly higher throughput per client, but lower capacity. - transaction_mode: bool, + client_pool_mode: PoolMode, /// For query cancellation, the client is given a random process ID and secret on startup. process_id: i32, @@ -76,7 +80,7 @@ pub struct Client { parameters: HashMap, /// Statistics related to this client - stats: Arc, + pub(crate) stats: Arc, /// Clients want to talk to admin database. admin: bool, @@ -87,6 +91,9 @@ pub struct Client { /// Last server process stats we talked to. last_server_stats: Option>, + /// Last server key we talked to. + pub(crate) last_server_key: Option, + /// Connected to server connected_to_server: bool, @@ -97,13 +104,15 @@ pub struct Client { username: String, /// Server startup and session parameters that we're going to track - server_parameters: ServerParameters, + pub(crate) server_parameters: ServerParameters, /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, /// Prepared statements prepared_statements: HashMap, + + pub(crate) xact_info: ClientTxnMetaData, } /// Client entrypoint. @@ -136,7 +145,7 @@ pub async fn client_entrypoint( let mut yes = BytesMut::new(); yes.put_u8(b'S'); - write_all(&mut stream, yes).await?; + write_all(&mut stream, &yes).await?; // Negotiate TLS. match startup_tls(stream, client_server_map, shutdown, admin_only).await { @@ -171,7 +180,7 @@ pub async fn client_entrypoint( // Rejecting client request for TLS. let mut no = BytesMut::new(); no.put_u8(b'N'); - write_all(&mut stream, no).await?; + write_all(&mut stream, &no).await?; // Attempting regular startup. Client can disconnect now // if they choose. @@ -519,7 +528,7 @@ where }; // Authenticate admin user. - let (transaction_mode, mut server_parameters) = if admin { + let (client_pool_mode, mut server_parameters) = if admin { let config = get_config(); // Compare server and client hashes. @@ -538,7 +547,7 @@ where return Err(error); } - (false, generate_server_parameters_for_admin()) + (PoolMode::Session, generate_server_parameters_for_admin()) } // Authenticate normal user. else { @@ -650,7 +659,7 @@ where } } - let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; + let client_pool_mode = pool.settings.pool_mode; // If the pool hasn't been validated yet, // connect to the servers and figure out what's what. @@ -671,7 +680,7 @@ where } } - (transaction_mode, pool.server_parameters()) + (client_pool_mode, pool.server_parameters()) }; // Update the parameters to merge what the application sent and what's originally on the server @@ -680,7 +689,7 @@ where debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, (&server_parameters).into()).await?; + write_all(&mut write, &(&server_parameters).into()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; @@ -699,7 +708,7 @@ where addr, buffer: BytesMut::with_capacity(8196), cancel_mode: false, - transaction_mode, + client_pool_mode, process_id, secret_key, client_server_map, @@ -708,12 +717,14 @@ where admin, last_address_id: None, last_server_stats: None, + last_server_key: None, pool_name: pool_name.clone(), username: username.clone(), server_parameters, shutdown, connected_to_server: false, prepared_statements: HashMap::new(), + xact_info: Default::default(), }) } @@ -734,7 +745,7 @@ where addr, buffer: BytesMut::with_capacity(8196), cancel_mode: true, - transaction_mode: false, + client_pool_mode: PoolMode::Session, process_id, secret_key, client_server_map, @@ -743,42 +754,184 @@ where admin: false, last_address_id: None, last_server_stats: None, + last_server_key: None, pool_name: String::from("undefined"), username: String::from("undefined"), server_parameters: ServerParameters::new(), shutdown, connected_to_server: false, prepared_statements: HashMap::new(), + xact_info: Default::default(), }) } - /// Handle a connected and authenticated client. - pub async fn handle(&mut self) -> Result<(), Error> { - // The client wants to cancel a query it has issued previously. - if self.cancel_mode { - trace!("Sending CancelRequest"); + async fn handle_cancel_mode(&mut self) -> Result<(), Error> { + trace!("Sending CancelRequest"); + let (process_id, secret_key, address, port) = { + let guard = self.client_server_map.lock(); + + match guard.get(&(self.process_id, self.secret_key)) { + // Drop the mutex as soon as possible. + // We found the server the client is using for its query + // that it wants to cancel. + Some((process_id, secret_key, address, port)) => { + (*process_id, *secret_key, address.clone(), *port) + } + + // The client doesn't know / got the wrong server, + // we're closing the connection for security reasons. + None => return Ok(()), + } + }; + + // Opens a new separate connection to the server, sends the backend_id + // and secret_key and then closes it for security reasons. No other interactions + // take place. + Server::cancel(&address, port, process_id, secret_key).await + } - let (process_id, secret_key, address, port) = { - let guard = self.client_server_map.lock(); + #[allow(clippy::too_many_arguments)] + async fn handle_message_in_custom_protocol_loop( + &mut self, + mut message: BytesMut, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + will_prepare: &mut bool, + prepared_statements_enabled: &mut bool, + prepared_statement: &mut Option, + plugin_output: &mut Option, + ) -> Result>), ()>, Error> { + let mut initial_parsed_ast = None; + match message[0] as char { + // Buffer extended protocol messages even if we do not have + // a server connection yet. Hopefully, when we get the S message + // we'll be able to allocate a connection. Also, clients do not expect + // the server to respond to these messages so even if we were not able to + // allocate a connection, we wouldn't be able to send back an error message + // to the client so we buffer them and defer the decision to error out or not + // to when we get the S message + 'D' => { + if *prepared_statements_enabled { + let name; + (name, message) = self.rewrite_describe(message).await?; + + if let Some(name) = name { + *prepared_statement = Some(name); + } + } - match guard.get(&(self.process_id, self.secret_key)) { - // Drop the mutex as soon as possible. - // We found the server the client is using for its query - // that it wants to cancel. - Some((process_id, secret_key, address, port)) => { - (*process_id, *secret_key, address.clone(), *port) + self.buffer.put(&message[..]); + return Ok(ControlFlow::Continue(())); + } + + 'E' => { + self.buffer.put(&message[..]); + return Ok(ControlFlow::Continue(())); + } + + 'Q' => { + if query_router.query_parser_enabled() { + match query_router.parse(&message) { + Ok(ast) => { + let plugin_result = query_router.execute_plugins(&ast).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + return Ok(ControlFlow::Continue(())); + } + + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, &result).await?; + return Ok(ControlFlow::Continue(())); + } + + _ => (), + }; + + let _ = query_router.infer(&ast); + + initial_parsed_ast = Some(ast); + } + Err(error) => { + if error != Error::UnsupportedStatement { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + } } + } + } - // The client doesn't know / got the wrong server, - // we're closing the connection for security reasons. - None => return Ok(()), + 'P' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_parse(message)?; + *will_prepare = true; } - }; - // Opens a new separate connection to the server, sends the backend_id - // and secret_key and then closes it for security reasons. No other interactions - // take place. - return Server::cancel(&address, port, process_id, secret_key).await; + self.buffer.put(&message[..]); + + if query_router.query_parser_enabled() { + match query_router.parse(&message) { + Ok(ast) => { + if let Ok(output) = query_router.execute_plugins(&ast).await { + *plugin_output = Some(output); + } + + let _ = query_router.infer(&ast); + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + }; + } + + return Ok(ControlFlow::Continue(())); + } + + 'B' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_bind(message).await?; + } + + self.buffer.put(&message[..]); + + if query_router.query_parser_enabled() { + query_router.infer_shard_from_bind(&message); + } + + return Ok(ControlFlow::Continue(())); + } + + // Close (F) + 'C' => { + if *prepared_statements_enabled { + let close: Close = (&message).try_into()?; + + if close.is_prepared_statement() && !close.anonymous() { + self.prepared_statements.remove(&close.name); + write_all_flush(&mut self.write, &close_complete()).await?; + return Ok(ControlFlow::Continue(())); + } + } + } + + _ => (), + } + + Ok(ControlFlow::Break((message, initial_parsed_ast))) + } + + /// Handle a connected and authenticated client. + pub async fn handle(&mut self) -> Result<(), Error> { + // The client wants to cancel a query it has issued previously. + if self.cancel_mode { + return self.handle_cancel_mode().await; } // The query router determines where the query is going to go, @@ -805,8 +958,8 @@ where // or issue commands for our sharding and server selection protocol. loop { trace!( - "Client idle, waiting for message, transaction mode: {}", - self.transaction_mode + "Client idle, waiting for message, pool mode: {:?}", + self.client_pool_mode ); // Should we rewrite prepared statements and bind messages? @@ -859,161 +1012,323 @@ where let mut pool = self.get_pool().await?; query_router.update_pool_settings(pool.settings.clone()); - let mut initial_parsed_ast = None; + let initial_parsed_ast; + + (message, initial_parsed_ast) = match self + .handle_message_in_custom_protocol_loop( + message, + &client_identifier, + &mut query_router, + &mut will_prepare, + &mut prepared_statements_enabled, + &mut prepared_statement, + &mut plugin_output, + ) + .await? + { + ControlFlow::Break(res) => res, - match message[0] as char { - // Buffer extended protocol messages even if we do not have - // a server connection yet. Hopefully, when we get the S message - // we'll be able to allocate a connection. Also, clients do not expect - // the server to respond to these messages so even if we were not able to - // allocate a connection, we wouldn't be able to send back an error message - // to the client so we buffer them and defer the decision to error out or not - // to when we get the S message - 'D' => { - if prepared_statements_enabled { - let name; - (name, message) = self.rewrite_describe(message).await?; + ControlFlow::Continue(()) => { + continue; + } + }; - if let Some(name) = name { - prepared_statement = Some(name); - } - } + // Check on plugin results. + if let Some(PluginOutput::Deny(error)) = plugin_output { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; + }; - self.buffer.put(&message[..]); - continue; + // Check if the pool is paused and wait until it's resumed. + if pool.wait_paused().await { + // Refresh pool information, something might have changed. + pool = self.get_pool().await?; + } + + query_router.update_pool_settings(pool.settings.clone()); + + // Reset transaction state, as we are entering a new transaction loop. + self.reset_client_xact(); + + let mut all_conns: HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + > = HashMap::new(); + + match self + .handle_transaction_loop( + &client_identifier, + &mut query_router, + &pool, + &mut all_conns, + Some(message), + initial_parsed_ast, + &mut will_prepare, + &mut prepared_statements_enabled, + &mut prepared_statement, + &mut plugin_output, + ) + .await? + { + Some(ControlFlow::Break(())) => { + return Ok(()); } - 'E' => { - self.buffer.put(&message[..]); + Some(ControlFlow::Continue(())) => { continue; } - 'Q' => { - if query_router.query_parser_enabled() { - match query_router.parse(&message) { - Ok(ast) => { - let plugin_result = query_router.execute_plugins(&ast).await; + _ => (), + } - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; - continue; - } + self.cleanup_custom_protocol_loop_helper(all_conns, prepared_statements_enabled) + .await?; + } + } - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } + async fn cleanup_custom_protocol_loop_helper( + &mut self, + mut all_conns: HashMap< + usize, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + prepared_statements_enabled: bool, + ) -> Result<(), Error> { + self.distributed_commit_or_abort(&mut all_conns).await?; - _ => (), - }; + // Reset transaction state for safety reasons. Even if this state will be reset before + // the next transaction, this dirty state could be seen in-between here and there. + self.reset_client_xact(); - let _ = query_router.infer(&ast); + debug!("Releasing servers back into the pool"); + for conn in all_conns.values_mut() { + let server = &mut *conn.0; + let address = &conn.1; - initial_parsed_ast = Some(ast); - } - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - } - } - } - } + // The server is no longer bound to us, we can't cancel it's queries anymore. + debug!("Releasing server back into the pool: {}", address); - 'P' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_parse(message)?; - will_prepare = true; - } + server.checkin_cleanup().await?; - self.buffer.put(&message[..]); + if prepared_statements_enabled { + server.maintain_cache().await?; + } - if query_router.query_parser_enabled() { - match query_router.parse(&message) { - Ok(ast) => { - if let Ok(output) = query_router.execute_plugins(&ast).await { - plugin_output = Some(output); - } + server.stats().idle(); + } + self.connected_to_server = false; + self.release(); + self.stats.idle(); + Ok(()) + } + async fn read_message_helper( + &mut self, + idle_client_timeout_duration: Duration, + query_router: &mut QueryRouter, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + initial_message: Option, + initial_parsed_ast: &mut Option>, + ) -> Result, Error> { + match initial_message { + None => { + trace!("Waiting for message inside transaction or in session mode"); - let _ = query_router.infer(&ast); - } - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - } - }; - } + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); - continue; - } + match tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) + .await + { + Ok(Ok(message)) => Ok(ControlFlow::Continue(message)), + Ok(Err(err)) => { + // Client disconnected inside a transaction. + // Clean up the server and re-use it. + self.stats.disconnect(); + for conn in all_conns.values_mut() { + let server = &mut *conn.0; + server.checkin_cleanup().await?; + } - 'B' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_bind(message).await?; + Err(err) } + Err(_) => { + // Client idle in transaction timeout + error_response_with_state( + &mut self.write, + "idle transaction timeout", + self.xact_info.state(), + ) + .await?; + error!( + "Client idle in transaction timeout: \ + {{ \ + pool_name: {}, \ + username: {}, \ + shard: {:?}, \ + role: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); - self.buffer.put(&message[..]); - - if query_router.query_parser_enabled() { - query_router.infer_shard_from_bind(&message); + Ok(ControlFlow::Break(())) } - - continue; } + } - // Close (F) - 'C' => { - if prepared_statements_enabled { - let close: Close = (&message).try_into()?; + Some(message) => Ok(ControlFlow::Continue(message)), + } + } - if close.is_prepared_statement() && !close.anonymous() { - self.prepared_statements.remove(&close.name); - write_all_flush(&mut self.write, &close_complete()).await?; - continue; - } + async fn handle_begin_statement( + &mut self, + code: char, + message: &BytesMut, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + initial_parsed_ast: &mut Option>, + ) -> Result>, ()>, Error> { + // Query + if code == 'Q' { + // If the first message is a `BEGIN` statement, then we are starting a + // transaction. However, we might not still be on the right shard (as the + // shard might be inferred from the first query). So we parse the query and + // store the `BEGIN` statement. Upon receiving the next query (and possibly + // determining the shard), we will execute the `BEGIN` statement. + if let Some(ast_vec) = initial_parsed_ast { + if Self::is_begin_statement(ast_vec) { + assert_eq!(ast_vec.len(), 1); + + let begin_stmt = &ast_vec[0]; + + if let Statement::StartTransaction { .. } = begin_stmt { + // This is the first BEGIN statement. We need to register it for later executions. + self.xact_info.set_begin_statement(Some(begin_stmt.clone())); + + self.xact_info.set_state(TransactionState::InTransaction); + } else { + panic!("Expected BEGIN statement, got {:?}", begin_stmt); } + + custom_protocol_response_ok_with_state( + &mut self.write, + "BEGIN", + self.xact_info.state(), + ) + .await?; + + return Ok(ControlFlow::Continue(())); } + } - _ => (), + if query_router.query_parser_enabled() { + return self + .parse_ast_helper(query_router, initial_parsed_ast, message, client_identifier) + .await; } + }; + Ok(ControlFlow::Break(None)) + } - // Check on plugin results. - if let Some(PluginOutput::Deny(error)) = plugin_output { - self.buffer.clear(); - error_response(&mut self.write, &error).await?; - plugin_output = None; - continue; + #[allow(clippy::too_many_arguments)] + async fn handle_transaction_loop<'a, 'b>( + &mut self, + client_identifier: &ClientIdentifier, + query_router: &mut QueryRouter, + pool: &'a ConnectionPool, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'b, crate::pool::ServerPool>, Address), + >, + mut initial_message: Option, + mut initial_parsed_ast: Option>, + will_prepare: &mut bool, + prepared_statements_enabled: &mut bool, + prepared_statement: &mut Option, + plugin_output: &mut Option, + ) -> Result>, Error> + where + 'a: 'b, + { + let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { + 0 => tokio::time::Duration::MAX, + timeout => tokio::time::Duration::from_millis(timeout), + }; + + // Transaction loop. Multiple queries can be issued by the client here. + // The connection belongs to the client until the transaction is over, + // or until the client disconnects if we are in session mode. + // + // If the client is in session or transaction modes, no more custom protocol + // commands will be accepted. However, in transparent mode, the `SET SHARD` and + // `SET SHARDING KEY` custom protocol commands are still accepted. + loop { + let is_first_message_to_server = all_conns.is_empty(); + + let mut message: BytesMut = match self + .read_message_helper( + idle_client_timeout_duration, + query_router, + all_conns, + initial_message.take(), + &mut initial_parsed_ast, + ) + .await? + { + ControlFlow::Continue(message) => message, + ControlFlow::Break(_) => { + break; + } }; - // Check if the pool is paused and wait until it's resumed. - if pool.wait_paused().await { - // Refresh pool information, something might have changed. - pool = self.get_pool().await?; - } + // Safe to unwrap because we know this message has a certain length and has the code + // This reads the first byte without advancing the internal pointer and mutating the bytes + let code = *message.first().unwrap() as char; + let mut ast = match self + .handle_begin_statement( + code, + &message, + client_identifier, + query_router, + &mut initial_parsed_ast, + ) + .await? + { + ControlFlow::Continue(()) => { + continue; + } - query_router.update_pool_settings(pool.settings.clone()); + ControlFlow::Break(ast) => ast, + }; - let current_shard = query_router.shard(); + if all_conns.is_empty() || self.is_transparent_mode() { + let current_shard = query_router.shard(); - // Handle all custom protocol commands, if any. - match query_router.try_execute_command(&message) { - // Normal query, not a custom command. - None => (), + // Handle all custom protocol commands, if any. + match query_router.try_execute_command(&message) { + // Normal query, not a custom command. + None => (), - // SET SHARD TO - Some((Command::SetShard, _)) => { - match query_router.shard() { - None => (), - Some(selected_shard) => { - if selected_shard >= pool.shards() { - // Bad shard number, send error message to client. - query_router.set_shard(current_shard); + // SET SHARD TO + Some((Command::SetShard, _)) => { + match query_router.shard() { + None => (), + Some(selected_shard) => { + if selected_shard >= pool.shards() { + // Bad shard number, send error message to client. + query_router.set_shard(current_shard); - error_response( + error_response_with_state( &mut self.write, &format!( "shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)", @@ -1021,539 +1336,615 @@ where pool.shards(), current_shard, ), + self.xact_info.state(), ) .await?; - } else { - custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + } else { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARD", + self.xact_info.state(), + ) + .await?; + } } } + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } } - continue; - } - - // SET PRIMARY READS TO - Some((Command::SetPrimaryReads, _)) => { - custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; - continue; - } - // SET SHARDING KEY TO - Some((Command::SetShardingKey, _)) => { - custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; - continue; - } + // SET PRIMARY READS TO + Some((Command::SetPrimaryReads, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET PRIMARY READS", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } - // SET SERVER ROLE TO - Some((Command::SetServerRole, _)) => { - custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; - continue; - } + // SET SHARDING KEY TO + Some((Command::SetShardingKey, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SHARDING KEY", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } - // SHOW SERVER ROLE - Some((Command::ShowServerRole, value)) => { - show_response(&mut self.write, "server role", &value).await?; - continue; - } + // SET SERVER ROLE TO + Some((Command::SetServerRole, _)) => { + custom_protocol_response_ok_with_state( + &mut self.write, + "SET SERVER ROLE", + self.xact_info.state(), + ) + .await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } - // SHOW SHARD - Some((Command::ShowShard, value)) => { - show_response(&mut self.write, "shard", &value).await?; - continue; - } + // SHOW SERVER ROLE + Some((Command::ShowServerRole, value)) => { + show_response(&mut self.write, "server role", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } - // SHOW PRIMARY READS - Some((Command::ShowPrimaryReads, value)) => { - show_response(&mut self.write, "primary reads", &value).await?; - continue; - } - }; + // SHOW SHARD + Some((Command::ShowShard, value)) => { + show_response(&mut self.write, "shard", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } - debug!("Waiting for connection from pool"); - if !self.admin { - self.stats.waiting(); - } + // SHOW PRIMARY READS + Some((Command::ShowPrimaryReads, value)) => { + show_response(&mut self.write, "primary reads", &value).await?; + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } + }; - // Grab a server from the pool. - let connection = match pool - .get(query_router.shard(), query_router.role(), &self.stats) - .await - { - Ok(conn) => { - debug!("Got connection from pool"); - conn + debug!("Waiting for connection from pool"); + if !self.admin { + self.stats.waiting(); } - Err(err) => { - // Client is attempting to get results from the server, - // but we were unable to grab a connection from the pool - // We'll send back an error message and clean the extended - // protocol buffer - self.stats.idle(); - - if message[0] as char == 'S' { - error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); - } - error_response( - &mut self.write, - format!("could not get connection from the pool - {}", err).as_str(), - ) - .await?; + let query_router_shard = query_router.shard(); + let server_key = query_router_shard.unwrap_or(0); + let mut conn_opt = all_conns.get_mut(&server_key); + if conn_opt.is_none() { + // Grab a server from the pool. + let connection = match pool + .get(query_router_shard, query_router.role(), &self.stats) + .await + { + Ok(conn) => { + debug!("Got connection from pool"); + conn + } + Err(err) => { + // Client is attempting to get results from the server, + // but we were unable to grab a connection from the pool + // We'll send back an error message and clean the extended + // protocol buffer + self.stats.idle(); - error!( - "Could not get connection from pool: \ - {{ \ - pool_name: {:?}, \ - username: {:?}, \ - shard: {:?}, \ - role: \"{:?}\", \ - error: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role(), - err - ); + if message[0] as char == 'S' { + error!( + "Got Sync message but failed to get a connection from the pool" + ); + self.buffer.clear(); + } - continue; - } - }; + error_response_with_state( + &mut self.write, + format!("could not get connection from the pool - {}", err) + .as_str(), + self.xact_info.state(), + ) + .await?; - let mut reference = connection.0; - let address = connection.1; - let server = &mut *reference; + error!( + "Could not get connection from pool: \ + {{ \ + pool_name: {:?}, \ + username: {:?}, \ + shard: {:?}, \ + role: \"{:?}\", \ + error: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role(), + err + ); + + if self.is_in_idle_transaction() { + return Ok(Some(ControlFlow::Continue(()))); + } else { + continue; + } + } + }; - // Server is assigned to the client in case the client wants to - // cancel a query later. - server.claim(self.process_id, self.secret_key); - self.connected_to_server = true; + // Before inserting this new connection, if we had only a single connection + // before, then it means that we have started a distributed transaction. + // At this point, we need to do some prep on the first server. + if all_conns.len() == 1 { + let (first_server_key, first_conn) = all_conns.iter_mut().next().unwrap(); + let first_server = &mut *first_conn.0; - // Update statistics - self.stats.active(); + if !self.acquire_gid(*first_server_key, first_server).await? { + break; + } + } - self.last_address_id = Some(address.id); - self.last_server_stats = Some(server.stats()); + all_conns.insert(server_key, connection); + + let is_distributed_xact = all_conns.len() > 1; + conn_opt = if is_distributed_xact { + let conn_opt = all_conns.get_mut(&server_key); + let conn = conn_opt.unwrap(); + let server = &mut *conn.0; + let address = &conn.1; + + debug!( + "Sending implicit BEGIN statement to server {} (in transparent mode with distributed transaction)", + address + ); + if !self.begin_distributed_xact(server_key, server).await? { + break; + } + Some(conn) + } else { + all_conns.get_mut(&server_key) + } + } + let conn = conn_opt.unwrap(); + let address = &conn.1; + let server = &mut *conn.0; - debug!( - "Client {:?} talking to server {:?}", - self.addr, - server.address() - ); + // Server is assigned to the client in case the client wants to + // cancel a query later. + server.claim(self.process_id, self.secret_key); + self.connected_to_server = true; - server.sync_parameters(&self.server_parameters).await?; + // Update statistics + self.stats.active(); - let mut initial_message = Some(message); + self.last_address_id = Some(address.id); + self.last_server_stats = Some(server.stats()); + self.last_server_key = Some(server_key); - let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() { - 0 => tokio::time::Duration::MAX, - timeout => tokio::time::Duration::from_millis(timeout), - }; + debug!( + "Client {:?} talking to server {:?}", + self.addr, + server.address() + ); - // Transaction loop. Multiple queries can be issued by the client here. - // The connection belongs to the client until the transaction is over, - // or until the client disconnects if we are in session mode. - // - // If the client is in session mode, no more custom protocol - // commands will be accepted. - loop { - // Only check if we should rewrite prepared statements - // in session mode. In transaction mode, we check at the beginning of - // each transaction. - if !self.transaction_mode { - prepared_statements_enabled = get_prepared_statements(); - } + server.sync_parameters(&self.server_parameters).await?; + } - debug!("Prepared statement active: {:?}", prepared_statement); + self.assign_client_transaction_state(all_conns); - // We are processing a prepared statement. - if let Some(ref name) = prepared_statement { - debug!("Checking prepared statement is on server"); - // Get the prepared statement the server expects to see. - let statement = match self.prepared_statements.get(name) { - Some(statement) => { - debug!("Prepared statement `{}` found in cache", name); - statement - } - None => { - return Err(Error::ClientError(format!( - "prepared statement `{}` not found", - name - ))) - } - }; + let is_distributed_xact = all_conns.len() > 1; + let server_key = query_router.shard().unwrap_or(0); + let conn = all_conns.get_mut(&server_key).unwrap(); + let server = &mut *conn.0; + let address = &conn.1; - // Since it's already in the buffer, we don't need to prepare it on this server. - if will_prepare { - server.will_prepare(&statement.name); - will_prepare = false; - } else { - // The statement is not prepared on the server, so we need to prepare it. - if server.should_prepare(&statement.name) { - match server.prepare(statement).await { - Ok(_) => (), - Err(err) => { - pool.ban( - &address, - BanReason::MessageSendFailed, - Some(&self.stats), - ); - return Err(err); - } - } - } - } + // Only check if we should rewrite prepared statements + // in session mode. In transaction mode, we check at the beginning of + // each transaction. + if !self.is_transaction_mode() { + *prepared_statements_enabled = get_prepared_statements(); + } - // Done processing the prepared statement. - prepared_statement = None; - } + debug!("Prepared statement active: {:?}", *prepared_statement); - let mut message = match initial_message { + // We are processing a prepared statement. + if let Some(ref name) = *prepared_statement { + debug!("Checking prepared statement is on server"); + // Get the prepared statement the server expects to see. + let statement = match self.prepared_statements.get(name) { + Some(statement) => { + debug!("Prepared statement `{}` found in cache", name); + statement + } None => { - trace!("Waiting for message inside transaction or in session mode"); - - // This is not an initial message so discard the initial_parsed_ast - initial_parsed_ast.take(); - - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { - Ok(Ok(message)) => message, - Ok(Err(err)) => { - // Client disconnected inside a transaction. - // Clean up the server and re-use it. - self.stats.disconnect(); - server.checkin_cleanup().await?; + return Err(Error::ClientError(format!( + "prepared statement `{}` not found", + name + ))) + } + }; + // Since it's already in the buffer, we don't need to prepare it on this server. + if *will_prepare { + server.will_prepare(&statement.name); + *will_prepare = false; + } else { + // The statement is not prepared on the server, so we need to prepare it. + if server.should_prepare(&statement.name) { + match server.prepare(statement).await { + Ok(_) => (), + Err(err) => { + pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats)); return Err(err); } - Err(_) => { - // Client idle in transaction timeout - error_response(&mut self.write, "idle transaction timeout").await?; - error!( - "Client idle in transaction timeout: \ - {{ \ - pool_name: {}, \ - username: {}, \ - shard: {:?}, \ - role: \"{:?}\" \ - }}", - self.pool_name, - self.username, - query_router.shard(), - query_router.role() - ); - - break; - } } } + } - Some(message) => { - initial_message = None; - message - } - }; - - // The message will be forwarded to the server intact. We still would like to - // parse it below to figure out what to do with it. - - // Safe to unwrap because we know this message has a certain length and has the code - // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.first().unwrap() as char; - - trace!("Message: {}", code); - - match code { - // Query - 'Q' => { - if query_router.query_parser_enabled() { - // We don't want to parse again if we already parsed it as the initial message - let ast = match initial_parsed_ast { - Some(_) => Some(initial_parsed_ast.take().unwrap()), - None => match query_router.parse(&message) { - Ok(ast) => Some(ast), - Err(error) => { - warn!( - "Query parsing error: {} (client: {})", - error, client_identifier - ); - None - } - }, - }; + // Done processing the prepared statement. + *prepared_statement = None; + } - if let Some(ast) = ast { - let plugin_result = query_router.execute_plugins(&ast).await; + // The message will be forwarded to the server intact. We still would like to + // parse it below to figure out what to do with it. - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; - continue; - } + trace!("Message: {}", code); - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } + match code { + // Query + 'Q' => { + if is_distributed_xact { + // if we are in a distributed transaction, we need to parse the query + // to figure out if it's a COMMIT or ABORT statement. + // If query parsing is disabled, we need to parse it here. Otherwise, + // it's already parsed. + if !query_router.query_parser_enabled() { + assert_eq!(ast, None); + ast = match self + .parse_ast_helper( + query_router, + &mut initial_parsed_ast, + &message, + client_identifier, + ) + .await? + { + ControlFlow::Continue(()) => { + continue; + } - _ => (), - }; + ControlFlow::Break(ast) => ast, } } - debug!("Sending query to server"); - - self.send_and_receive_loop( - code, - Some(&message), - server, - &address, - &pool, - &self.stats.clone(), - ) - .await?; - - if !server.in_transaction() { - // Report transaction executed statistics. - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); - - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode && !server.in_copy_mode() { - self.stats.idle(); + if let Some(ast) = &ast { + if is_distributed_xact && self.set_commit_or_abort_statement(ast) { break; } } } + debug!("Sending query to server"); + + // If this is the first message that we're actually sending to the server, + // we need to send the 'BEGIN' statement first if it was issued before this. + // This is the case when the client is in transparent mode and A 'BEGIN' + // statement was issued before the first query. + if is_first_message_to_server { + if let Some(begin_stmt) = self.xact_info.get_begin_statement() { + let begin_stmt = begin_stmt.clone(); + let res = server.query(&begin_stmt.to_string()).await; + + if self.post_query_processing(server, res).await?.is_none() { + break; + } - // Terminate - 'X' => { - server.checkin_cleanup().await?; - self.stats.disconnect(); - self.release(); + self.initialize_xact_params(server, &begin_stmt); - if prepared_statements_enabled { - server.maintain_cache().await?; + assert!(server.in_transaction()); } - - return Ok(()); } - // Parse - // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - 'P' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_parse(message)?; - will_prepare = true; - } + self.send_and_receive_loop( + code, + Some(&message), + server, + address, + pool, + &self.stats.clone(), + ) + .await?; - if query_router.query_parser_enabled() { - if let Ok(ast) = query_router.parse(&message) { - if let Ok(output) = query_router.execute_plugins(&ast).await { - plugin_output = Some(output); - } - } - } + if !server.in_transaction() { + // Report transaction executed statistics. + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); + + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { + self.stats.idle(); - self.buffer.put(&message[..]); + break; + } } + } - // Bind - // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' - 'B' => { - if prepared_statements_enabled { - (prepared_statement, message) = self.rewrite_bind(message).await?; - } + // Terminate + 'X' => { + server.checkin_cleanup().await?; + self.stats.disconnect(); + self.release(); - self.buffer.put(&message[..]); + if *prepared_statements_enabled { + server.maintain_cache().await?; } - // Describe - // Command a client can issue to describe a previously prepared named statement. - 'D' => { - if prepared_statements_enabled { - let name; - (name, message) = self.rewrite_describe(message).await?; + return Ok(Some(ControlFlow::Break(()))); + } + + // Parse + // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. + 'P' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_parse(message)?; + *will_prepare = true; + } - if let Some(name) = name { - prepared_statement = Some(name); + if query_router.query_parser_enabled() { + if let Ok(ast) = query_router.parse(&message) { + if let Ok(output) = query_router.execute_plugins(&ast).await { + *plugin_output = Some(output); } } + } + + self.buffer.put(&message[..]); + } - self.buffer.put(&message[..]); + // Bind + // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' + 'B' => { + if *prepared_statements_enabled { + (*prepared_statement, message) = self.rewrite_bind(message).await?; } - // Close the prepared statement. - 'C' => { - if prepared_statements_enabled { - let close: Close = (&message).try_into()?; + self.buffer.put(&message[..]); + } - if close.is_prepared_statement() && !close.anonymous() { - if let Some(parse) = self.prepared_statements.get(&close.name) { - server.will_close(&parse.generated_name); - } else { - // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. - }; - } - } + // Describe + // Command a client can issue to describe a previously prepared named statement. + 'D' => { + if *prepared_statements_enabled { + let name; + (name, message) = self.rewrite_describe(message).await?; - self.buffer.put(&message[..]); + if let Some(name) = name { + *prepared_statement = Some(name); + } } - // Execute - // Execute a prepared statement prepared in `P` and bound in `B`. - 'E' => { - self.buffer.put(&message[..]); + self.buffer.put(&message[..]); + } + + // Close the prepared statement. + 'C' => { + if *prepared_statements_enabled { + let close: Close = (&message).try_into()?; + + if close.is_prepared_statement() && !close.anonymous() { + if let Some(parse) = self.prepared_statements.get(&close.name) { + server.will_close(&parse.generated_name); + } else { + // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. + }; + } } - // Sync - // Frontend (client) is asking for the query result now. - 'S' => { - debug!("Sending query to server"); + self.buffer.put(&message[..]); + } - match plugin_output { - Some(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; - plugin_output = None; - self.buffer.clear(); - continue; - } + // Execute + // Execute a prepared statement prepared in `P` and bound in `B`. + 'E' => { + self.buffer.put(&message[..]); + } - Some(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - plugin_output = None; - self.buffer.clear(); - continue; - } + // Sync + // Frontend (client) is asking for the query result now. + 'S' => { + debug!("Sending query to server"); + + match plugin_output { + Some(PluginOutput::Deny(error)) => { + error_response_with_state( + &mut self.write, + error, + self.xact_info.state(), + ) + .await?; + *plugin_output = None; + self.buffer.clear(); + continue; + } - _ => (), - }; + Some(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + *plugin_output = None; + self.buffer.clear(); + continue; + } - self.buffer.put(&message[..]); + _ => (), + }; - let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; + self.buffer.put(&message[..]); - // Almost certainly true - if first_message_code == 'P' && !prepared_statements_enabled { - // Message layout - // P followed by 32 int followed by null-terminated statement name - // So message code should be in offset 0 of the buffer, first character - // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); - if first_char_in_name != 0 { - // This is a named prepared statement - // Server connection state will need to be cleared at checkin - server.mark_dirty(); - } + let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; + + // Almost certainly true + if first_message_code == 'P' && !*prepared_statements_enabled { + // Message layout + // P followed by 32 int followed by null-terminated statement name + // So message code should be in offset 0 of the buffer, first character + // in prepared statement name would be index 5 + let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + if first_char_in_name != 0 { + // This is a named prepared statement + // Server connection state will need to be cleared at checkin + server.mark_dirty(); } + } - self.send_and_receive_loop( - code, - None, - server, - &address, - &pool, - &self.stats.clone(), - ) - .await?; + self.send_and_receive_loop( + code, + None, + server, + address, + pool, + &self.stats.clone(), + ) + .await?; - self.buffer.clear(); + self.buffer.clear(); - if !server.in_transaction() { - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); + if !server.in_transaction() { + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode && !server.in_copy_mode() { - break; - } + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && !server.in_copy_mode() + { + break; } } + } - // CopyData - 'd' => { - self.buffer.put(&message[..]); + // CopyData + 'd' => { + self.buffer.put(&message[..]); - // Want to limit buffer size - if self.buffer.len() > 8196 { - // Forward the data to the server, - self.send_server_message(server, &self.buffer, &address, &pool) - .await?; - self.buffer.clear(); - } + // Want to limit buffer size + if self.buffer.len() > 8196 { + // Forward the data to the server, + self.send_server_message(server, &self.buffer, address, pool) + .await?; + self.buffer.clear(); } + } - // CopyDone or CopyFail - // Copy is done, successfully or not. - 'c' | 'f' => { - // We may already have some copy data in the buffer, add this message to buffer - self.buffer.put(&message[..]); + // CopyDone or CopyFail + // Copy is done, successfully or not. + 'c' | 'f' => { + // We may already have some copy data in the buffer, add this message to buffer + self.buffer.put(&message[..]); - self.send_server_message(server, &self.buffer, &address, &pool) - .await?; + self.send_server_message(server, &self.buffer, address, pool) + .await?; - // Clear the buffer - self.buffer.clear(); + // Clear the buffer + self.buffer.clear(); - let response = self - .receive_server_message(server, &address, &pool, &self.stats.clone()) - .await?; + let response = self + .receive_server_message(server, address, pool, &self.stats.clone()) + .await?; - match write_all_flush(&mut self.write, &response).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; + match write_all_flush(&mut self.write, &response).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; - if !server.in_transaction() { - self.stats.transaction(); - server - .stats() - .transaction(self.server_parameters.get_application_name()); + if !server.in_transaction() { + self.stats.transaction(); + server + .stats() + .transaction(self.server_parameters.get_application_name()); - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { - break; - } + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if self.is_transaction_mode() || self.is_transparent_mode() { + break; } } + } - // Some unexpected message. We either did not implement the protocol correctly - // or this is not a Postgres client we're talking to. - _ => { - error!("Unexpected code: {}", code); - } + // Some unexpected message. We either did not implement the protocol correctly + // or this is not a Postgres client we're talking to. + _ => { + error!("Unexpected code: {}", code); } } + } - // The server is no longer bound to us, we can't cancel it's queries anymore. - debug!("Releasing server back into the pool"); - - server.checkin_cleanup().await?; + Ok(None) + } - if prepared_statements_enabled { - server.maintain_cache().await?; - } + async fn parse_ast_helper( + &mut self, + query_router: &mut QueryRouter, + initial_parsed_ast: &mut Option>, + message: &BytesMut, + client_identifier: &ClientIdentifier, + ) -> Result>>, Error> { + Ok({ + // We don't want to parse again if we already parsed it as the initial message + let ast = parse_ast(initial_parsed_ast, query_router, message, client_identifier); + + if let Some(ast_ref) = &ast { + let plugin_result = query_router.execute_plugins(ast_ref).await; + + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response_with_state(&mut self.write, &error, self.xact_info.state()) + .await?; + ControlFlow::Continue(()) + } - server.stats().idle(); - self.connected_to_server = false; + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, &result).await?; + ControlFlow::Continue(()) + } - self.release(); - self.stats.idle(); - } + _ => ControlFlow::Break(ast), + } + } else { + ControlFlow::Break(None) + } + }) } /// Retrieve connection pool, if it exists. @@ -1731,7 +2122,7 @@ where // Report query executed statistics. client_stats.query(); server.stats().query( - Instant::now().duration_since(query_start).as_millis() as u64, + query_start.elapsed().as_millis() as u64, self.server_parameters.get_application_name(), ); @@ -1796,6 +2187,58 @@ where } } } + + pub fn is_transaction_mode(&self) -> bool { + self.client_pool_mode == PoolMode::Transaction + } + + pub fn is_transparent_mode(&self) -> bool { + self.client_pool_mode == PoolMode::Transparent + } + + pub fn is_in_idle_transaction(&self) -> bool { + self.xact_info.is_idle() + } + + fn is_begin_statement(ast_vec: &[Statement]) -> bool { + ast_vec.len() == 1 && matches!(ast_vec[0], Statement::StartTransaction { .. }) + } +} + +fn parse_ast( + initial_parsed_ast: &mut Option>, + query_router: &mut QueryRouter, + message: &BytesMut, + client_identifier: &ClientIdentifier, +) -> Option> { + // We don't want to parse again if we already parsed it as the initial message + match *initial_parsed_ast { + Some(_) => { + let parsed_ast = initial_parsed_ast.take().unwrap(); + // if 'parsed_ast' is empty, it means that there was a failed + // attempt to parse the query as a custom command, earlier above. + if parsed_ast.is_empty() { + None + } else { + Some(parsed_ast) + } + } + None => match query_router.parse(message) { + Ok(ast) => { + let _ = query_router.infer(&ast); + Some(ast) + } + Err(error) => { + if error != Error::UnsupportedStatement { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + None + } + }, + } } impl Drop for Client { diff --git a/src/client_xact.rs b/src/client_xact.rs new file mode 100644 index 00000000..36e7737d --- /dev/null +++ b/src/client_xact.rs @@ -0,0 +1,544 @@ +use crate::client::Client; +use crate::errors::Error; +use crate::query_messages::{ErrorInfo, ErrorResponse, Message}; +use bytes::BytesMut; + +use core::panic; +use futures::future::join_all; +use log::{debug, warn}; +use sqlparser::ast::{Statement, TransactionAccessMode, TransactionMode}; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::config::Address; +use crate::messages::*; +use crate::server::Server; +use crate::server_xact::*; + +pub type ServerId = usize; + +/// The metadata of a server transaction. +#[derive(Default, Debug, Clone)] +pub struct ClientTxnMetaData { + begin_statement: Option, + commit_statement: Option, + abort_statement: Option, + + pub params: CommonTxnParams, +} + +impl ClientTxnMetaData { + pub fn set_state(&mut self, state: TransactionState) { + match self.params.state { + TransactionState::Idle => { + self.params.state = state; + } + TransactionState::InTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a transaction."); + } + _ => { + self.params.state = state; + } + }, + TransactionState::InFailedTransaction => match state { + TransactionState::Idle => { + warn!("Cannot go back to idle from a failed transaction."); + } + TransactionState::InTransaction => { + warn!("Cannot go back to a transaction from a failed transaction.") + } + _ => { + self.params.state = state; + } + }, + } + } + + pub fn state(&self) -> TransactionState { + self.params.state + } + + pub fn is_idle(&self) -> bool { + self.params.state == TransactionState::Idle + } + + pub fn set_xact_gid(&mut self, xact_gid: Option) { + self.params.xact_gid = xact_gid; + } + + pub fn get_xact_gid(&self) -> Option { + self.params.xact_gid.clone() + } + + pub fn set_begin_statement(&mut self, begin_statement: Option) { + self.begin_statement = begin_statement; + } + + pub fn get_begin_statement(&self) -> Option<&Statement> { + self.begin_statement.as_ref() + } + + pub fn set_commit_statement(&mut self, commit_statement: Option) { + self.commit_statement = commit_statement; + } + + pub fn get_commit_statement(&self) -> Option<&Statement> { + self.commit_statement.as_ref() + } + + pub fn set_abort_statement(&mut self, abort_statement: Option) { + self.abort_statement = abort_statement; + } + + pub fn get_abort_statement(&self) -> Option<&Statement> { + self.abort_statement.as_ref() + } +} + +impl Client +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + /// This function starts a distributed transaction by sending a BEGIN statement to the first server. + /// It is called on the first server, as soon as client wants to interact with another server, + /// which hints that the client wants to start a distributed transaction. + pub async fn begin_distributed_xact( + &mut self, + server_key: ServerId, + server: &mut Server, + ) -> Result { + let begin_stmt = self.xact_info.get_begin_statement(); + assert!(begin_stmt.is_some()); + let res = server.query(&begin_stmt.unwrap().to_string()).await; + if self.post_query_processing(server, res).await?.is_none() { + return Ok(false); + } + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + assert!(self.xact_info.get_xact_gid().is_some()); + let gid = self.xact_info.get_xact_gid().unwrap(); + + debug!("Assigning GID ('{}') to server {}", gid, server.address(),); + + let gid_res = server + .assign_xact_gid(&Self::gen_server_specific_gid(server_key, &gid)) + .await; + if self.post_query_processing(server, gid_res).await?.is_none() { + return Ok(false); + } + + Ok(true) + } + + /// This functions generates a GID for the current transaction and sends it to the server. + pub async fn acquire_gid( + &mut self, + server_key: ServerId, + server: &mut Server, + ) -> Result { + assert!(self.xact_info.get_xact_gid().is_none()); + let gid = self.generate_xact_gid(); + + debug!("Acquiring GID ('{}') from server {}", gid, server.address(),); + + // If we are in a distributed transaction, we need to assign a GID to the transaction. + let gid_res = server + .assign_xact_gid(&Self::gen_server_specific_gid(server_key, &gid)) + .await; + if self.post_query_processing(server, gid_res).await?.is_none() { + return Ok(false); + } + self.xact_info.set_xact_gid(Some(gid)); + Ok(true) + } + + /// Generates a random GID (i.e., Global transaction ID) for a transaction. + fn generate_xact_gid(&self) -> String { + format!("txn_{}_{}", self.addr, Uuid::new_v4()) + } + + /// Generates a server-specific GID for a transaction. We need this, because it's possible that + /// multiple servers might actually be the same server (which commonly happens in testing). + fn gen_server_specific_gid(server_key: ServerId, gid: &str) -> String { + format!("{}_{}", server_key, gid) + } + + /// Assigns the transaction state based on the state of all servers. + pub fn assign_client_transaction_state( + &mut self, + all_conns: &HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) { + self.xact_info.set_state(if all_conns.is_empty() { + // if there's no server, we're in idle mode. + TransactionState::Idle + } else if Self::is_any_server_in_failed_xact(all_conns) { + // if any server is in failed transaction, we're in failed transaction. + TransactionState::InFailedTransaction + } else { + // if we have at least one server and it is in a transaction, we're in a transaction. + TransactionState::InTransaction + }); + } + + /// Returns true if any server is in a failed transaction. + fn is_any_server_in_failed_xact( + all_conns: &HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> bool { + all_conns + .iter() + .any(|(_, conn)| conn.0.in_failed_transaction()) + } + + /// This function initializes the transaction parameters based on the server's default. + pub fn initialize_xact_params(&mut self, server: &mut Server, begin_stmt: &Statement) { + if let Statement::StartTransaction { modes } = begin_stmt { + // Initialize transaction parameters using the server's default. + self.xact_info.params = server.server_default_transaction_parameters(); + for mode in modes { + match mode { + TransactionMode::AccessMode(access_mode) => { + self.xact_info.params.set_read_only(match access_mode { + TransactionAccessMode::ReadOnly => true, + TransactionAccessMode::ReadWrite => false, + }); + } + TransactionMode::IsolationLevel(isolation_level) => { + self.xact_info.params.set_isolation_level(*isolation_level); + } + } + } + debug!( + "Transaction paramaters after the first BEGIN statement: {:?}", + self.xact_info.params + ); + + // Set the transaction parameters on the first server. + self.set_transaction_params_to_server(server); + } else { + // If it's not a BEGIN, then it's an irrecoverable error. + panic!("The statement is not a BEGIN statement."); + } + } + + fn set_transaction_params_to_server(&mut self, server: &mut Server) { + let server_params = &mut server.transaction_metadata_mut().params; + + server_params.set_isolation_level(self.xact_info.params.get_isolation_level()); + server_params.set_read_only(self.xact_info.params.is_read_only()); + server_params.set_deferrable(self.xact_info.params.is_deferrable()); + } + + /// This function performs a distribted abort/commit if necessary, and also resets the transaction + /// state. This is suppoed to be called before exiting the transaction loop. At that point, if + /// either an abort or commit statement is set, we need to perform a distributed abort/commit. This + /// is based on the logic that an abort or commit statement is only set if we are in a distributed + /// transaction and we observe a commit or abort statement sent to the server. That is where we exit + /// the transaction loop and expect this function to takeover and abort/commit the transaction. + pub async fn distributed_commit_or_abort( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + let dist_commit = self.xact_info.get_commit_statement(); + let dist_abort = self.xact_info.get_abort_statement(); + if dist_commit.is_some() || dist_abort.is_some() { + // if either a commit or abort statement is set, we should be in a distributed transaction. + assert!(!all_conns.is_empty()); + + let is_chained = Self::should_be_chained(dist_commit, dist_abort); + let dist_commit = dist_commit.map(|stmt| stmt.to_string()); + let mut dist_abort = dist_abort.map(|stmt| stmt.to_string()); + + // Report transaction executed statistics. + self.stats.transaction(); + + let mut is_distributed_commit_failed = false; + // We are in distributed transaction mode, and need to commit or abort on all servers. + if let Some(commit_stmt) = dist_commit { + // If two-phase commit was successful, we can send the COMMIT message to the client. + // Otherwise, we need to ROLLBACK on all servers. + let dist_commit_res = self.distributed_commit(all_conns).await; + if self + .communicate_err_response(dist_commit_res) + .await? + .is_none() + { + // Currently, if a distributed commit fails, we send a ROLLBACK to all servers. + // However, this is different from how Postgres handles it. Postgres sends an + // error response to the client, and then does not accept any more queries from + // the client until the client explicitly sends a ROLLBACK. + dist_abort = Some("ROLLBACK".to_string()); + is_distributed_commit_failed = true; + } else { + custom_protocol_response_ok_with_state( + &mut self.write, + &commit_stmt, + TransactionState::Idle, + ) + .await?; + } + } + + if let Some(abort_stmt) = dist_abort { + let distributed_abort_res = self.distributed_abort(all_conns, &abort_stmt).await; + if is_distributed_commit_failed { + // Nothing to do, as the error reponse is already sent before. + } else if self + .communicate_err_response(distributed_abort_res) + .await? + .is_some() + { + custom_protocol_response_ok_with_state( + &mut self.write, + &abort_stmt, + TransactionState::Idle, + ) + .await?; + } + } + + let is_all_servers_in_non_copy_mode = + all_conns.iter().all(|(_, conn)| !conn.0.in_copy_mode()); + + // Release server back to the pool if we are in transaction or transparent modes. + // If we are in session mode, we keep the server until the client disconnects. + if (self.is_transaction_mode() || self.is_transparent_mode()) + && is_all_servers_in_non_copy_mode + { + self.stats.idle(); + } + + if is_chained { + let last_conn = all_conns + .get_mut(self.last_server_key.as_ref().unwrap()) + .unwrap(); + let last_server = &mut *last_conn.0; + + // TODO(MD): chained transaction should be implemented. + // Here, we need to start a local transaction on the last server. However, here is + // too late to start a transaction, as we are far from the transaction loop. We need to + // rearrange the code (or add more complicated control flow) to make it possible. + warn!( + "Chained transaction is not implemented yet. \ + The last server {} will NOT be in transaction.", + last_server.address() + ); + } + } + Ok(()) + } + + pub fn reset_client_xact(&mut self) { + // Reset transaction state for safety reasons. + self.xact_info = Default::default(); + } + + async fn distributed_commit( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + debug!("Committing distributed transaction."); + if Self::is_any_server_in_failed_xact(all_conns) { + #[cfg(debug_assertions)] + all_conns.iter().for_each(|(server_key, conn)| { + let server = &*conn.0; + if server.in_failed_transaction() { + debug!( + "Server {} (with server_key: {:?}) is in failed transaction. Skipping commit.", + server.address(), + server_key, + ); + } + }); + + let err = ErrorInfo::new_brief( + "Error".to_string(), + "25P02".to_string(), + "Cannot commit a transaction that is in failed state.".to_string(), + ); + + return Err(Error::ErrorResponse(ErrorResponse::from(err))); + } + self.distributed_prepare(all_conns).await?; + + let commit_prepared_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + server.local_server_commit_prepared() + })) + .await; + + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + }); + + for commit_prepared_res in commit_prepared_results { + // For now, we just return the first error we encounter. + commit_prepared_res?; + } + + Ok(()) + } + + /// After each interaction with the server, we need to set the transaction state based on the + /// server's state. + fn set_post_query_state(&mut self, server: &mut Server) { + self.xact_info + .set_state(server.transaction_metadata().state()); + } + + async fn distributed_abort( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + abort_stmt: &str, + ) -> Result<(), Error> { + debug!("Aborting distributed transaction"); + let abort_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + server.query(abort_stmt) + })) + .await; + + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + server + .stats() + .transaction(self.server_parameters.get_application_name()); + }); + + for abort_res in abort_results { + // For now, we just return the first error we encounter. + abort_res?; + } + Ok(()) + } + + #[allow(clippy::type_complexity)] + async fn distributed_prepare( + &mut self, + all_conns: &mut HashMap< + ServerId, + (bb8::PooledConnection<'_, crate::pool::ServerPool>, Address), + >, + ) -> Result<(), Error> { + // Apply 'PREPARE TRANSACTION' on all involved servers. + let prepare_results = join_all(all_conns.iter_mut().map(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + server.local_server_prepare_transaction() + })) + .await; + + // Update the client state based on the server state. + all_conns.iter_mut().for_each(|(_, conn)| { + let server = &mut *conn.0; + self.set_post_query_state(server); + }); + + // If there was any error, we need to abort the transaction. + for prepare_res in prepare_results { + prepare_res?; + } + Ok(()) + } + + /// Returns true if the statement is a commit or abort statement. Also, it sets the commit or abort + /// statement on the client. + pub fn set_commit_or_abort_statement(&mut self, ast: &Vec) -> bool { + if Self::is_commit_statement(ast) { + self.xact_info.set_commit_statement(Some(ast[0].clone())); + true + } else if Self::is_abort_statement(ast) { + self.xact_info.set_abort_statement(Some(ast[0].clone())); + true + } else { + false + } + } + + /// Returns true if the statement is a commit statement. + fn is_commit_statement(ast: &Vec) -> bool { + for statement in ast { + if let Statement::Commit { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; + } + } + false + } + + /// Returns true if the statement is an abort statement. + fn is_abort_statement(ast: &Vec) -> bool { + for statement in ast { + if let Statement::Rollback { .. } = *statement { + assert_eq!(ast.len(), 1); + return true; + } + } + false + } + + /// Returns true if the commit or abort statement should be chained. + fn should_be_chained(dist_commit: Option<&Statement>, dist_abort: Option<&Statement>) -> bool { + matches!( + (dist_commit, dist_abort), + (Some(Statement::Commit { chain: true }), _) + | (_, Some(Statement::Rollback { chain: true })) + ) + } + async fn communicate_err_response( + &mut self, + res: Result, + ) -> Result, Error> { + match res { + Ok(res) => Ok(Some(res)), + Err(Error::ErrorResponse(err)) => { + error_response_stmt(&mut self.write, &err, self.xact_info.state()).await?; + Ok(None) + } + Err(err) => Err(err), + } + } + + pub async fn post_query_processing( + &mut self, + server: &mut Server, + res: Result, + ) -> Result, Error> { + self.set_post_query_state(server); + self.communicate_err_response(res).await + } +} + +/// Send an error response to the client. +pub async fn error_response_stmt( + stream: &mut S, + err: &ErrorResponse, + t_state: TransactionState, +) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let mut err_bytes = BytesMut::new(); + err.encode(&mut err_bytes)?; + write_all_half(stream, &err_bytes).await?; + + ready_for_query_with_state(stream, t_state).await +} diff --git a/src/config.rs b/src/config.rs index f91e488e..c108a3f2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -477,6 +477,9 @@ impl Default for General { /// - session: server is attached to the client. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum PoolMode { + #[serde(alias = "transparent", alias = "Transparent")] + Transparent, + #[serde(alias = "transaction", alias = "Transaction")] Transaction, @@ -487,6 +490,7 @@ pub enum PoolMode { impl ToString for PoolMode { fn to_string(&self) -> String { match *self { + PoolMode::Transparent => "transparent".to_string(), PoolMode::Transaction => "transaction".to_string(), PoolMode::Session => "session".to_string(), } diff --git a/src/errors.rs b/src/errors.rs index a6aebc50..3a312655 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,5 +1,7 @@ //! Errors. +use crate::query_messages::ErrorResponse; + /// Various errors. #[derive(Debug, PartialEq, Clone)] pub enum Error { @@ -29,6 +31,8 @@ pub enum Error { QueryRouterParserError(String), QueryRouterError(String), InvalidShardId(usize), + ErrorResponse(ErrorResponse), + IncompletePacket, } #[derive(Clone, PartialEq, Debug)] diff --git a/src/lib.rs b/src/lib.rs index 6a8a1e36..2d64b216 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod admin; pub mod auth_passthrough; pub mod client; +pub mod client_xact; pub mod cmd_args; pub mod config; pub mod constants; @@ -12,9 +13,11 @@ pub mod mirrors; pub mod plugins; pub mod pool; pub mod prometheus; +pub mod query_messages; pub mod query_router; pub mod scram; pub mod server; +pub mod server_xact; pub mod sharding; pub mod stats; pub mod tls; diff --git a/src/messages.rs b/src/messages.rs index 86036a92..c438800d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -12,6 +12,7 @@ use crate::config::get_config; use crate::errors::Error; use crate::constants::MESSAGE_TERMINATOR; +use crate::server_xact::TransactionState; use std::collections::HashMap; use std::ffi::CString; use std::fmt::{Display, Formatter}; @@ -58,7 +59,7 @@ where auth_ok.put_i32(8); auth_ok.put_i32(0); - write_all(stream, auth_ok).await + write_all(stream, &auth_ok).await } /// Generate md5 password challenge. @@ -80,7 +81,7 @@ where res.put_i32(5); // MD5 res.put_slice(&salt[..]); - write_all(stream, res).await?; + write_all(stream, &res).await?; Ok(salt) } @@ -99,7 +100,7 @@ where key_data.put_i32(backend_id); key_data.put_i32(secret_key); - write_all(stream, key_data).await + write_all(stream, &key_data).await } /// Construct a `Q`: Query message. @@ -115,6 +116,17 @@ pub fn simple_query(query: &str) -> BytesMut { /// Tell the client we're ready for another query. pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + ready_for_query_with_state(stream, TransactionState::Idle).await +} + +/// Tell the client we're ready for another query or not, based on the given transaction state. +pub async fn ready_for_query_with_state( + stream: &mut S, + t_state: TransactionState, +) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { @@ -124,9 +136,13 @@ where bytes.put_u8(b'Z'); bytes.put_i32(5); - bytes.put_u8(b'I'); // Idle + match t_state { + TransactionState::Idle => bytes.put_u8(b'I'), + TransactionState::InTransaction => bytes.put_u8(b'T'), + TransactionState::InFailedTransaction => bytes.put_u8(b'E'), + } - write_all(stream, bytes).await + write_all(stream, &bytes).await } /// Send the startup packet the server. We're pretending we're a Pg client. @@ -285,7 +301,7 @@ where message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); - write_all(stream, message).await + write_all(stream, &message).await } pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> @@ -299,7 +315,7 @@ where message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); - write_all(stream, message).await + write_all(stream, &message).await } /// Implements a response to our custom `SET SHARDING KEY` @@ -309,6 +325,24 @@ pub async fn custom_protocol_response_ok(stream: &mut S, message: &str) -> Re where S: tokio::io::AsyncWrite + std::marker::Unpin, { + custom_protocol_response_ok_with_state(stream, message, TransactionState::Idle).await +} + +/// Implements a response to our custom `SET SHARDING KEY` +/// and `SET SERVER ROLE` commands. +/// This tells the client we're ready for the next query or not, based on the state. +pub async fn custom_protocol_response_ok_with_state( + stream: &mut S, + message: &str, + t_state: TransactionState, +) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + debug!( + "Sending custom protocol response: {} at {:?} state.", + message, t_state + ); let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&format!("{}\0", message)[..]); @@ -320,18 +354,34 @@ where res.put_slice(&set_complete[..]); write_all_half(stream, &res).await?; - ready_for_query(stream).await + ready_for_query_with_state(stream, t_state).await } /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + error_response_with_state(stream, message, TransactionState::Idle).await +} + +/// Send a custom error message to the client. +/// Tell the client we are ready for the next query. Given the current transaction state, no +/// rollback is necessary if it's in the "idle" or "transaction" state (i.e., not already in the +/// rollback state). +/// Docs on error codes: . +pub async fn error_response_with_state( + stream: &mut S, + message: &str, + t_state: TransactionState, +) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { error_response_terminal(stream, message).await?; - ready_for_query(stream).await + ready_for_query_with_state(stream, t_state).await } /// Send a custom error message to the client. @@ -405,7 +455,7 @@ where res.put(error); - write_all(stream, res).await + write_all(stream, &res).await } /// Respond to a SHOW SHARD command. @@ -563,11 +613,11 @@ pub fn flush() -> BytesMut { } /// Write all data in the buffer to the TcpStream. -pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), Err(err) => Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", diff --git a/src/pool.rs b/src/pool.rs index 736dc1ad..77bcbfca 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -92,7 +92,7 @@ impl From<&Address> for PoolIdentifier { /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { - /// Transaction or Session. + /// Transparent, Transaction or Session. pub pool_mode: PoolMode, /// Random or LeastOutstandingConnections. diff --git a/src/query_messages.rs b/src/query_messages.rs new file mode 100644 index 00000000..16575bca --- /dev/null +++ b/src/query_messages.rs @@ -0,0 +1,468 @@ +use std::fmt; + +/// Helper functions to send one-off protocol messages +/// and handle TcpStream (TCP socket). +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::errors::Error; + +pub type Oid = u32; + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct FieldDescription { + // the field name + name: String, + // the object ID of table, default to 0 if not a table + table_id: i32, + // the attribute number of the column, default to 0 if not a column from table + column_id: i16, + // the object ID of the data type + type_id: Oid, + // the size of data type, negative values denote variable-width types + type_size: i16, + // the type modifier + type_modifier: i32, + // the format code being used for the filed, will be 0 or 1 for now + format_code: i16, +} + +/// Get null-terminated string, returns None when empty cstring read. +/// +/// Note that this implementation will also advance cursor by 1 after reading +/// empty cstring. This behaviour works for how postgres wire protocol handling +/// key-value pairs, which is ended by a single `\0` +pub(crate) fn get_cstring(buf: &mut BytesMut) -> Option { + let mut i = 0; + + // with bound check to prevent invalid format + while i < buf.remaining() && buf[i] != b'\0' { + i += 1; + } + + // i+1: include the '\0' + // move cursor to the end of cstring + let string_buf = buf.split_to(i + 1); + + if i == 0 { + None + } else { + Some(String::from_utf8_lossy(&string_buf[..i]).into_owned()) + } +} + +/// Put null-termianted string +/// +/// You can put empty string by giving `""` as input. +pub(crate) fn put_cstring(buf: &mut BytesMut, input: &str) { + buf.put_slice(input.as_bytes()); + buf.put_u8(b'\0'); +} + +/// Try to read message length from buf, without actually move the cursor +pub(crate) fn get_length(buf: &BytesMut, offset: usize) -> Result { + if buf.remaining() >= 4 + offset { + Ok((&buf[offset..4 + offset]).get_i32() as usize) + } else { + Err(Error::IncompletePacket) + } +} + +/// Check if message_length matches and move the cursor to right position then +/// call the `decode_fn` for the body +pub(crate) fn decode_packet( + buf: &mut BytesMut, + offset: usize, + decode_fn: F, +) -> Result +where + F: Fn(&mut BytesMut, usize) -> Result, +{ + let msg_len = get_length(buf, offset)?; + + if buf.remaining() >= msg_len + offset { + buf.advance(offset + 4); + return decode_fn(buf, msg_len); + } + + Err(Error::IncompletePacket) +} + +/// Define how message encode and decoded. +pub trait Message: Sized { + /// Return the type code of the message. In order to maintain backward + /// compatibility, `Startup` has no message type. + #[inline] + fn message_type() -> Option { + None + } + + /// Return the length of the message, including the length integer itself. + fn message_length(&self) -> usize; + + /// Encode body part of the message. + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error>; + + /// Decode body part of the message. + fn decode_body(buf: &mut BytesMut, full_len: usize) -> Result; + + /// Default implementation for encoding message. + /// + /// Message type and length are encoded in this implementation and it calls + /// `encode_body` for remaining parts. + fn encode(&self, buf: &mut BytesMut) -> Result<(), Error> { + if let Some(mt) = Self::message_type() { + buf.put_u8(mt); + } + + let message_length = self.message_length(); + let original_buf_len = buf.len(); + buf.put_i32(message_length as i32); + let result = self.encode_body(buf); + assert_eq!(buf.len() - original_buf_len, message_length); + + result + } + + /// Default implementation for decoding message. + /// + /// Message type and length are decoded in this implementation and it calls + /// `decode_body` for remaining parts. Return `None` if the packet is not + /// complete for parsing. + fn decode(buf: &mut BytesMut) -> Result { + let offset = Self::message_type().is_some().into(); + + decode_packet(buf, offset, |buf, full_len| { + Self::decode_body(buf, full_len) + }) + } +} + +pub const MESSAGE_TYPE_BYTE_ROW_DESCRITION: u8 = b'T'; + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct RowDescription { + fields: Vec, +} + +impl RowDescription { + pub fn fields(&self) -> &[FieldDescription] { + &self.fields + } +} + +impl Message for RowDescription { + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_ROW_DESCRITION) + } + + fn message_length(&self) -> usize { + 4 + 2 + + self + .fields + .iter() + .map(|f| f.name.as_bytes().len() + 1 + 4 + 2 + 4 + 2 + 4 + 2) + .sum::() + } + + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { + buf.put_i16(self.fields.len() as i16); + + for field in &self.fields { + put_cstring(buf, &field.name); + buf.put_i32(field.table_id); + buf.put_i16(field.column_id); + buf.put_u32(field.type_id); + buf.put_i16(field.type_size); + buf.put_i32(field.type_modifier); + buf.put_i16(field.format_code); + } + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _: usize) -> Result { + let fields_len = buf.get_i16(); + let mut fields = Vec::with_capacity(fields_len as usize); + + for _ in 0..fields_len { + let field = FieldDescription { + name: get_cstring(buf).unwrap_or_else(|| "".to_owned()), + table_id: buf.get_i32(), + column_id: buf.get_i16(), + type_id: buf.get_u32(), + type_size: buf.get_i16(), + type_modifier: buf.get_i32(), + format_code: buf.get_i16(), + }; + + fields.push(field); + } + + Ok(RowDescription { fields }) + } +} + +/// Data structure for postgresql wire protocol `DataRow` message. +/// +/// Data can be represented as text or binary format as specified by format +/// codes from previous `RowDescription` message. +#[derive(PartialEq, Eq, Debug, Default, Clone)] +pub struct DataRow { + fields: Vec>, +} + +impl DataRow { + pub fn fields(&self) -> &[Option] { + &self.fields + } +} + +pub const MESSAGE_TYPE_BYTE_DATA_ROW: u8 = b'D'; + +impl Message for DataRow { + #[inline] + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_DATA_ROW) + } + + fn message_length(&self) -> usize { + 4 + 2 + + self + .fields + .iter() + .map(|b| b.as_ref().map(|b| b.len() + 4).unwrap_or(4)) + .sum::() + } + + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { + buf.put_i16(self.fields.len() as i16); + for field in &self.fields { + if let Some(bytes) = field { + buf.put_i32(bytes.len() as i32); + buf.put_slice(bytes.as_ref()); + } else { + buf.put_i32(-1); + } + } + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _msg_len: usize) -> Result { + let field_count = buf.get_i16() as usize; + + let mut fields = Vec::with_capacity(field_count); + for _ in 0..field_count { + let field_len = buf.get_i32(); + if field_len >= 0 { + fields.push(Some(buf.split_to(field_len as usize).freeze())); + } else { + fields.push(None); + } + } + + Ok(DataRow { fields }) + } +} + +#[derive(PartialEq, Eq, Debug, Default)] +pub struct QueryResponse { + row_desc: RowDescription, + data_rows: Vec, +} + +impl QueryResponse { + pub fn new(row_desc: RowDescription, data_rows: Vec) -> Self { + Self { + row_desc, + data_rows, + } + } + + pub fn row_desc(&self) -> &RowDescription { + &self.row_desc + } + + pub fn data_rows(&self) -> &[DataRow] { + &self.data_rows + } +} + +// Postgres error and notice message fields +// This part of protocol is defined in +// https://www.postgresql.org/docs/8.2/protocol-error-fields.html +#[derive(Debug, Default)] +pub struct ErrorInfo { + // severity can be one of `ERROR`, `FATAL`, or `PANIC` (in an error + // message), or `WARNING`, `NOTICE`, `DEBUG`, `INFO`, or `LOG` (in a notice + // message), or a localized translation of one of these. + severity: String, + // error code defined in + // https://www.postgresql.org/docs/current/errcodes-appendix.html + code: String, + // readable message + message: String, + // optional secondary message + detail: Option, + // optional suggestion for fixing the issue + hint: Option, + // Position: the field value is a decimal ASCII integer, indicating an error + // cursor position as an index into the original query string. + position: Option, + // Internal position: this is defined the same as the P field, but it is + // used when the cursor position refers to an internally generated command + // rather than the one submitted by the client + internal_position: Option, + // Internal query: the text of a failed internally-generated command. + internal_query: Option, + // Where: an indication of the context in which the error occurred. + where_context: Option, + // File: the file name of the source-code location where the error was + // reported. + file_name: Option, + // Line: the line number of the source-code location where the error was + // reported. + line: Option, + // Routine: the name of the source-code routine reporting the error. + routine: Option, +} + +impl ErrorInfo { + #[allow(clippy::too_many_arguments)] + pub fn new( + severity: String, + code: String, + message: String, + detail: Option, + hint: Option, + position: Option, + internal_position: Option, + internal_query: Option, + where_context: Option, + file_name: Option, + line: Option, + routine: Option, + ) -> Self { + Self { + severity, + code, + message, + detail, + hint, + position, + internal_position, + internal_query, + where_context, + file_name, + line, + routine, + } + } + pub fn new_brief(severity: String, code: String, message: String) -> Self { + Self::new( + severity, code, message, None, None, None, None, None, None, None, None, None, + ) + } +} + +impl ErrorInfo { + fn into_fields(self) -> Vec<(u8, String)> { + let mut fields = Vec::with_capacity(11); + + fields.push((b'S', self.severity)); + fields.push((b'C', self.code)); + fields.push((b'M', self.message)); + if let Some(value) = self.detail { + fields.push((b'D', value)); + } + if let Some(value) = self.hint { + fields.push((b'H', value)); + } + if let Some(value) = self.position { + fields.push((b'P', value)); + } + if let Some(value) = self.internal_position { + fields.push((b'p', value)); + } + if let Some(value) = self.internal_query { + fields.push((b'q', value)); + } + if let Some(value) = self.where_context { + fields.push((b'W', value)); + } + if let Some(value) = self.file_name { + fields.push((b'F', value)); + } + if let Some(value) = self.line { + fields.push((b'L', value.to_string())); + } + if let Some(value) = self.routine { + fields.push((b'R', value)); + } + + fields + } +} + +impl From for ErrorResponse { + fn from(ei: ErrorInfo) -> ErrorResponse { + ErrorResponse { + fields: ei.into_fields(), + } + } +} + +/// postgres error response, sent from backend to frontend +#[derive(PartialEq, Eq, Debug, Clone, Default)] +pub struct ErrorResponse { + fields: Vec<(u8, String)>, +} + +impl fmt::Display for ErrorResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ErrorResponse({:?})", self.fields) + } +} + +pub const MESSAGE_TYPE_BYTE_ERROR_RESPONSE: u8 = b'E'; + +impl Message for ErrorResponse { + #[inline] + fn message_type() -> Option { + Some(MESSAGE_TYPE_BYTE_ERROR_RESPONSE) + } + + fn message_length(&self) -> usize { + 4 + self + .fields + .iter() + .map(|f| 1 + f.1.as_bytes().len() + 1) + .sum::() + + 1 + } + + fn encode_body(&self, buf: &mut BytesMut) -> Result<(), Error> { + for (code, value) in &self.fields { + buf.put_u8(*code); + put_cstring(buf, value); + } + + buf.put_u8(b'\0'); + + Ok(()) + } + + fn decode_body(buf: &mut BytesMut, _: usize) -> Result { + let mut fields = Vec::new(); + loop { + let code = buf.get_u8(); + + if code == b'\0' { + return Ok(ErrorResponse { fields }); + } else { + let value = get_cstring(buf).unwrap_or_else(|| "".to_owned()); + fields.push((code, value)); + } + } + } +} diff --git a/src/query_router.rs b/src/query_router.rs index 8b451dd3..4b6873dc 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -338,6 +338,9 @@ impl QueryRouter { Some((command, value)) } + const UNSUPPORTED_STATEMENTS_FOR_PARSING: [&'static str; 4] = + ["COPY", "SET", "TRUNCATE", "VACUUM"]; + pub fn parse(&self, message: &BytesMut) -> Result, Error> { let mut message_cursor = Cursor::new(message); @@ -380,8 +383,22 @@ impl QueryRouter { match Parser::parse_sql(&PostgreSqlDialect {}, &query) { Ok(ast) => Ok(ast), Err(err) => { - debug!("{}: {}", err, query); - Err(Error::QueryRouterParserError(err.to_string())) + let qry_upper = query.to_ascii_uppercase(); + + // Check for unsupported statements to avoid producing a warning. + // Note 1: this is not a complete list of unsupported statements. + // Note 2: we do not check for unsupported statements before going through the + // parser, as the plugin system might be able to handle them, once sqlparser + // is able to correctly parse these (rather valid) queries. + if Self::UNSUPPORTED_STATEMENTS_FOR_PARSING + .iter() + .any(|s| qry_upper.starts_with(s)) + { + Err(Error::UnsupportedStatement) + } else { + debug!("{}: {}", err, query); + Err(Error::QueryRouterParserError(err.to_string())) + } } } } diff --git a/src/server.rs b/src/server.rs index 3394cda7..22e58257 100644 --- a/src/server.rs +++ b/src/server.rs @@ -24,7 +24,9 @@ use crate::messages::BytesMutReader; use crate::messages::*; use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; +use crate::query_messages::{ErrorResponse, Message}; use crate::scram::ScramSha256; +use crate::server_xact::*; use crate::stats::ServerStats; use std::io::Write; @@ -106,7 +108,7 @@ impl StreamInner { } #[derive(Copy, Clone)] -struct CleanupState { +pub(crate) struct CleanupState { /// If server connection requires RESET ALL before checkin because of set statement needs_cleanup_set: bool, @@ -131,7 +133,7 @@ impl CleanupState { self.needs_cleanup_prepare = true; } - fn reset(&mut self) { + pub(crate) fn reset(&mut self) { self.needs_cleanup_set = false; self.needs_cleanup_prepare = false; } @@ -154,12 +156,15 @@ static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { set.insert("TimeZone".to_string()); set.insert("standard_conforming_strings".to_string()); set.insert("application_name".to_string()); + for param in TRANSACTION_PARAMETERS.iter() { + set.insert(param.clone()); + } set }); #[derive(Debug, Clone)] pub struct ServerParameters { - parameters: HashMap, + pub(crate) parameters: HashMap, } impl Default for ServerParameters { @@ -183,6 +188,7 @@ impl ServerParameters { false, ); server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false); + CommonTxnParams::set_default_server_parameters(&mut server_parameters); server_parameters } @@ -265,7 +271,7 @@ impl From<&ServerParameters> for BytesMut { pub struct Server { /// Server host, e.g. localhost, /// port, e.g. 5432, and role, e.g. primary or replica. - address: Address, + pub(crate) address: Address, /// Server TCP connection. stream: BufStream, @@ -274,14 +280,14 @@ pub struct Server { buffer: BytesMut, /// Server information the server sent us over on startup. - server_parameters: ServerParameters, + pub(crate) server_parameters: ServerParameters, /// Backend id and secret key used for query cancellation. process_id: i32, secret_key: i32, /// Is the server inside a transaction or idle. - in_transaction: bool, + pub(crate) transaction_metadata: ServerTxnMetaData, /// Is there more data for the client to read. data_available: bool, @@ -293,7 +299,7 @@ pub struct Server { bad: bool, /// If server connection requires reset statements before checkin - cleanup_state: CleanupState, + pub(crate) cleanup_state: CleanupState, /// Mapping of clients and servers used for query cancellation. client_server_map: ClientServerMap, @@ -493,7 +499,7 @@ impl Server { } }; - trace!("Message: {}", code); + trace!("server Message: {}", code); match code { // Authentication @@ -790,14 +796,14 @@ impl Server { } }; - let server = Server { + let mut server = Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), server_parameters, process_id, secret_key, - in_transaction: false, + transaction_metadata: Default::default(), in_copy_mode: false, data_available: false, bad: false, @@ -821,6 +827,11 @@ impl Server { prepared_statements: BTreeSet::new(), }; + // We want to make sure that all servers are operating on the same isolation level. + server + .sync_given_parameter_keys(&TRANSACTION_PARAMETERS) + .await?; + return Ok(server); } @@ -920,20 +931,21 @@ impl Server { 'Z' => { let transaction_state = message.get_u8() as char; + let metadata = &mut self.transaction_metadata; match transaction_state { // In transaction. 'T' => { - self.in_transaction = true; + metadata.set_state(TransactionState::InTransaction); } // Idle, transaction over. 'I' => { - self.in_transaction = false; + metadata.set_state(TransactionState::Idle); } // Some error occurred, the transaction was rolled back. 'E' => { - self.in_transaction = true; + metadata.set_state(TransactionState::InFailedTransaction); } // Something totally unexpected, this is not a Postgres server we know. @@ -976,7 +988,7 @@ impl Server { // No great way to differentiate between set and set local // As a result, we will miss cases when set statements are used in transactions // This will reduce amount of reset statements sent - if !self.in_transaction { + if !self.in_transaction() { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_set = true; } @@ -1183,8 +1195,8 @@ impl Server { /// If the server is still inside a transaction. /// If the client disconnects while the server is in a transaction, we will clean it up. pub fn in_transaction(&self) -> bool { - debug!("Server in transaction: {}", self.in_transaction); - self.in_transaction + debug!("Server in transaction: {:?}", self.transaction_metadata); + !self.transaction_metadata.is_idle() } pub fn in_copy_mode(&self) -> bool { @@ -1227,21 +1239,7 @@ impl Server { pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> { let parameter_diff = self.server_parameters.compare_params(parameters); - if parameter_diff.is_empty() { - return Ok(()); - } - - let mut query = String::from(""); - - for (key, value) in parameter_diff { - query.push_str(&format!("SET {} TO '{}';", key, value)); - } - - let res = self.query(&query).await; - - self.cleanup_state.reset(); - - res + self.sync_given_parameter_key_values(¶meter_diff).await } /// Indicate that this server connection cannot be re-used and must be discarded. @@ -1268,20 +1266,29 @@ impl Server { /// It will use the simple query protocol. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. pub async fn query(&mut self, query: &str) -> Result<(), Error> { - debug!("Running `{}` on server {:?}", query, self.address); + debug!("Running `{}` on server {}", query, self.address); let query = simple_query(query); self.send(&query).await?; + let mut err = None; loop { - let _ = self.recv(None).await?; + let mut response = self.recv(None).await?; + + if response[0] == b'E' && err.is_none() { + err = Some(ErrorResponse::decode(&mut response)); + } if !self.data_available { break; } } + if let Some(err) = err { + return Err(Error::ErrorResponse(err?)); + } + Ok(()) } @@ -1322,6 +1329,8 @@ impl Server { warn!(target: "pgcat::server::cleanup", "Server returned while still in copy-mode"); } + self.transaction_metadata = Default::default(); + Ok(()) } @@ -1385,6 +1394,14 @@ impl Server { parse_query_message(&mut message).await } + + pub fn transaction_metadata(&self) -> &ServerTxnMetaData { + &self.transaction_metadata + } + + pub fn transaction_metadata_mut(&mut self) -> &mut ServerTxnMetaData { + &mut self.transaction_metadata + } } async fn parse_query_message(message: &mut BytesMut) -> Result, Error> { diff --git a/src/server_xact.rs b/src/server_xact.rs new file mode 100644 index 00000000..0927af79 --- /dev/null +++ b/src/server_xact.rs @@ -0,0 +1,307 @@ +/// Implementation of the PostgreSQL server (database) protocol. +/// Here we are pretending to the a Postgres client. +use log::{debug, warn}; +use once_cell::sync::Lazy; +use sqlparser::ast::TransactionIsolationLevel; +use std::collections::HashMap; + +use crate::errors::Error; +use crate::server::{Server, ServerParameters}; + +/// The default transaction parameters that might be configured on the server. +pub static TRANSACTION_PARAMETERS: Lazy> = Lazy::new(|| { + vec![ + "default_transaction_isolation".to_string(), + "default_transaction_read_only".to_string(), + "default_transaction_deferrable".to_string(), + ] +}); + +/// The default transaction parameters that are either configured on the server or set by the +/// BEGIN statement. +#[derive(Debug, Clone)] +pub struct CommonTxnParams { + pub(crate) state: TransactionState, + + pub(crate) xact_gid: Option, + + isolation_level: TransactionIsolationLevel, + read_only: bool, + deferrable: bool, +} + +impl CommonTxnParams { + pub fn new( + isolation_level: TransactionIsolationLevel, + read_only: bool, + deferrable: bool, + ) -> Self { + Self { + state: TransactionState::Idle, + xact_gid: None, + isolation_level, + read_only, + deferrable, + } + } + + pub fn get_isolation_level(&self) -> TransactionIsolationLevel { + self.isolation_level + } + + pub fn is_read_only(&self) -> bool { + self.read_only + } + + pub fn is_deferrable(&self) -> bool { + self.deferrable + } + + pub fn set_isolation_level(&mut self, isolation_level: TransactionIsolationLevel) { + self.isolation_level = isolation_level; + } + + pub fn set_read_only(&mut self, read_only: bool) { + self.read_only = read_only; + } + + pub fn set_deferrable(&mut self, deferrable: bool) { + self.deferrable = deferrable; + } + + pub fn is_serializable(&self) -> bool { + matches!( + self.get_isolation_level(), + TransactionIsolationLevel::Serializable + ) + } + + pub fn is_repeatable_read(&self) -> bool { + matches!( + self.get_isolation_level(), + TransactionIsolationLevel::RepeatableRead + ) + } + + pub fn is_repeatable_read_or_higher(&self) -> bool { + self.is_serializable() || self.is_repeatable_read() + } + + /// Sets the default transaction parameters on the given ServerParameters instance. + pub fn set_default_server_parameters(sparams: &mut ServerParameters) { + // TODO(MD): make these configurable + sparams.set_param( + "default_transaction_isolation".to_string(), + "read committed".to_string(), + false, + ); + sparams.set_param( + "default_transaction_read_only".to_string(), + "off".to_string(), + false, + ); + sparams.set_param( + "default_transaction_deferrable".to_string(), + "off".to_string(), + false, + ); + } +} + +impl Default for CommonTxnParams { + fn default() -> Self { + Self::new(TransactionIsolationLevel::ReadCommitted, false, false) + } +} + +/// The various states that a server transaction can be in. +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum TransactionState { + /// Server is idle. + Idle, + /// Server is in a transaction. + InTransaction, + /// Server is in a failed transaction. + InFailedTransaction, +} + +/// The metadata of a server transaction. +#[derive(Default, Debug, Clone)] +pub struct ServerTxnMetaData { + is_prepared: bool, + + pub params: CommonTxnParams, +} + +impl ServerTxnMetaData { + pub fn set_state(&mut self, state: TransactionState) { + self.params.state = state; + } + + pub fn state(&self) -> TransactionState { + self.params.state + } + + pub fn is_idle(&self) -> bool { + self.params.state == TransactionState::Idle + } + + pub fn is_in_transaction(&self) -> bool { + self.params.state == TransactionState::InTransaction + } + + pub fn is_in_failed_transaction(&self) -> bool { + self.params.state == TransactionState::InFailedTransaction + } + + pub fn set_xact_gid(&mut self, xact_gid: Option) { + self.params.xact_gid = xact_gid; + } + + pub fn get_xact_gid(&self) -> Option { + self.params.xact_gid.clone() + } + + pub fn set_prepared(&mut self, is_prepared: bool) { + self.is_prepared = is_prepared; + } + + pub fn has_done_prepare_transaction(&self) -> bool { + self.is_prepared + } +} + +impl ServerParameters { + fn get_default_transaction_isolation(&self) -> TransactionIsolationLevel { + // Can unwrap because we set it in the constructor + if let Some(isolation_level) = self.parameters.get("default_transaction_isolation") { + return match isolation_level.to_lowercase().as_str() { + "read committed" => TransactionIsolationLevel::ReadCommitted, + "repeatable read" => TransactionIsolationLevel::RepeatableRead, + "serializable" => TransactionIsolationLevel::Serializable, + "read uncommitted" => TransactionIsolationLevel::ReadUncommitted, + _ => TransactionIsolationLevel::ReadCommitted, + }; + } + TransactionIsolationLevel::ReadCommitted + } + + fn get_default_transaction_read_only(&self) -> bool { + if let Some(is_readonly) = self.parameters.get("default_transaction_read_only") { + return !is_readonly.to_lowercase().eq("off"); + } + false + } + + fn get_default_transaction_deferrable(&self) -> bool { + if let Some(deferrable) = self.parameters.get("default_transaction_deferrable") { + return !deferrable.to_lowercase().eq("off"); + } + false + } + + fn get_default_transaction_parameters(&self) -> CommonTxnParams { + CommonTxnParams::new( + self.get_default_transaction_isolation(), + self.get_default_transaction_read_only(), + self.get_default_transaction_deferrable(), + ) + } +} + +impl Server { + pub fn server_default_transaction_parameters(&self) -> CommonTxnParams { + self.server_parameters.get_default_transaction_parameters() + } + + /// Sends some queries to the server to sync the given pramaters specified by 'keys'. + pub async fn sync_given_parameter_keys(&mut self, keys: &[String]) -> Result<(), Error> { + let mut key_values = HashMap::new(); + for key in keys { + if let Some(value) = self.server_parameters.parameters.get(key) { + key_values.insert(key.clone(), value.clone()); + } + } + self.sync_given_parameter_key_values(&key_values).await + } + + /// Sends some queries to the server to sync the given pramaters specified by 'key_values'. + pub async fn sync_given_parameter_key_values( + &mut self, + key_values: &HashMap, + ) -> Result<(), Error> { + let mut query = String::from(""); + + for (key, value) in key_values { + query.push_str(&format!("SET {} TO '{}';", key, value)); + } + + let res = self.query(&query).await; + + self.cleanup_state.reset(); + + match res { + Ok(_) => Ok(()), + Err(Error::ErrorResponse(err_res)) => { + warn!( + "Error while syncing parameters (was dropped): {:?}", + err_res + ); + Ok(()) + } + Err(err) => Err(err), + } + } + + /// Returnes true if the given server is in a failed transaction state. + pub fn in_failed_transaction(&self) -> bool { + self.transaction_metadata.is_in_failed_transaction() + } + + /// Sets the GID on the server. If we are in serializable mode, we need to register the GID to + /// the remote postgres instance, too. + pub async fn assign_xact_gid(&mut self, gid: &str) -> Result<(), Error> { + self.transaction_metadata + .set_xact_gid(Some(gid.to_string())); + Ok(()) + } + + pub async fn local_server_prepare_transaction(&mut self) -> Result<(), Error> { + debug!( + "Called local_server_prepare_transaction on {}", + self.address + ); + + let xact_gid = self.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery(format!( + "There is no GID assigned to the current transaction while it's requested to be \ + prepared to commit on the server ({}).", + self.address() + ))); + } + let xact_gid = xact_gid.unwrap(); + + self.query(&format!("PREPARE TRANSACTION '{}'", xact_gid)) + .await?; + + self.transaction_metadata.set_prepared(true); + Ok(()) + } + + pub async fn local_server_commit_prepared(&mut self) -> Result<(), Error> { + debug!("Called local_server_commit_prepared on {}.", self.address); + + let xact_gid = self.transaction_metadata.get_xact_gid(); + if xact_gid.is_none() { + return Err(Error::BadQuery( + "The current connection is not attached to a \ + transaction while it's requested to be prepared to commit." + .to_string(), + )); + } + let xact_gid = xact_gid.unwrap(); + + self.query(&format!("COMMIT PREPARED '{}'", xact_gid)).await + } +}