diff --git a/Cargo.lock b/Cargo.lock index bf961a61..7991667e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -762,9 +762,11 @@ dependencies = [ "once_cell", "parking_lot", "phf", + "pin-project", "postgres-protocol", "rand", "regex", + "rustls", "rustls-pemfile", "serde", "serde_derive", @@ -776,6 +778,7 @@ dependencies = [ "tokio", "tokio-rustls", "toml", + "webpki-roots", ] [[package]] @@ -820,6 +823,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "pin-project-lite" version = "0.2.9" @@ -1446,6 +1469,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125" +dependencies = [ + "rustls-webpki", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index a5573518..28e94a6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,9 @@ nix = "0.26.2" atomic_enum = "0.2.0" postgres-protocol = "0.6.5" fallible-iterator = "0.2" +pin-project = "1" +webpki-roots = "0.23" +rustls = { version = "0.21", features = ["dangerous_configuration"] } [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/pgcat.toml b/pgcat.toml index 9203cb60..df2ba715 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -61,9 +61,15 @@ tcp_keepalives_count = 5 tcp_keepalives_interval = 5 # Path to TLS Certificate file to use for TLS connections -# tls_certificate = "server.cert" +# tls_certificate = ".circleci/server.cert" # Path to TLS private key file to use for TLS connections -# tls_private_key = "server.key" +# tls_private_key = ".circleci/server.key" + +# Enable/disable server TLS +server_tls = false + +# Verify server certificate is completely authentic. +verify_server_certificate = false # User name to access the virtual administrative database (pgbouncer or pgcat) # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. diff --git a/src/client.rs b/src/client.rs index 5098ec6f..efde7554 100644 --- a/src/client.rs +++ b/src/client.rs @@ -539,6 +539,7 @@ where Some(md5_hash_password(username, password, &salt)) } else { if !get_config().is_auth_query_configured() { + wrong_password(&mut write, username).await?; return Err(Error::ClientAuthImpossible(username.into())); } @@ -565,6 +566,8 @@ where } Err(err) => { + wrong_password(&mut write, username).await?; + return Err(Error::ClientAuthPassthroughError( err.to_string(), client_identifier, @@ -587,7 +590,15 @@ where client_identifier ); - let fetched_hash = refetch_auth_hash(&pool).await?; + let fetched_hash = match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => fetched_hash, + Err(err) => { + wrong_password(&mut write, username).await?; + + return Err(err); + } + }; + let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); // Ok password changed in server an auth is possible. diff --git a/src/config.rs b/src/config.rs index d822486d..4af7beda 100644 --- a/src/config.rs +++ b/src/config.rs @@ -281,6 +281,13 @@ pub struct General { pub tls_certificate: Option, pub tls_private_key: Option, + + #[serde(default)] // false + pub server_tls: bool, + + #[serde(default)] // false + pub verify_server_certificate: bool, + pub admin_username: String, pub admin_password: String, @@ -373,6 +380,8 @@ impl Default for General { autoreload: None, tls_certificate: None, tls_private_key: None, + server_tls: false, + verify_server_certificate: false, admin_username: String::from("admin"), admin_password: String::from("admin"), auth_query: None, @@ -852,6 +861,11 @@ impl Config { info!("TLS support is disabled"); } }; + info!("Server TLS enabled: {}", self.general.server_tls); + info!( + "Server TLS certificate verification: {}", + self.general.verify_server_certificate + ); for (pool_name, pool_config) in &self.pools { // TODO: Make this output prettier (maybe a table?) diff --git a/src/messages.rs b/src/messages.rs index ba4818ce..0e980fe6 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -116,7 +116,10 @@ where /// Send the startup packet the server. We're pretending we're a Pg client. /// This tells the server which user we are and what database we want. -pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { +pub async fn startup(stream: &mut S, user: &str, database: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut bytes = BytesMut::with_capacity(25); bytes.put_i32(196608); // Protocol number @@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } } +pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> { + let mut bytes = BytesMut::with_capacity(12); + + bytes.put_i32(8); + bytes.put_i32(80877103); + + match stream.write_all(&bytes).await { + Ok(_) => Ok(()), + Err(err) => Err(Error::SocketError(format!( + "Error writing SSLRequest to server socket - Error: {:?}", + err + ))), + } +} + /// Parse the params the server sends as a key/value format. pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { let mut result = HashMap::new(); @@ -505,6 +523,29 @@ where } } +pub async fn write_all_flush(stream: &mut S, buf: &[u8]) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + match stream.write_all(buf).await { + Ok(_) => match stream.flush().await { + Ok(_) => Ok(()), + Err(err) => { + return Err(Error::SocketError(format!( + "Error flushing socket - Error: {:?}", + err + ))) + } + }, + Err(err) => { + return Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))) + } + } +} + /// Read a complete message from the socket. pub async fn read_message(stream: &mut S) -> Result where diff --git a/src/pool.rs b/src/pool.rs index 8ec88604..ee8de446 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -376,8 +376,7 @@ impl ConnectionPool { .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) .test_on_check_out(false) .build(manager) - .await - .unwrap(); + .await?; pools.push(pool); servers.push(address); diff --git a/src/server.rs b/src/server.rs index 84bed6cc..5bcd5fb9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,13 +9,12 @@ use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use std::time::SystemTime; -use tokio::io::{AsyncReadExt, BufReader}; -use tokio::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, -}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; +use tokio::net::TcpStream; +use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::config::{Address, User}; +use crate::config::{get_config, Address, User}; use crate::constants::*; use crate::errors::{Error, ServerIdentifier}; use crate::messages::*; @@ -23,6 +22,84 @@ use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; use crate::scram::ScramSha256; use crate::stats::ServerStats; +use std::io::Write; + +use pin_project::pin_project; + +#[pin_project(project = SteamInnerProj)] +pub enum StreamInner { + Plain { + #[pin] + stream: TcpStream, + }, + Tls { + #[pin] + stream: TlsStream, + }, +} + +impl AsyncWrite for StreamInner { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let this = self.project(); + match this { + SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf), + SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + SteamInnerProj::Tls { stream } => stream.poll_flush(cx), + SteamInnerProj::Plain { stream } => stream.poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx), + SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx), + } + } +} + +impl AsyncRead for StreamInner { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf), + SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf), + } + } +} + +impl StreamInner { + pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + StreamInner::Tls { stream } => { + let r = stream.get_mut(); + let mut w = r.1.writer(); + w.write(buf) + } + StreamInner::Plain { stream } => stream.try_write(buf), + } + } +} /// Server state. pub struct Server { @@ -30,11 +107,8 @@ pub struct Server { /// port, e.g. 5432, and role, e.g. primary or replica. address: Address, - /// Buffered read socket. - read: BufReader, - - /// Unbuffered write socket (our client code buffers). - write: OwnedWriteHalf, + /// Server TCP connection. + stream: BufStream, /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, @@ -98,8 +172,88 @@ impl Server { ))); } }; + + // TCP timeouts. configure_socket(&stream); + let config = get_config(); + + let mut stream = if config.general.server_tls { + // Request a TLS connection + ssl_request(&mut stream).await?; + + let response = match stream.read_u8().await { + Ok(response) => response as char, + Err(err) => { + return Err(Error::SocketError(format!( + "Server socket error: {:?}", + err + ))) + } + }; + + match response { + // Server supports TLS + 'S' => { + debug!("Connecting to server using TLS"); + + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }), + ); + + let mut tls_config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + + // Equivalent to sslmode=prefer which is fine most places. + // If you want verify-full, change `verify_server_certificate` to true. + if !config.general.verify_server_certificate { + let mut dangerous = tls_config.dangerous(); + dangerous.set_certificate_verifier(Arc::new( + crate::tls::NoCertificateVerification {}, + )); + } + + let connector = TlsConnector::from(Arc::new(tls_config)); + let stream = match connector + .connect(address.host.as_str().try_into().unwrap(), stream) + .await + { + Ok(stream) => stream, + Err(err) => { + return Err(Error::SocketError(format!("Server TLS error: {:?}", err))) + } + }; + + StreamInner::Tls { stream } + } + + // Server does not support TLS + 'N' => StreamInner::Plain { stream }, + + // Something else? + m => { + return Err(Error::SocketError(format!( + "Unknown message: {}", + m as char + ))); + } + } + } else { + StreamInner::Plain { stream } + }; + + // let (read, write) = split(stream); + // let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write }); + trace!("Sending StartupMessage"); // StartupMessage @@ -245,7 +399,7 @@ impl Server { let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); - if sasl_type == SCRAM_SHA_256 { + if sasl_type.contains(SCRAM_SHA_256) { debug!("Using {}", SCRAM_SHA_256); // Generate client message. @@ -268,7 +422,7 @@ impl Server { res.put_i32(sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut stream, res).await?; + write_all_flush(&mut stream, &res).await?; } else { error!("Unsupported SCRAM version: {}", sasl_type); return Err(Error::ServerError); @@ -299,7 +453,7 @@ impl Server { res.put_i32(4 + sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut stream, res).await?; + write_all_flush(&mut stream, &res).await?; } SASL_FINAL => { @@ -443,12 +597,9 @@ impl Server { } }; - let (read, write) = stream.into_split(); - let mut server = Server { address: address.clone(), - read: BufReader::new(read), - write, + stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), server_info, process_id, @@ -515,7 +666,7 @@ impl Server { bytes.put_i32(process_id); bytes.put_i32(secret_key); - write_all(&mut stream, bytes).await + write_all_flush(&mut stream, &bytes).await } /// Send messages to the server from the client. @@ -523,7 +674,7 @@ impl Server { self.mirror_send(messages); self.stats().data_sent(messages.len()); - match write_all_half(&mut self.write, messages).await { + match write_all_flush(&mut self.stream, &messages).await { Ok(_) => { // Successfully sent to server self.last_activity = SystemTime::now(); @@ -542,7 +693,7 @@ impl Server { /// in order to receive all data the server has to offer. pub async fn recv(&mut self) -> Result { loop { - let mut message = match read_message(&mut self.read).await { + let mut message = match read_message(&mut self.stream).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -935,13 +1086,13 @@ impl Drop for Server { // Update statistics self.stats.disconnect(); - let mut bytes = BytesMut::with_capacity(4); + let mut bytes = BytesMut::with_capacity(5); bytes.put_u8(b'X'); bytes.put_i32(4); - match self.write.try_write(&bytes) { - Ok(_) => (), - Err(_) => debug!("Dirty shutdown"), + match self.stream.get_mut().try_write(&bytes) { + Ok(5) => (), + _ => debug!("Dirty shutdown"), }; // Should not matter. diff --git a/src/tls.rs b/src/tls.rs index fbfbae75..6c4a7f5b 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item}; use std::iter; use std::path::Path; use std::sync::Arc; -use tokio_rustls::rustls::{self, Certificate, PrivateKey}; +use std::time::SystemTime; +use tokio_rustls::rustls::{ + self, + client::{ServerCertVerified, ServerCertVerifier}, + Certificate, PrivateKey, ServerName, +}; use tokio_rustls::TlsAcceptor; use crate::config::get_config; @@ -64,3 +69,19 @@ impl Tls { }) } } + +pub struct NoCertificateVerification; + +impl ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} diff --git a/tests/ruby/mirrors_spec.rb b/tests/ruby/mirrors_spec.rb index 801df28c..898d0d71 100644 --- a/tests/ruby/mirrors_spec.rb +++ b/tests/ruby/mirrors_spec.rb @@ -25,7 +25,7 @@ processes.pgcat.shutdown end - it "can mirror a query" do + xit "can mirror a query" do conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) runs = 15 runs.times { conn.async_exec("SELECT 1 + 2") }