diff --git a/Cargo.lock b/Cargo.lock index 9166c1b9d5..d6fad3b452 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3449,6 +3449,7 @@ dependencies = [ "crc", "dotenvy", "etcetera", + "flume", "futures-channel", "futures-core", "futures-io", @@ -3471,6 +3472,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", + "sqlx", "sqlx-core", "stringprep", "thiserror", diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index 77cc0e2403..a6e9339a6a 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -50,6 +50,7 @@ base64 = { version = "0.21.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false } byteorder = { version = "1.4.3", default-features = false, features = ["std"] } dotenvy = { workspace = true } +flume = { version = "0.11.0", default-features = false, features = ["async"] } hex = "0.4.3" home = "0.5.5" itoa = "1.0.1" @@ -71,5 +72,9 @@ workspace = true # We use JSON in the driver implementation itself so there's no reason not to enable it here. features = ["json"] + +[dev-dependencies] +sqlx = { workspace = true, features = ["postgres"] } + [target.'cfg(target_os = "windows")'.dependencies] etcetera = "0.8.0" diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index 928524edea..fe6fb0180a 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -98,7 +98,6 @@ impl PgAdvisoryLock { /// [hkdf]: https://datatracker.ietf.org/doc/html/rfc5869 /// ### Example /// ```rust - /// # extern crate sqlx_core as sqlx; /// use sqlx::postgres::{PgAdvisoryLock, PgAdvisoryLockKey}; /// /// let lock = PgAdvisoryLock::new("my first Postgres advisory lock!"); diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 9f5008f9ed..763d76d2a7 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -8,7 +8,7 @@ use crate::message::{ Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, }; use crate::types::Oid; -use crate::{PgConnectOptions, PgConnection}; +use crate::{PgConnectOptions, PgConnection, PgReplicationMode}; // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 @@ -44,6 +44,13 @@ impl PgConnection { params.push(("options", options)); } + if let Some(replication_mode) = options.replication_mode { + match replication_mode { + PgReplicationMode::Physical => params.push(("replication", "true")), + PgReplicationMode::Logical => params.push(("replication", "database")), + } + } + stream .send(Startup { username: Some(&options.username), diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 5b9d5804b2..5a149b1f3b 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -17,6 +17,7 @@ mod listener; mod message; mod options; mod query_result; +mod replication; mod row; mod statement; mod transaction; @@ -44,8 +45,9 @@ pub use database::Postgres; pub use error::{PgDatabaseError, PgErrorPosition}; pub use listener::{PgListener, PgNotification}; pub use message::PgSeverity; -pub use options::{PgConnectOptions, PgSslMode}; +pub use options::{PgConnectOptions, PgReplicationMode, PgSslMode}; pub use query_result::PgQueryResult; +pub use replication::{PgCopyBothReceiver, PgCopyBothSender, PgReplication, PgReplicationPool}; pub use row::PgRow; pub use statement::PgStatement; pub use transaction::PgTransactionManager; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 242823688f..e103d460ec 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -188,8 +188,8 @@ impl PgListener { /// # Example /// /// ```rust,no_run - /// # use sqlx_core::postgres::PgListener; - /// # use sqlx_core::error::Error; + /// # use sqlx::postgres::PgListener; + /// # use sqlx::error::Error; /// # /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { @@ -219,8 +219,8 @@ impl PgListener { /// # Example /// /// ```rust,no_run - /// # use sqlx_core::postgres::PgListener; - /// # use sqlx_core::error::Error; + /// # use sqlx::postgres::PgListener; + /// # use sqlx::error::Error; /// # /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { diff --git a/sqlx-postgres/src/message/copy.rs b/sqlx-postgres/src/message/copy.rs index db0e7398cf..9b308f22cc 100644 --- a/sqlx-postgres/src/message/copy.rs +++ b/sqlx-postgres/src/message/copy.rs @@ -3,7 +3,7 @@ use crate::io::{BufExt, BufMutExt, Decode, Encode}; use sqlx_core::bytes::{Buf, BufMut, Bytes}; use std::ops::Deref; -/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` +/// The same structure is sent for `CopyInResponse`, `CopyOutResponse` and `CopyBothResponse`. pub struct CopyResponse { pub format: i8, pub num_columns: i16, diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index ef1dbfabf0..49a276c2c6 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -64,6 +64,7 @@ pub enum MessageFormat { CommandComplete, CopyData, CopyDone, + CopyBothResponse, CopyInResponse, CopyOutResponse, DataRow, @@ -118,6 +119,7 @@ impl MessageFormat { b'R' => MessageFormat::Authentication, b'S' => MessageFormat::ParameterStatus, b'T' => MessageFormat::RowDescription, + b'W' => MessageFormat::CopyBothResponse, b'Z' => MessageFormat::ReadyForQuery, b'n' => MessageFormat::NoData, b's' => MessageFormat::PortalSuspended, diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index 73e08bc7c2..e4084820b5 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -82,7 +82,8 @@ mod ssl_mode; /// // Information about SQL queries is logged at `DEBUG` level by default. /// opts = opts.log_statements(log::LevelFilter::Trace); /// -/// let pool = PgPool::connect_with(&opts).await?; +/// let pool = PgPool::connect_with(opts).await?; +/// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] @@ -101,9 +102,20 @@ pub struct PgConnectOptions { pub(crate) application_name: Option, pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, + pub(crate) replication_mode: Option, pub(crate) options: Option, } +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum PgReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + impl Default for PgConnectOptions { fn default() -> Self { Self::new_without_pgpass().apply_pgpass() @@ -130,7 +142,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new(); /// ``` pub fn new() -> Self { @@ -167,6 +179,7 @@ impl PgConnectOptions { application_name: var("PGAPPNAME").ok(), extra_float_digits: Some("2".into()), log_settings: Default::default(), + replication_mode: None, options: var("PGOPTIONS").ok(), } } @@ -196,7 +209,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .host("localhost"); /// ``` @@ -212,7 +225,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .port(5432); /// ``` @@ -238,7 +251,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("postgres"); /// ``` @@ -252,7 +265,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("root") /// .password("safe-and-secure"); @@ -267,7 +280,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .database("postgres"); /// ``` @@ -287,7 +300,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// .ssl_mode(PgSslMode::Require); /// ``` @@ -303,7 +316,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) @@ -319,7 +332,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) @@ -339,7 +352,7 @@ impl PgConnectOptions { /// This is for illustration purposes only. /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// /// const CERT: &[u8] = b"\ /// -----BEGIN CERTIFICATE----- @@ -361,7 +374,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) @@ -381,7 +394,7 @@ impl PgConnectOptions { /// This is for illustration purposes only. /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// /// const KEY: &[u8] = b"\ /// -----BEGIN PRIVATE KEY----- @@ -403,7 +416,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgSslMode, PgConnectOptions}; + /// # use sqlx::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) @@ -430,7 +443,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .application_name("my-app"); /// ``` @@ -477,7 +490,7 @@ impl PgConnectOptions { /// /// ### Examples /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// /// let mut options = PgConnectOptions::new() /// // for Redshift and Postgres 10 @@ -492,12 +505,26 @@ impl PgConnectOptions { self } + /// Sets the replication mode. + /// + /// This option determines whether the connection should use the replication + /// protocol instead of the normal protocol. + /// + /// In physical or logical replication mode, only the simple query protocol + /// can be used. + /// + /// The default behavior is to disable the replication mode. + pub(crate) fn replication_mode(mut self, replication_mode: PgReplicationMode) -> Self { + self.replication_mode = Some(replication_mode); + self + } + /// Set additional startup options for the connection as a list of key-value pairs. /// /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .options([("geqo", "off"), ("statement_timeout", "5min")]); /// ``` @@ -542,7 +569,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .host("127.0.0.1"); /// assert_eq!(options.get_host(), "127.0.0.1"); @@ -556,7 +583,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .port(6543); /// assert_eq!(options.get_port(), 6543); @@ -570,7 +597,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .socket("/tmp"); /// assert!(options.get_socket().is_some()); @@ -584,7 +611,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("foo"); /// assert_eq!(options.get_username(), "foo"); @@ -598,7 +625,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .database("postgres"); /// assert!(options.get_database().is_some()); @@ -612,7 +639,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::{PgConnectOptions, PgSslMode}; + /// # use sqlx::postgres::{PgConnectOptions, PgSslMode}; /// let options = PgConnectOptions::new(); /// assert!(matches!(options.get_ssl_mode(), PgSslMode::Prefer)); /// ``` @@ -625,7 +652,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .application_name("service"); /// assert!(options.get_application_name().is_some()); @@ -639,7 +666,7 @@ impl PgConnectOptions { /// # Example /// /// ```rust - /// # use sqlx_postgres::PgConnectOptions; + /// # use sqlx::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .options([("foo", "bar")]); /// assert!(options.get_options().is_some()); diff --git a/sqlx-postgres/src/replication.rs b/sqlx-postgres/src/replication.rs new file mode 100644 index 0000000000..2cdffac1cb --- /dev/null +++ b/sqlx-postgres/src/replication.rs @@ -0,0 +1,247 @@ +use crate::{ + error::Error, + io::Encode, + message::{CopyData, CopyDone, CopyResponse, MessageFormat, Query}, + PgConnectOptions, PgPool, PgPoolOptions, PgReplicationMode, Postgres, Result, +}; +use futures_util::stream::Stream; +use futures_util::{future::Either, FutureExt as _}; +use sqlx_core::{bytes::Bytes, pool::PoolConnection}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug, Clone)] +pub struct PgReplicationPool(PgPool); + +impl PgReplicationPool { + pub async fn connect(url: &str, mode: PgReplicationMode) -> Result { + let pool = PgPoolOptions::new() + .max_connections(1) + .max_lifetime(None) + .idle_timeout(None) + .connect(url) + .await?; + Ok(Self::from_pool(pool, mode)) + } + + pub fn from_pool(pool: PgPool, mode: PgReplicationMode) -> Self { + let pool_options = pool.options().clone(); + let connect_options = + ::clone(&pool.connect_options()).replication_mode(mode); + + Self(pool_options.parent(pool).connect_lazy_with(connect_options)) + } + + pub async fn acquire(&self) -> Result, Error> { + self.0.acquire().await + } + + /// Open a duplex connection allowing high-speed bulk data transfer to and from the server. + /// + /// # Example + /// + /// ```rust,no_run + /// use sqlx::postgres::{ + /// PgReplicationPool, PgReplicationMode, PgReplication + /// }; + /// use futures_util::stream::StreamExt; + /// # #[cfg(feature = "_rt")] + /// # sqlx::__rt::test_block_on(async move { + /// let pool = PgReplicationPool::connect("0.0.0.0", PgReplicationMode::Logical) + /// .await + /// .expect("failed to connect to postgres"); + /// + /// let query = format!( + /// r#"START_REPLICATION SLOT "{}" LOGICAL {} ("proto_version" '1', "publication_names" '{}')"#, + /// "test_slot", "0/1573178", "test_publication", + /// ); + /// let PgReplication {sender, receiver} = pool.start_replication(query.as_str()) + /// .await + /// .expect("start replication"); + /// // Read data from the server. + /// while let Some(data) = receiver.next().await { + /// println!("data: {:?}", data); + /// // And send some back (e.g. keepalive). + /// sender.send(Vec::new()).await?; + /// } + /// // Connection closed. + /// # Result::<(), Error>::Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn start_replication(&self, statement: &str) -> Result { + // Setup upstream/downstream channels. + let (recv_tx, recv_rx) = flume::bounded(1); + let (send_tx, send_rx) = flume::bounded(1); + + crate::rt::spawn({ + let pool = self.clone(); + async move { + if let Err(err) = copy_both_handler(pool, recv_tx.clone(), send_rx).await { + let _ignored = recv_tx.send_async(Err(err)).await; + } + } + }); + + // Execute the given statement to switch into CopyBoth mode. + let mut buf = Vec::new(); + Query(statement).encode(&mut buf); + send_tx + .send_async(PgCopyBothCommand::Begin(buf)) + .await + .map_err(|_err| Error::WorkerCrashed)?; + + Ok(PgReplication { + sender: PgCopyBothSender(send_tx), + receiver: PgCopyBothReceiver(recv_rx.into_stream()), + }) + } +} + +enum PgCopyBothCommand { + Begin(Vec), + CopyData(Vec), + CopyDone { from_client: bool }, +} + +pub struct PgCopyBothSender(flume::Sender); +pub struct PgCopyBothReceiver(flume::r#async::RecvStream<'static, Result>); + +pub struct PgReplication { + pub receiver: PgCopyBothReceiver, + pub sender: PgCopyBothSender, +} + +impl PgCopyBothSender { + /// Send a chunk of `COPY` data. + pub async fn send(&self, data: impl Into>) -> Result<()> { + self.0 + .send_async(PgCopyBothCommand::CopyData(data.into())) + .await + .map_err(|_err| Error::WorkerCrashed)?; + + Ok(()) + } + + /// Signal that the CopyBoth mode is complete. + pub async fn finish(self) -> Result<()> { + self.0 + .send_async(PgCopyBothCommand::CopyDone { from_client: true }) + .await + .map_err(|_err| Error::WorkerCrashed)?; + + Ok(()) + } +} + +impl Stream for PgCopyBothReceiver { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } +} + +async fn copy_both_handler( + pool: PgReplicationPool, + recv_tx: flume::Sender>, + send_rx: flume::Receiver, +) -> Result<()> { + let mut has_started = false; + let mut conn = pool.0.acquire().await?; + conn.wait_until_ready().await?; + + loop { + // Wait for either incoming data or a message to send. + let command = match futures_util::future::select( + // send_rx.recv_async() should be the first parameter because select is biased, + // it only selects the first argument if it is always ready. If conn.stream.recv() + // was first we could fail to respond to keep alives. + std::pin::pin!(send_rx.recv_async()), + std::pin::pin!(conn.stream.recv()), + ) + .await + { + Either::Left((command, _)) => { + Some(command.or_else(|err| match err { + flume::RecvError::Disconnected => { + // This only errors if the consumer has been dropped. + // There is no reason to continue. + Ok(PgCopyBothCommand::CopyDone { from_client: true }) + } + _ => Err(Error::WorkerCrashed), + })?) + } + Either::Right((data, _)) => { + let message = data?; + match message.format { + MessageFormat::CopyData => { + recv_tx + .send_async(message.decode::>().map(|x| x.0)) + .await + .map_err(|_err| Error::WorkerCrashed)?; + None + } + // Server is done sending data, close our side. + MessageFormat::CopyDone => { + let _ = message.decode::()?; + Some(PgCopyBothCommand::CopyDone { from_client: false }) + } + _ => { + return Err(err_protocol!( + "unexpected message format during copy out: {:?}", + message.format + )) + } + } + } + }; + + if let Some(command) = command { + match command { + // Start the stream. + PgCopyBothCommand::Begin(buf) => { + if has_started { + return Err(err_protocol!("Copy-Both mode already initiated")); + } + conn.stream.send(buf.as_slice()).await?; + // Consume the server response. + conn.stream + .recv_expect::(MessageFormat::CopyBothResponse) + .await?; + has_started = true; + } + // Send data to the server. + PgCopyBothCommand::CopyData(data) => { + if !has_started { + return Err(err_protocol!("connection hasn't been started")); + } + conn.stream.send(CopyData(data)).await?; + } + + // Grafeceful shutdown of the stream. + PgCopyBothCommand::CopyDone { from_client } => { + if !has_started { + return Err(err_protocol!("connection hasn't been started")); + } + conn.stream.send(CopyDone).await?; + // If we are the first to send CopyDone, wait for the server to send his own. + if from_client { + conn.stream.recv_expect(MessageFormat::CopyDone).await?; + } + conn.stream + .recv_expect(MessageFormat::CommandComplete) + .await?; // Content: "START_REPLICATION" + conn.stream + .recv_expect(MessageFormat::CommandComplete) + .await?; // Content: "COPY0/0" + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + break; + } + } + } + } + + Ok(()) +}