From 95a3c98106d3378b3994c858b3159c8fe6f9194b Mon Sep 17 00:00:00 2001 From: Jeff Davis Date: Mon, 14 Dec 2020 11:54:01 -0800 Subject: [PATCH 1/5] Make simple_query::encode() pub(crate). --- tokio-postgres/src/simple_query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 24473b896..a26e43e6e 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -63,7 +63,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro } } -fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) From bd964377b9ee975e33f9a967aafd053fe893c0fe Mon Sep 17 00:00:00 2001 From: Jeff Davis Date: Mon, 14 Dec 2020 11:58:59 -0800 Subject: [PATCH 2/5] Connection string config for replication. Co-authored-by: Petros Angelatos --- tokio-postgres/src/config.rs | 45 +++++++++++++++++++++++++++++++ tokio-postgres/src/connect_raw.rs | 8 +++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 62b45f793..8bc2a42df 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -72,6 +72,21 @@ pub enum LoadBalanceHosts { Random, } +/// Replication mode configuration. +/// +/// It is recommended that you use a PostgreSQL server patch version +/// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or +/// 9.5.25. Earlier patch levels have a bug that doesn't properly +/// handle pipelined requests after streaming has stopped. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + /// A host specification. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Host { @@ -209,6 +224,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, } impl Default for Config { @@ -242,6 +258,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, } } @@ -524,6 +541,22 @@ impl Config { self.load_balance_hosts } + /// Set replication mode. + /// + /// It is recommended that you use a PostgreSQL server patch version + /// of at least: 14.0, 13.2, 12.6, 11.11, 10.16, 9.6.21, or + /// 9.5.25. Earlier patch levels have a bug that doesn't properly + /// handle pipelined requests after streaming has stopped. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -660,6 +693,17 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "replication" => { + let mode = match value { + "off" => None, + "true" => Some(ReplicationMode::Physical), + "database" => Some(ReplicationMode::Logical), + _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), + }; + if let Some(mode) = mode { + self.replication_mode(mode); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -744,6 +788,7 @@ impl fmt::Debug for Config { config_dbg .field("target_session_attrs", &self.target_session_attrs) .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..8edf45937 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, Config}; +use crate::config::{self, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -133,6 +133,12 @@ where if let Some(application_name) = &config.application_name { params.push(("application_name", &**application_name)); } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } let mut buf = BytesMut::new(); frontend::startup_message(params, &mut buf).map_err(Error::encode)?; From 92899e83af0791a2374b25f66897840113c68103 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Fri, 28 May 2021 00:18:23 +0200 Subject: [PATCH 3/5] implement Stream for Responses Signed-off-by: Petros Angelatos --- tokio-postgres/src/client.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 92eabde36..93f7c2f7b 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -19,7 +19,7 @@ use crate::{ use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; -use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; +use futures_util::{future, pin_mut, ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; use postgres_protocol::message::backend::Message; use postgres_types::BorrowToSql; @@ -29,6 +29,7 @@ use std::fmt; use std::net::IpAddr; #[cfg(feature = "runtime")] use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; #[cfg(feature = "runtime")] @@ -61,6 +62,17 @@ impl Responses { } } +impl Stream for Responses { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match ready!((*self).poll_next(cx)) { + Err(err) if err.is_closed() => Poll::Ready(None), + msg => Poll::Ready(Some(msg)), + } + } +} + /// A cache of type info and prepared statements for fetching type info /// (corresponding to the queries in the [prepare](prepare) module). #[derive(Default)] From bed87a94dcf823cbcada2752c331c7cafdbbc26e Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 1 Apr 2021 15:13:06 +0200 Subject: [PATCH 4/5] add copy_both_simple method Signed-off-by: Petros Angelatos --- postgres-protocol/src/message/backend.rs | 34 +++ tokio-postgres/src/client.rs | 48 ++- tokio-postgres/src/connection.rs | 20 ++ tokio-postgres/src/copy_both.rs | 358 +++++++++++++++++++++++ tokio-postgres/src/lib.rs | 2 + tokio-postgres/tests/test/copy_both.rs | 125 ++++++++ tokio-postgres/tests/test/main.rs | 1 + 7 files changed, 585 insertions(+), 3 deletions(-) create mode 100644 tokio-postgres/src/copy_both.rs create mode 100644 tokio-postgres/tests/test/copy_both.rs diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 73b169288..fdc83fedb 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D'; pub const ERROR_RESPONSE_TAG: u8 = b'E'; pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; pub const NO_DATA_TAG: u8 = b'n'; @@ -93,6 +94,7 @@ pub enum Message { CopyDone, CopyInResponse(CopyInResponseBody), CopyOutResponse(CopyOutResponseBody), + CopyBothResponse(CopyBothResponseBody), DataRow(DataRowBody), EmptyQueryResponse, ErrorResponse(ErrorResponseBody), @@ -190,6 +192,16 @@ impl Message { storage, }) } + COPY_BOTH_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyBothResponse(CopyBothResponseBody { + format, + len, + storage, + }) + } EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; @@ -524,6 +536,28 @@ impl CopyOutResponseBody { } } +#[derive(Debug, Clone)] +pub struct CopyBothResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyBothResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + #[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 93f7c2f7b..4469806f9 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,6 +1,7 @@ -use crate::codec::BackendMessages; +use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::SslMode; use crate::connection::{Request, RequestMessages}; +use crate::copy_both::{CopyBothDuplex, CopyBothReceiver}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; @@ -13,8 +14,9 @@ use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, - Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, + copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, + CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, + TransactionBuilder, }; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; @@ -41,6 +43,11 @@ pub struct Responses { cur: BackendMessages, } +pub struct CopyBothHandles { + pub(crate) stream_receiver: mpsc::Receiver>, + pub(crate) sink_sender: mpsc::Sender, +} + impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { @@ -115,6 +122,32 @@ impl InnerClient { }) } + pub fn start_copy_both(&self) -> Result { + let (sender, receiver) = mpsc::channel(16); + let (stream_sender, stream_receiver) = mpsc::channel(16); + let (sink_sender, sink_receiver) = mpsc::channel(16); + + let responses = Responses { + receiver, + cur: BackendMessages::empty(), + }; + let messages = RequestMessages::CopyBoth(CopyBothReceiver::new( + responses, + sink_receiver, + stream_sender, + )); + + let request = Request { messages, sender }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + + Ok(CopyBothHandles { + stream_receiver, + sink_sender, + }) + } + pub fn typeinfo(&self) -> Option { self.cached_typeinfo.lock().typeinfo.clone() } @@ -505,6 +538,15 @@ impl Client { copy_out::copy_out(self.inner(), statement).await } + /// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy + /// data. + pub async fn copy_both_simple(&self, query: &str) -> Result, Error> + where + T: Buf + 'static + Send, + { + copy_both::copy_both_simple(self.inner(), query).await + } + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..a3449f88b 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -1,4 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_both::CopyBothReceiver; use crate::copy_in::CopyInReceiver; use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; @@ -20,6 +21,7 @@ use tokio_util::codec::Framed; pub enum RequestMessages { Single(FrontendMessage), CopyIn(CopyInReceiver), + CopyBoth(CopyBothReceiver), } pub struct Request { @@ -258,6 +260,24 @@ where .map_err(Error::io)?; self.pending_request = Some(RequestMessages::CopyIn(receiver)); } + RequestMessages::CopyBoth(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_both request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_both stream"); + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyBoth(receiver)); + } } } } diff --git a/tokio-postgres/src/copy_both.rs b/tokio-postgres/src/copy_both.rs new file mode 100644 index 000000000..d3b46eab7 --- /dev/null +++ b/tokio-postgres/src/copy_both.rs @@ -0,0 +1,358 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::{simple_query, Error}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// The state machine of CopyBothReceiver +/// +/// ```ignore +/// CopyBoth +/// / \ +/// v v +/// CopyOut CopyIn +/// \ / +/// v v +/// CopyNone +/// | +/// v +/// CopyComplete +/// | +/// v +/// CommandComplete +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CopyBothState { + /// The state before having entered the CopyBoth mode. + Setup, + /// Initial state where CopyData messages can go in both directions + CopyBoth, + /// The server->client stream is closed and we're in CopyIn mode + CopyIn, + /// The client->server stream is closed and we're in CopyOut mode + CopyOut, + /// Both directions are closed, we waiting for CommandComplete messages + CopyNone, + /// We have received the first CommandComplete message for the copy + CopyComplete, + /// We have received the final CommandComplete message for the statement + CommandComplete, +} + +/// A CopyBothReceiver is responsible for handling the CopyBoth subprotocol. It ensures that no +/// matter what the users do with their CopyBothDuplex handle we're always going to send the +/// correct messages to the backend in order to restore the connection into a usable state. +/// +/// ```ignore +/// | +/// | +/// | +/// pg -> Connection -> CopyBothReceiver ---+---> CopyBothDuplex +/// | ^ \ +/// | / v +/// | Sink Stream +/// ``` +pub struct CopyBothReceiver { + /// Receiver of backend messages from the underlying [Connection](crate::Connection) + responses: Responses, + /// Receiver of frontend messages sent by the user using + sink_receiver: mpsc::Receiver, + /// Sender of CopyData contents to be consumed by the user using + stream_sender: mpsc::Sender>, + /// The current state of the subprotocol + state: CopyBothState, + /// Holds a buffered message until we are ready to send it to the user's stream + buffered_message: Option>, +} + +impl CopyBothReceiver { + pub(crate) fn new( + responses: Responses, + sink_receiver: mpsc::Receiver, + stream_sender: mpsc::Sender>, + ) -> CopyBothReceiver { + CopyBothReceiver { + responses, + sink_receiver, + stream_sender, + state: CopyBothState::Setup, + buffered_message: None, + } + } + + /// Convenience method to set the subprotocol into an unexpected message state + fn unexpected_message(&mut self) { + self.sink_receiver.close(); + self.buffered_message = Some(Err(Error::unexpected_message())); + self.state = CopyBothState::CommandComplete; + } + + /// Processes messages from the backend, it will resolve once all backend messages have been + /// processed + fn poll_backend(&mut self, cx: &mut Context<'_>) -> Poll<()> { + use CopyBothState::*; + + loop { + // Deliver the buffered message (if any) to the user to ensure we can potentially + // buffer a new one in response to a server message + if let Some(message) = self.buffered_message.take() { + match self.stream_sender.poll_ready(cx) { + Poll::Ready(_) => { + // If the receiver has hung up we'll just drop the message + let _ = self.stream_sender.start_send(message); + } + Poll::Pending => { + // Stash the message and try again later + self.buffered_message = Some(message); + return Poll::Pending; + } + } + } + + match ready!(self.responses.poll_next_unpin(cx)) { + Some(Ok(Message::CopyBothResponse(body))) => match self.state { + Setup => { + self.buffered_message = Some(Ok(Message::CopyBothResponse(body))); + self.state = CopyBoth; + } + _ => self.unexpected_message(), + }, + Some(Ok(Message::CopyData(body))) => match self.state { + CopyBoth | CopyOut => { + self.buffered_message = Some(Ok(Message::CopyData(body))); + } + _ => self.unexpected_message(), + }, + // The server->client stream is done + Some(Ok(Message::CopyDone)) => { + match self.state { + CopyBoth => self.state = CopyIn, + CopyOut => self.state = CopyNone, + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::CommandComplete(_))) => { + match self.state { + CopyNone => self.state = CopyComplete, + CopyComplete => { + self.stream_sender.close_channel(); + self.sink_receiver.close(); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + // The server indicated an error, terminate our side if we haven't already + Some(Err(err)) => { + match self.state { + Setup | CopyBoth | CopyOut | CopyIn => { + self.sink_receiver.close(); + self.buffered_message = Some(Err(err)); + self.state = CommandComplete; + } + _ => self.unexpected_message(), + }; + } + Some(Ok(Message::ReadyForQuery(_))) => match self.state { + CommandComplete => { + self.sink_receiver.close(); + self.stream_sender.close_channel(); + } + _ => self.unexpected_message(), + }, + Some(Ok(_)) => self.unexpected_message(), + None => return Poll::Ready(()), + } + } + } +} + +/// The [Connection](crate::Connection) will keep polling this stream until it is exhausted. This +/// is the mechanism that drives the CopyBoth subprotocol forward +impl Stream for CopyBothReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use CopyBothState::*; + + match self.poll_backend(cx) { + Poll::Ready(()) => Poll::Ready(None), + Poll::Pending => match self.state { + Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) { + Some(msg) => Poll::Ready(Some(msg)), + None => { + self.state = match self.state { + CopyBoth => CopyOut, + CopyIn => CopyNone, + _ => unreachable!(), + }; + + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + }, + _ => Poll::Pending, + }, + } + } +} + +pin_project! { + /// A duplex stream for consuming streaming replication data. + /// + /// Users should ensure that CopyBothDuplex is dropped before attempting to await on a new + /// query. This will ensure that the connection returns into normal processing mode. + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ⚠️ INCORRECT ⚠️ + /// client.query("SELECT 1", &[]).await; // hangs forever + /// + /// // duplex_stream drop-ed here + /// } + /// ``` + /// + /// ```no_run + /// use tokio_postgres::Client; + /// + /// async fn foo(client: &Client) { + /// let duplex_stream = client.copy_both_simple::<&[u8]>("..").await; + /// + /// // ✅ CORRECT ✅ + /// drop(duplex_stream); + /// + /// client.query("SELECT 1", &[]).await; + /// } + /// ``` + pub struct CopyBothDuplex { + #[pin] + sink_sender: mpsc::Sender, + #[pin] + stream_receiver: mpsc::Receiver>, + buf: BytesMut, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl Stream for CopyBothDuplex { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) { + Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())), + Some(Ok(_)) => Some(Err(Error::unexpected_message())), + Some(Err(err)) => Some(Err(err)), + None => None, + }) + } +} + +impl Sink for CopyBothDuplex +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sink_sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sink_sender + .as_mut() + .start_send(FrontendMessage::CopyData(data)) + .map_err(|_| Error::closed())?; + } + + this.sink_sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + this.sink_sender.disconnect(); + Poll::Ready(Ok(())) + } +} + +pub async fn copy_both_simple( + client: &InnerClient, + query: &str, +) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy both query {}", query); + + let buf = simple_query::encode(client, query)?; + + let mut handles = client.start_copy_both()?; + + handles + .sink_sender + .send(FrontendMessage::Raw(buf)) + .await + .map_err(|_| Error::closed())?; + + match handles.stream_receiver.next().await.transpose()? { + Some(Message::CopyBothResponse(_)) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyBothDuplex { + stream_receiver: handles.stream_receiver, + sink_sender: handles.sink_sender, + buf: BytesMut::new(), + _p: PhantomPinned, + _p2: PhantomData, + }) +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index ec843d511..31364b1ff 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -123,6 +123,7 @@ pub use crate::cancel_token::CancelToken; pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; +pub use crate::copy_both::CopyBothDuplex; pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyOutStream; use crate::error::DbError; @@ -160,6 +161,7 @@ mod connect_raw; mod connect_socket; mod connect_tls; mod connection; +mod copy_both; mod copy_in; mod copy_out; pub mod error; diff --git a/tokio-postgres/tests/test/copy_both.rs b/tokio-postgres/tests/test/copy_both.rs new file mode 100644 index 000000000..2723928ac --- /dev/null +++ b/tokio-postgres/tests/test/copy_both.rs @@ -0,0 +1,125 @@ +use futures_util::{future, StreamExt, TryStreamExt}; +use tokio_postgres::{error::SqlState, Client, SimpleQueryMessage, SimpleQueryRow}; + +async fn q(client: &Client, query: &str) -> Vec { + let msgs = client.simple_query(query).await.unwrap(); + + msgs.into_iter() + .filter_map(|msg| match msg { + SimpleQueryMessage::Row(row) => Some(row), + _ => None, + }) + .collect() +} + +#[tokio::test] +async fn copy_both_error() { + let client = crate::connect("user=postgres replication=database").await; + + let err = client + .copy_both_simple::("START_REPLICATION SLOT undefined LOGICAL 0000/0000") + .await + .err() + .unwrap(); + + assert_eq!(err.code(), Some(&SqlState::UNDEFINED_OBJECT)); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both_stream_error() { + let client = crate::connect("user=postgres replication=true").await; + + q(&client, "CREATE_REPLICATION_SLOT err2 PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err2 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + let mut msgs: Vec<_> = duplex_stream.collect().await; + let result = msgs.pop().unwrap(); + assert_eq!(msgs.len(), 0); + assert!(result.unwrap_err().as_db_error().is_some()); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "DROP_REPLICATION_SLOT err2").await.len(), 0); +} + +#[tokio::test] +async fn copy_both_stream_error_sync() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "CREATE_REPLICATION_SLOT err1 TEMPORARY PHYSICAL").await; + + // This will immediately error out after entering CopyBoth mode + let duplex_stream = client + .copy_both_simple::("START_REPLICATION SLOT err1 PHYSICAL FFFF/FFFF") + .await + .unwrap(); + + // Immediately close our sink to send a CopyDone before receiving the ErrorResponse + drop(duplex_stream); + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} + +#[tokio::test] +async fn copy_both() { + let client = crate::connect("user=postgres replication=database").await; + + q(&client, "DROP TABLE IF EXISTS replication").await; + q(&client, "CREATE TABLE replication (i text)").await; + + let slot_query = "CREATE_REPLICATION_SLOT slot TEMPORARY LOGICAL \"test_decoding\""; + let lsn = q(&client, slot_query).await[0] + .get("consistent_point") + .unwrap() + .to_owned(); + + // We will attempt to read this from the other end + q(&client, "BEGIN").await; + let xid = q(&client, "SELECT txid_current()").await[0] + .get("txid_current") + .unwrap() + .to_owned(); + q(&client, "INSERT INTO replication VALUES ('processed')").await; + q(&client, "COMMIT").await; + + // Insert a second row to generate unprocessed messages in the stream + q(&client, "INSERT INTO replication VALUES ('ignored')").await; + + let query = format!("START_REPLICATION SLOT slot LOGICAL {}", lsn); + let duplex_stream = client + .copy_both_simple::(&query) + .await + .unwrap(); + + let expected = vec![ + format!("BEGIN {}", xid), + "table public.replication: INSERT: i[text]:'processed'".to_string(), + format!("COMMIT {}", xid), + ]; + + let actual: Vec<_> = duplex_stream + // Process only XLogData messages + .try_filter(|buf| future::ready(buf[0] == b'w')) + // Playback the stream until the first expected message + .try_skip_while(|buf| future::ready(Ok(!buf.ends_with(expected[0].as_ref())))) + // Take only the expected number of messsage, the rest will be discarded by tokio_postgres + .take(expected.len()) + .try_collect() + .await + .unwrap(); + + for (msg, ending) in actual.into_iter().zip(expected.into_iter()) { + assert!(msg.ends_with(ending.as_ref())); + } + + // Ensure we can continue issuing queries + assert_eq!(q(&client, "SELECT 1").await[0].get(0), Some("1")); +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 9a6aa26fe..778ddaf05 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -20,6 +20,7 @@ use tokio_postgres::{ }; mod binary_copy; +mod copy_both; mod parse; #[cfg(feature = "runtime")] mod runtime; From 88edd68e3370faf9c263c701c6300c77600a890b Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Tue, 23 Nov 2021 15:36:00 +0100 Subject: [PATCH 5/5] ci: enable logical replication in the test image Signed-off-by: Petros Angelatos --- docker/sql_setup.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 0315ac805..051a12000 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -64,6 +64,7 @@ port = 5433 ssl = on ssl_cert_file = 'server.crt' ssl_key_file = 'server.key' +wal_level = logical EOCONF cat > "$PGDATA/pg_hba.conf" <<-EOCONF @@ -82,6 +83,7 @@ host all ssl_user ::0/0 reject # IPv4 local connections: host all postgres 0.0.0.0/0 trust +host replication postgres 0.0.0.0/0 trust # IPv6 local connections: host all postgres ::0/0 trust # Unix socket connections: