From 74af0e9f49fe6a3291958a7cf94111097494fbce Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 10 Apr 2023 13:57:09 -0700 Subject: [PATCH 1/3] Format cleanup --- .circleci/pgcat.toml | 2 +- .editorconfig | 14 ++++ Cargo.lock | 2 +- Cargo.toml | 2 +- examples/docker/pgcat.toml | 3 - pgcat.toml | 2 +- src/auth_passthrough.rs | 48 +++++++++---- src/client.rs | 144 +++++++++++++++++++++++++------------ src/config.rs | 6 +- src/errors.rs | 65 ++++++++++++++++- src/main.rs | 25 ++++--- src/stats/server.rs | 1 - 12 files changed, 233 insertions(+), 81 deletions(-) create mode 100644 .editorconfig diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 0d47ed72..377680a0 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -39,7 +39,7 @@ log_client_connections = false log_client_disconnections = false # Reload config automatically if it changes. -autoreload = true +autoreload = 15000 # TLS tls_certificate = ".circleci/server.cert" diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..d7a2758d --- /dev/null +++ b/.editorconfig @@ -0,0 +1,14 @@ +root = true + +[*] +trim_trailing_whitespace = true +insert_final_newline = true + +[*.rs] +indent_style = space +indent_size = 4 +max_line_length = 120 + +[*.toml] +indent_style = space +indent_size = 2 diff --git a/Cargo.lock b/Cargo.lock index be4cb3ec..5d6c56f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -739,7 +739,7 @@ dependencies = [ [[package]] name = "pgcat" -version = "1.0.0" +version = "1.0.1" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 02c80881..4e33c645 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "1.0.0" +version = "1.0.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/examples/docker/pgcat.toml b/examples/docker/pgcat.toml index c41c8cdd..bfc4c2e2 100644 --- a/examples/docker/pgcat.toml +++ b/examples/docker/pgcat.toml @@ -38,9 +38,6 @@ log_client_connections = false # If we should log client disconnections log_client_disconnections = false -# Reload config automatically if it changes. -autoreload = false - # TLS # tls_certificate = "server.cert" # tls_private_key = "server.key" diff --git a/pgcat.toml b/pgcat.toml index a6a4af50..18339197 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -45,7 +45,7 @@ log_client_connections = false log_client_disconnections = false # When set to true, PgCat reloads configs if it detects a change in the config file. -autoreload = false +autoreload = 15000 # Number of worker threads the Runtime will use (4 by default). worker_threads = 5 diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index 76483e59..f313dead 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -1,4 +1,5 @@ use crate::errors::Error; +use crate::pool::ConnectionPool; use crate::server::Server; use log::debug; @@ -78,19 +79,25 @@ impl AuthPassthrough { let user = &address.username; - debug!("Connecting to server to obtain auth hashes."); + debug!("Connecting to server to obtain auth hashes"); + let auth_query = self.query.replace("$1", user); + match Server::exec_simple_query(address, &auth_user, &auth_query).await { Ok(password_data) => { if password_data.len() == 2 && password_data.first().unwrap() == user { - if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") { - Ok(stripped_hash.to_string()) - } - else { - Err(Error::AuthPassthroughError( - "Obtained hash from auth_query does not seem to be in md5 format.".to_string(), - )) - } + if let Some(stripped_hash) = password_data + .last() + .unwrap() + .to_string() + .strip_prefix("md5") { + Ok(stripped_hash.to_string()) + } + else { + Err(Error::AuthPassthroughError( + "Obtained hash from auth_query does not seem to be in md5 format.".to_string(), + )) + } } else { Err(Error::AuthPassthroughError( "Data obtained from query does not follow the scheme 'user','hash'." @@ -99,10 +106,25 @@ impl AuthPassthrough { } } Err(err) => { - Err(Error::AuthPassthroughError( - format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", - user, err))) + Err(Error::AuthPassthroughError( + format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", + user, err)) + ) } - } + } + } +} + +pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { + let address = pool.address(0, 0); + if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { + let hash = apt.fetch_hash(address).await?; + + return Ok(hash); } + + Err(Error::ClientError(format!( + "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", + address.username, address.database + ))) } diff --git a/src/client.rs b/src/client.rs index d75c069d..688e1934 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use crate::errors::Error; +use crate::errors::{ClientIdentifier, Error}; use crate::pool::BanReason; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; @@ -12,7 +12,7 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; -use crate::auth_passthrough::AuthPassthrough; +use crate::auth_passthrough::refetch_auth_hash; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; use crate::messages::*; @@ -369,28 +369,14 @@ pub async fn startup_tls( } // Bad Postgres client. - Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err( - Error::ProtocolSyncError(format!("Bad postgres client (tls)")), - ), + Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => { + Err(Error::ProtocolSyncError("Bad postgres client (tls)".into())) + } Err(err) => Err(err), } } -async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { - let address = pool.address(0, 0); - if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { - let hash = apt.fetch_hash(address).await?; - - return Ok(hash); - } - - Err(Error::ClientError(format!( - "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", - address.username, address.database - ))) -} - impl Client where S: tokio::io::AsyncRead + std::marker::Unpin, @@ -418,7 +404,7 @@ where Some(user) => user, None => { return Err(Error::ClientError( - "Missing user parameter on client startup".to_string(), + "Missing user parameter on client startup".into(), )) } }; @@ -433,6 +419,8 @@ where None => "pgcat", }; + let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); + let admin = ["pgcat", "pgbouncer"] .iter() .filter(|db| *db == pool_name) @@ -463,7 +451,12 @@ where let code = match read.read_u8().await { Ok(p) => p, - Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), + Err(_) => { + return Err(Error::ClientSocketError( + "password code".into(), + client_identifier, + )) + } }; // PasswordMessage @@ -476,19 +469,30 @@ where let len = match read.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), + Err(_) => { + return Err(Error::ClientSocketError( + "password message length".into(), + client_identifier, + )) + } }; let mut password_response = vec![0u8; (len - 4) as usize]; match read.read_exact(&mut password_response).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), + Err(_) => { + return Err(Error::ClientSocketError( + "password message".into(), + client_identifier, + )) + } }; // Authenticate admin user. let (transaction_mode, server_info) = if admin { let config = get_config(); + // Compare server and client hashes. let password_hash = md5_hash_password( &config.general.admin_username, @@ -497,10 +501,12 @@ where ); if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); + let error = Error::ClientGeneralError("Invalid password".into(), client_identifier); + + warn!("{}", error); wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + return Err(error); } (false, generate_server_info_for_admin()) @@ -519,7 +525,10 @@ where ) .await?; - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + return Err(Error::ClientGeneralError( + "Invalid pool name".into(), + client_identifier, + )); } }; @@ -530,16 +539,23 @@ where Some(md5_hash_password(username, password, &salt)) } else { if !get_config().is_auth_query_configured() { - return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); + return Err(Error::ClientAuthImpossible(username.into())); } let mut hash = (*pool.auth_hash.read()).clone(); if hash.is_none() { - warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); + warn!( + "Query auth configured \ + but no hash password found \ + for pool {}. Will try to refetch it.", + pool_name + ); + match refetch_auth_hash(&pool).await { Ok(fetched_hash) => { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); + warn!("Password for {}, obtained. Updating.", client_identifier); + { let mut pool_auth_hash = pool.auth_hash.write(); *pool_auth_hash = Some(fetched_hash.clone()); @@ -547,16 +563,12 @@ where hash = Some(fetched_hash); } + Err(err) => { - return Err( - Error::ClientError( - format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", - username, - pool_name, - application_name, - err) - ) - ); + return Err(Error::ClientAuthPassthroughError( + err.to_string(), + client_identifier, + )); } } }; @@ -570,20 +582,31 @@ where // // @TODO: we could end up fetching again the same password twice (see above). if password_hash.unwrap() != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); + warn!( + "Invalid password {}, will try to refetch it.", + client_identifier + ); + let fetched_hash = refetch_auth_hash(&pool).await?; let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); // Ok password changed in server an auth is possible. if new_password_hash == password_response { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); + warn!( + "Password for {}, changed in server. Updating.", + client_identifier + ); + { let mut pool_auth_hash = pool.auth_hash.write(); *pool_auth_hash = Some(fetched_hash); } } else { wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + return Err(Error::ClientGeneralError( + "Invalid password".into(), + client_identifier, + )); } } @@ -753,9 +776,9 @@ where &mut self.write, "terminating connection due to administrator command" ).await?; - self.stats.disconnect(); - return Ok(()) + self.stats.disconnect(); + return Ok(()); } // Admin clients ignore shutdown. @@ -928,11 +951,26 @@ where error!("Got Sync message but failed to get a connection from the pool"); self.buffer.clear(); } + error_response(&mut self.write, "could not get connection from the pool") .await?; - error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}", - self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err); + error!( + "Could not get connection from pool: \ + {{ \ + pool_name: {:?}, \ + username: {:?}, \ + shard: {:?}, \ + role: \"{:?}\", \ + error: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role(), + err + ); + continue; } }; @@ -999,11 +1037,25 @@ where Err(_) => { // Client idle in transaction timeout error_response(&mut self.write, "idle transaction timeout").await?; - error!("Client idle in transaction timeout: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\"}}", self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role()); + error!( + "Client idle in transaction timeout: \ + {{ \ + pool_name: {}, \ + username: {}, \ + shard: {}, \ + role: \"{:?}\" \ + }}", + self.pool_name, + self.username, + query_router.shard(), + query_router.role() + ); + break; } } } + Some(message) => { initial_message = None; message diff --git a/src/config.rs b/src/config.rs index 37494f67..13528af2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -245,8 +245,8 @@ pub struct General { #[serde(default = "General::default_worker_threads")] pub worker_threads: usize, - #[serde(default)] // False - pub autoreload: bool, + #[serde(default)] // None + pub autoreload: Option, pub tls_certificate: Option, pub tls_private_key: Option, @@ -335,7 +335,7 @@ impl Default for General { tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), log_client_connections: false, log_client_disconnections: false, - autoreload: false, + autoreload: None, tls_certificate: None, tls_private_key: None, admin_username: String::from("admin"), diff --git a/src/errors.rs b/src/errors.rs index 58fc088b..537d782b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,9 +1,13 @@ -/// Errors. +//! Errors. /// Various errors. #[derive(Debug, PartialEq)] pub enum Error { SocketError(String), + ClientSocketError(String, ClientIdentifier), + ClientGeneralError(String, ClientIdentifier), + ClientAuthImpossible(String), + ClientAuthPassthroughError(String, ClientIdentifier), ClientBadStartup, ProtocolSyncError(String), BadQuery(String), @@ -18,3 +22,62 @@ pub enum Error { AuthError(String), AuthPassthroughError(String), } + +#[derive(Clone, PartialEq, Debug)] +pub struct ClientIdentifier { + pub application_name: String, + pub username: String, + pub pool_name: String, +} + +impl ClientIdentifier { + pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier { + ClientIdentifier { + application_name: application_name.into(), + username: username.into(), + pool_name: pool_name.into(), + } + } +} + +impl std::fmt::Display for ClientIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{{ application_name: {}, username: {}, pool_name: {} }}", + self.application_name, self.username, self.pool_name + ) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match &self { + &Error::ClientSocketError(error, client_identifier) => write!( + f, + "Error reading {} from client {}", + error, client_identifier + ), + &Error::ClientGeneralError(error, client_identifier) => { + write!(f, "{} {}", error, client_identifier) + } + &Error::ClientAuthImpossible(username) => write!( + f, + "Client auth not possible, \ + no cleartext password set for username: {} \ + in config and auth passthrough (query_auth) \ + is not set up.", + username + ), + &Error::ClientAuthPassthroughError(error, client_identifier) => write!( + f, + "No cleartext password set, \ + and no auth passthrough could not \ + obtain the hash from server for {}, \ + the error was: {}", + client_identifier, error + ), + _ => todo!(), + } + } +} diff --git a/src/main.rs b/src/main.rs index 4c8987f1..39c67c2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -179,16 +179,19 @@ fn main() -> Result<(), Box> { stats_collector.collect().await; }); - info!("Config autoreloader: {}", config.general.autoreload); + info!("Config autoreloader: {}", match config.general.autoreload { + Some(interval) => format!("{} ms", interval), + None => "disabled".into(), + }); - let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); - let autoreload_client_server_map = client_server_map.clone(); + if let Some(interval) = config.general.autoreload { + let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval)); + let autoreload_client_server_map = client_server_map.clone(); - tokio::task::spawn(async move { - loop { - autoreload_interval.tick().await; - if config.general.autoreload { - info!("Automatically reloading config"); + tokio::task::spawn(async move { + loop { + autoreload_interval.tick().await; + debug!("Automatically reloading config"); if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await { if changed { @@ -196,8 +199,10 @@ fn main() -> Result<(), Box> { } }; } - } - }); + }); + }; + + #[cfg(windows)] let mut term_signal = win_signal::ctrl_close().unwrap(); diff --git a/src/stats/server.rs b/src/stats/server.rs index 009e9b57..08968a12 100644 --- a/src/stats/server.rs +++ b/src/stats/server.rs @@ -100,7 +100,6 @@ impl ServerStats { .server_idle(self.state.load(Ordering::Relaxed)); self.state.store(ServerState::Idle, Ordering::Relaxed); - self.set_undefined_application(); } /// Reports a server connection is disconecting from the pooler. From 0647b84f5d0b13d171efa04e857c706334b6b32f Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 10 Apr 2023 13:58:38 -0700 Subject: [PATCH 2/3] fmt --- src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 688e1934..4114e4bb 100644 --- a/src/client.rs +++ b/src/client.rs @@ -202,7 +202,7 @@ pub async fn client_entrypoint( // Client probably disconnected rejecting our plain text connection. Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError( - format!("Bad postgres client (plain)"), + "Bad postgres client (plain)".into(), )), Err(err) => Err(err), From d5664401674453db848a5e95dde78302d4d5917e Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 10 Apr 2023 14:21:11 -0700 Subject: [PATCH 3/3] finally --- src/errors.rs | 39 ++++++++++++++++- src/server.rs | 113 ++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 133 insertions(+), 19 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 537d782b..0930ab8b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,6 +12,8 @@ pub enum Error { ProtocolSyncError(String), BadQuery(String), ServerError, + ServerStartupError(String, ServerIdentifier), + ServerAuthError(String, ServerIdentifier), BadConfig, AllServersDown, ClientError(String), @@ -50,6 +52,31 @@ impl std::fmt::Display for ClientIdentifier { } } +#[derive(Clone, PartialEq, Debug)] +pub struct ServerIdentifier { + pub username: String, + pub database: String, +} + +impl ServerIdentifier { + pub fn new(username: &str, database: &str) -> ServerIdentifier { + ServerIdentifier { + username: username.into(), + database: database.into(), + } + } +} + +impl std::fmt::Display for ServerIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{{ username: {}, database: {} }}", + self.username, self.database + ) + } +} + impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match &self { @@ -77,7 +104,17 @@ impl std::fmt::Display for Error { the error was: {}", client_identifier, error ), - _ => todo!(), + &Error::ServerStartupError(error, server_identifier) => write!( + f, + "Error reading {} on server startup {}", + error, server_identifier, + ), + &Error::ServerAuthError(error, server_identifier) => { + write!(f, "{} for {}", error, server_identifier,) + } + + // The rest can use Debug. + err => write!(f, "{:?}", err), } } } diff --git a/src/server.rs b/src/server.rs index 37f0e0c7..14862bd0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -17,7 +17,7 @@ use tokio::net::{ use crate::config::{Address, User}; use crate::constants::*; -use crate::errors::Error; +use crate::errors::{Error, ServerIdentifier}; use crate::messages::*; use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; @@ -108,6 +108,7 @@ impl Server { let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; let mut secret_key: i32 = 0; + let server_identifier = ServerIdentifier::new(&user.username, &database); // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. @@ -119,12 +120,22 @@ impl Server { loop { let code = match stream.read_u8().await { Ok(code) => code as char, - Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "message code".into(), + server_identifier, + )) + } }; let len = match stream.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "message len".into(), + server_identifier, + )) + } }; trace!("Message: {}", code); @@ -135,7 +146,12 @@ impl Server { // Determine which kind of authentication is required, if any. let auth_code = match stream.read_i32().await { Ok(auth_code) => auth_code, - Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "auth code".into(), + server_identifier, + )) + } }; trace!("Auth: {}", auth_code); @@ -148,7 +164,12 @@ impl Server { match stream.read_exact(&mut salt).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "salt".into(), + server_identifier, + )) + } }; match &user.password { @@ -171,8 +192,12 @@ impl Server { &salt[..], ) .await?, - None => - return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database))) + None => return Err( + Error::ServerAuthError( + "Auth passthrough (auth_query) failed and no user password is set in cleartext".into(), + server_identifier + ) + ), } } } @@ -182,16 +207,28 @@ impl Server { SASL => { if scram.is_none() { - return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database))); + return Err(Error::ServerAuthError( + "SASL auth required and no password specified. \ + Auth passthrough (auth_query) method is currently \ + unsupported for SASL auth" + .into(), + server_identifier, + )); } debug!("Starting SASL authentication"); + let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; match stream.read_exact(&mut sasl_auth).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "sasl message".into(), + server_identifier, + )) + } }; let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); @@ -233,7 +270,12 @@ impl Server { match stream.read_exact(&mut sasl_data).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "sasl cont message".into(), + server_identifier, + )) + } }; let msg = BytesMut::from(&sasl_data[..]); @@ -254,7 +296,12 @@ impl Server { let mut sasl_final = vec![0u8; len as usize - 8]; match stream.read_exact(&mut sasl_final).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "sasl final message".into(), + server_identifier, + )) + } }; match scram @@ -284,7 +331,12 @@ impl Server { 'E' => { let error_code = match stream.read_u8().await { Ok(error_code) => error_code, - Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "error code message".into(), + server_identifier, + )) + } }; trace!("Error: {}", error_code); @@ -300,7 +352,12 @@ impl Server { match stream.read_exact(&mut error).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "error message".into(), + server_identifier, + )) + } }; // TODO: the error message contains multiple fields; we can decode them and @@ -319,7 +376,12 @@ impl Server { match stream.read_exact(&mut param).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "parameter status message".into(), + server_identifier, + )) + } }; // Save the parameter so we can pass it to the client later. @@ -336,12 +398,22 @@ impl Server { // See: . process_id = match stream.read_i32().await { Ok(id) => id, - Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "process id message".into(), + server_identifier, + )) + } }; secret_key = match stream.read_i32().await { Ok(id) => id, - Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "secret key message".into(), + server_identifier, + )) + } }; } @@ -351,7 +423,12 @@ impl Server { match stream.read_exact(&mut idle).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), + Err(_) => { + return Err(Error::ServerStartupError( + "transaction status message".into(), + server_identifier, + )) + } }; let (read, write) = stream.into_split(); @@ -413,7 +490,7 @@ impl Server { Ok(stream) => stream, Err(err) => { error!("Could not connect to server: {}", err); - return Err(Error::SocketError(format!("Error reading cancel message"))); + return Err(Error::SocketError("Error reading cancel message".into())); } }; configure_socket(&stream);