Skip to content

Adds details to errors and fixes error propagation bug #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ where
let code = query.get_u8() as char;

if code != 'Q' {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Invalid code, expected 'Q' but got '{}'",
code
)));
}

let len = query.get_i32() as usize;
Expand Down
46 changes: 31 additions & 15 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,12 @@ pub async fn client_entrypoint(
}

// Client probably disconnected rejecting our plain text connection.
_ => Err(Error::ProtocolSyncError),
Ok((ClientConnectionType::Tls, _))
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
format!("Bad postgres client (plain)"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary plain because this handles the TLS use case too.

)),

Err(err) => Err(err),
}
}
}
Expand Down Expand Up @@ -297,7 +302,10 @@ where

// Something else, probably something is wrong and it's not our fault,
// e.g. badly implemented Postgres client.
_ => Err(Error::ProtocolSyncError),
_ => Err(Error::ProtocolSyncError(format!(
"Unexpected startup code: {}",
code
))),
}
}

Expand Down Expand Up @@ -343,7 +351,11 @@ pub async fn startup_tls(
}

// Bad Postgres client.
_ => Err(Error::ProtocolSyncError),
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
),

Err(err) => Err(err),
}
}

Expand Down Expand Up @@ -374,7 +386,11 @@ where
// This parameter is mandatory by the protocol.
let username = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
None => {
return Err(Error::ClientError(
"Missing user parameter on client startup".to_string(),
))
}
};

let pool_name = match parameters.get("database") {
Expand Down Expand Up @@ -417,25 +433,27 @@ where

let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};

// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}

let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};

let mut password_response = vec![0u8; (len - 4) as usize];

match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};

// Authenticate admin user.
Expand All @@ -451,7 +469,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?;

return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}

(false, generate_server_info_for_admin())
Expand All @@ -470,8 +488,7 @@ where
)
.await?;

warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}
};

Expand All @@ -482,7 +499,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?;

return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}

let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
Expand Down Expand Up @@ -669,8 +686,7 @@ where
)
.await?;

warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name);
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name)));
}
};
query_router.update_pool_settings(pool.settings.clone());
Expand Down
6 changes: 3 additions & 3 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
/// Various errors.
#[derive(Debug, PartialEq)]
pub enum Error {
SocketError,
SocketError(String),
ClientBadStartup,
ProtocolSyncError,
ProtocolSyncError(String),
ServerError,
BadConfig,
AllServersDown,
ClientError,
ClientError(String),
TlsError,
StatementTimeout,
ShuttingDown,
Expand Down
30 changes: 24 additions & 6 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu

match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket"
)))
}
}
}

Expand Down Expand Up @@ -450,7 +454,7 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
}
}

Expand All @@ -461,7 +465,7 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
}
}

Expand All @@ -472,19 +476,33 @@ where
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message code from socket"
)))
}
};

let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message len from socket, code: {:?}",
code
)))
}
};

let mut buf = vec![0u8; len as usize - 4];

match stream.read_exact(&mut buf).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message from socket, code: {:?}",
code
)))
}
};

let mut bytes = BytesMut::with_capacity(len as usize + 1);
Expand Down
12 changes: 6 additions & 6 deletions src/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?;

if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}

let salt = match base64::decode(&server_message.salt) {
Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};

let salted_password = Self::hi(
Expand Down Expand Up @@ -163,7 +163,7 @@ impl ScramSha256 {

let verifier = match base64::decode(&final_message.value) {
Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};

let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
Expand Down Expand Up @@ -225,14 +225,14 @@ impl Message {
.collect::<Vec<String>>();

if parts.len() != 3 {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}

let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};

Ok(Message {
Expand All @@ -252,7 +252,7 @@ impl FinalMessage {
/// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}

Ok(FinalMessage {
Expand Down
Loading