From 2d2a5dea81be7fdac54ea1ea64b7a407efb13c9b Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Thu, 27 Jun 2019 21:37:22 -0700 Subject: [PATCH] Send response messages in blocks Our codec implementation originally just parsed single messages out of the stream buffer. However, if a query returns a bunch of rows, we're spending a ton of time shipping those individual messages from the connection back to the Query stream. Instead, collect blocks of unparsed messages that are as large as possible and send those back. This cuts the processing time of the following query in half, from ~10 seconds to ~5: `SELECT s.n, 'name' || s.n FROM generate_series(0, 9999999) AS s(n)` At this point, almost all of the remainder of the time is spent parsing the rows. cc #450 --- Cargo.toml | 3 + postgres-protocol/src/message/backend.rs | 104 ++++++++++++++++++----- tokio-postgres/src/proto/bind.rs | 8 +- tokio-postgres/src/proto/client.rs | 6 +- tokio-postgres/src/proto/codec.rs | 70 ++++++++++++++- tokio-postgres/src/proto/connect_raw.rs | 74 ++++++++++++++-- tokio-postgres/src/proto/connection.rs | 51 ++++++----- tokio-postgres/src/proto/copy_in.rs | 13 +-- tokio-postgres/src/proto/copy_out.rs | 10 +-- tokio-postgres/src/proto/execute.rs | 8 +- tokio-postgres/src/proto/mod.rs | 11 +-- tokio-postgres/src/proto/prepare.rs | 14 +-- tokio-postgres/src/proto/query.rs | 6 +- tokio-postgres/src/proto/responses.rs | 42 +++++++++ tokio-postgres/src/proto/simple_query.rs | 6 +- 15 files changed, 324 insertions(+), 102 deletions(-) create mode 100644 tokio-postgres/src/proto/responses.rs diff --git a/Cargo.toml b/Cargo.toml index 40e30b1e8..37421ba99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,6 @@ members = [ "tokio-postgres-native-tls", "tokio-postgres-openssl", ] + +[profile.release] +debug = 2 diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index f9c37b590..909f8bfd3 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] -use byteorder::{BigEndian, ReadBytesExt}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use memchr::memchr; @@ -11,6 +11,66 @@ use std::str; use crate::Oid; +pub const PARSE_COMPLETE_TAG: u8 = b'1'; +pub const BIND_COMPLETE_TAG: u8 = b'2'; +pub const CLOSE_COMPLETE_TAG: u8 = b'3'; +pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A'; +pub const COPY_DONE_TAG: u8 = b'c'; +pub const COMMAND_COMPLETE_TAG: u8 = b'C'; +pub const COPY_DATA_TAG: u8 = b'd'; +pub const DATA_ROW_TAG: u8 = b'D'; +pub const ERROR_RESPONSE_TAG: u8 = b'E'; +pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; +pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; +pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; +pub const NO_DATA_TAG: u8 = b'n'; +pub const NOTICE_RESPONSE_TAG: u8 = b'N'; +pub const AUTHENTICATION_TAG: u8 = b'R'; +pub const PORTAL_SUSPENDED_TAG: u8 = b's'; +pub const PARAMETER_STATUS_TAG: u8 = b'S'; +pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; +pub const ROW_DESCRIPTION_TAG: u8 = b'T'; +pub const READY_FOR_QUERY_TAG: u8 = b'Z'; + +#[derive(Debug, Copy, Clone)] +pub struct Header { + tag: u8, + len: i32, +} + +#[allow(clippy::len_without_is_empty)] +impl Header { + #[inline] + pub fn parse(buf: &[u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_i32(&buf[1..]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length", + )); + } + + Ok(Some(Header { tag, len })) + } + + #[inline] + pub fn tag(self) -> u8 { + self.tag + } + + #[inline] + pub fn len(self) -> i32 { + self.len + } +} + /// An enum representing Postgres backend messages. pub enum Message { AuthenticationCleartextPassword, @@ -80,10 +140,10 @@ impl Message { }; let message = match tag { - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'A' => { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { let process_id = buf.read_i32::()?; let channel = buf.read_cstr()?; let message = buf.read_cstr()?; @@ -93,25 +153,25 @@ impl Message { message, }) } - b'c' => Message::CopyDone, - b'C' => { + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { let tag = buf.read_cstr()?; Message::CommandComplete(CommandCompleteBody { tag }) } - b'd' => { + COPY_DATA_TAG => { let storage = buf.read_all(); Message::CopyData(CopyDataBody { storage }) } - b'D' => { + DATA_ROW_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); Message::DataRow(DataRowBody { storage, len }) } - b'E' => { + ERROR_RESPONSE_TAG => { let storage = buf.read_all(); Message::ErrorResponse(ErrorResponseBody { storage }) } - b'G' => { + COPY_IN_RESPONSE_TAG => { let format = buf.read_u8()?; let len = buf.read_u16::()?; let storage = buf.read_all(); @@ -121,7 +181,7 @@ impl Message { storage, }) } - b'H' => { + COPY_OUT_RESPONSE_TAG => { let format = buf.read_u8()?; let len = buf.read_u16::()?; let storage = buf.read_all(); @@ -131,8 +191,8 @@ impl Message { storage, }) } - b'I' => Message::EmptyQueryResponse, - b'K' => { + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; let secret_key = buf.read_i32::()?; Message::BackendKeyData(BackendKeyDataBody { @@ -140,12 +200,12 @@ impl Message { secret_key, }) } - b'n' => Message::NoData, - b'N' => { + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { let storage = buf.read_all(); Message::NoticeResponse(NoticeResponseBody { storage }) } - b'R' => match buf.read_i32::()? { + AUTHENTICATION_TAG => match buf.read_i32::()? { 0 => Message::AuthenticationOk, 2 => Message::AuthenticationKerberosV5, 3 => Message::AuthenticationCleartextPassword, @@ -180,23 +240,23 @@ impl Message { )); } }, - b's' => Message::PortalSuspended, - b'S' => { + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, + PARAMETER_STATUS_TAG => { let name = buf.read_cstr()?; let value = buf.read_cstr()?; Message::ParameterStatus(ParameterStatusBody { name, value }) } - b't' => { + PARAMETER_DESCRIPTION_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); Message::ParameterDescription(ParameterDescriptionBody { storage, len }) } - b'T' => { + ROW_DESCRIPTION_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); Message::RowDescription(RowDescriptionBody { storage, len }) } - b'Z' => { + READY_FOR_QUERY_TAG => { let status = buf.read_u8()?; Message::ReadyForQuery(ReadyForQueryBody { status }) } diff --git a/tokio-postgres/src/proto/bind.rs b/tokio-postgres/src/proto/bind.rs index 0a8a3b5f0..944afd6ea 100644 --- a/tokio-postgres/src/proto/bind.rs +++ b/tokio-postgres/src/proto/bind.rs @@ -1,10 +1,10 @@ -use futures::sync::mpsc; -use futures::{Poll, Stream}; +use futures::{try_ready, Poll, Stream}; use postgres_protocol::message::backend::Message; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use crate::proto::client::{Client, PendingRequest}; use crate::proto::portal::Portal; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::Error; @@ -19,7 +19,7 @@ pub enum Bind { }, #[state_machine_future(transitions(Finished))] ReadBindComplete { - receiver: mpsc::Receiver, + receiver: Responses, client: Client, name: String, statement: Statement, @@ -46,7 +46,7 @@ impl PollBind for Bind { fn poll_read_bind_complete<'a>( state: &'a mut RentToOwn<'a, ReadBindComplete>, ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); let state = state.take(); match message { diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index abb2f5341..7fb070e90 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -3,7 +3,6 @@ use bytes::IntoBuf; use futures::sync::mpsc; use futures::{AsyncSink, Poll, Sink, Stream}; use postgres_protocol; -use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::collections::HashMap; use std::error::Error as StdError; @@ -20,6 +19,7 @@ use crate::proto::idle::{IdleGuard, IdleState}; use crate::proto::portal::Portal; use crate::proto::prepare::PrepareFuture; use crate::proto::query::QueryStream; +use crate::proto::responses::{self, Responses}; use crate::proto::simple_query::SimpleQueryStream; use crate::proto::statement::Statement; #[cfg(feature = "runtime")] @@ -130,9 +130,9 @@ impl Client { self.0.state.lock().typeinfo_composite_query = Some(statement.clone()); } - pub fn send(&self, request: PendingRequest) -> Result, Error> { + pub fn send(&self, request: PendingRequest) -> Result { let (messages, idle) = request.0?; - let (sender, receiver) = mpsc::channel(1); + let (sender, receiver) = responses::channel(); self.0 .sender .unbounded_send(Request { diff --git a/tokio-postgres/src/proto/codec.rs b/tokio-postgres/src/proto/codec.rs index c7c6d9045..4ebebd479 100644 --- a/tokio-postgres/src/proto/codec.rs +++ b/tokio-postgres/src/proto/codec.rs @@ -1,4 +1,5 @@ use bytes::{Buf, BytesMut}; +use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend; use postgres_protocol::message::frontend::CopyData; use std::io; @@ -9,6 +10,31 @@ pub enum FrontendMessage { CopyData(CopyData>), } +pub enum BackendMessage { + Normal { + messages: BackendMessages, + request_complete: bool, + }, + Async(backend::Message), +} + +pub struct BackendMessages(BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + pub struct PostgresCodec; impl Encoder for PostgresCodec { @@ -26,10 +52,48 @@ impl Encoder for PostgresCodec { } impl Decoder for PostgresCodec { - type Item = backend::Message; + type Item = BackendMessage; type Error = io::Error; - fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { - backend::Message::parse(src) + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + let mut idx = 0; + let mut request_complete = false; + + while let Some(header) = backend::Header::parse(&src[idx..])? { + let len = header.len() as usize + 1; + if src[idx..].len() < len { + break; + } + + match header.tag() { + backend::NOTICE_RESPONSE_TAG + | backend::NOTIFICATION_RESPONSE_TAG + | backend::PARAMETER_STATUS_TAG => { + if idx == 0 { + let message = backend::Message::parse(src)?.unwrap(); + return Ok(Some(BackendMessage::Async(message))); + } else { + break; + } + } + _ => {} + } + + idx += len; + + if header.tag() == backend::READY_FOR_QUERY_TAG { + request_complete = true; + break; + } + } + + if idx == 0 { + Ok(None) + } else { + Ok(Some(BackendMessage::Normal { + messages: BackendMessages(src.split_to(idx)), + request_complete, + })) + } } } diff --git a/tokio-postgres/src/proto/connect_raw.rs b/tokio-postgres/src/proto/connect_raw.rs index feeb4bfc4..0cb0ec033 100644 --- a/tokio-postgres/src/proto/connect_raw.rs +++ b/tokio-postgres/src/proto/connect_raw.rs @@ -1,6 +1,6 @@ use fallible_iterator::FallibleIterator; -use futures::sink; use futures::sync::mpsc; +use futures::{sink, Async, AsyncSink}; use futures::{try_ready, Future, Poll, Sink, Stream}; use postgres_protocol::authentication; use postgres_protocol::authentication::sasl::{self, ScramSha256}; @@ -8,13 +8,64 @@ use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::collections::HashMap; +use std::io; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; +use crate::proto::codec::{BackendMessage, BackendMessages}; use crate::proto::{Client, Connection, FrontendMessage, MaybeTlsStream, PostgresCodec, TlsFuture}; use crate::tls::ChannelBinding; use crate::{Config, Error, TlsConnect}; +pub struct StartupStream { + inner: Framed, PostgresCodec>, + buf: BackendMessages, +} + +impl Sink for StartupStream +where + S: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite, +{ + type SinkItem = FrontendMessage; + type SinkError = io::Error; + + fn start_send(&mut self, item: FrontendMessage) -> io::Result> { + self.inner.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), io::Error> { + self.inner.poll_complete() + } + + fn close(&mut self) -> Poll<(), io::Error> { + self.inner.close() + } +} + +impl Stream for StartupStream +where + S: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite, +{ + type Item = Message; + type Error = io::Error; + + fn poll(&mut self) -> Poll, io::Error> { + loop { + if let Some(message) = self.buf.next()? { + return Ok(Async::Ready(Some(message))); + } + + match try_ready!(self.inner.poll()) { + Some(BackendMessage::Async(message)) => return Ok(Async::Ready(Some(message))), + Some(BackendMessage::Normal { messages, .. }) => self.buf = messages, + None => return Ok(Async::Ready(None)), + } + } + } +} + #[derive(StateMachineFuture)] pub enum ConnectRaw where @@ -29,47 +80,47 @@ where }, #[state_machine_future(transitions(ReadingAuth))] SendingStartup { - future: sink::Send, PostgresCodec>>, + future: sink::Send>, config: Config, idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] ReadingAuth { - stream: Framed, PostgresCodec>, + stream: StartupStream, config: Config, idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingAuthCompletion))] SendingPassword { - future: sink::Send, PostgresCodec>>, + future: sink::Send>, config: Config, idx: Option, }, #[state_machine_future(transitions(ReadingSasl))] SendingSasl { - future: sink::Send, PostgresCodec>>, + future: sink::Send>, scram: ScramSha256, config: Config, idx: Option, }, #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] ReadingSasl { - stream: Framed, PostgresCodec>, + stream: StartupStream, scram: ScramSha256, config: Config, idx: Option, }, #[state_machine_future(transitions(ReadingInfo))] ReadingAuthCompletion { - stream: Framed, PostgresCodec>, + stream: StartupStream, config: Config, idx: Option, }, #[state_machine_future(transitions(Finished))] ReadingInfo { - stream: Framed, PostgresCodec>, + stream: StartupStream, process_id: i32, secret_key: i32, parameters: HashMap, @@ -109,6 +160,10 @@ where frontend::startup_message(params, &mut buf).map_err(Error::encode)?; let stream = Framed::new(stream, PostgresCodec); + let stream = StartupStream { + inner: stream, + buf: BackendMessages::empty(), + }; transition!(SendingStartup { future: stream.send(FrontendMessage::Raw(buf)), @@ -363,7 +418,8 @@ where state.config, state.idx, ); - let connection = Connection::new(state.stream, state.parameters, receiver); + let connection = + Connection::new(state.stream.inner, state.parameters, receiver); transition!(Finished((client, connection))) } Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index 14559fd0a..222fd16ea 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -1,3 +1,4 @@ +use fallible_iterator::FallibleIterator; use futures::sync::mpsc; use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream}; use log::trace; @@ -8,7 +9,7 @@ use std::io; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::codec::{FrontendMessage, PostgresCodec}; +use crate::proto::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; use crate::proto::copy_in::CopyInReceiver; use crate::proto::idle::IdleGuard; use crate::{AsyncMessage, Notification}; @@ -24,12 +25,12 @@ pub enum RequestMessages { pub struct Request { pub messages: RequestMessages, - pub sender: mpsc::Sender, + pub sender: mpsc::Sender, pub idle: Option, } struct Response { - sender: mpsc::Sender, + sender: mpsc::Sender, _idle: Option, } @@ -45,7 +46,7 @@ pub struct Connection { parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, - pending_response: Option, + pending_response: Option, responses: VecDeque, state: State, } @@ -74,7 +75,7 @@ where self.parameters.get(name).map(|s| &**s) } - fn poll_response(&mut self) -> Poll, io::Error> { + fn poll_response(&mut self) -> Poll, io::Error> { if let Some(message) = self.pending_response.take() { trace!("retrying pending response"); return Ok(Async::Ready(Some(message))); @@ -101,12 +102,12 @@ where } }; - let message = match message { - Message::NoticeResponse(body) => { + let (mut messages, request_complete) = match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; return Ok(Some(AsyncMessage::Notice(error))); } - Message::NotificationResponse(body) => { + BackendMessage::Async(Message::NotificationResponse(body)) => { let notification = Notification { process_id: body.process_id(), channel: body.channel().map_err(Error::parse)?.to_string(), @@ -114,30 +115,29 @@ where }; return Ok(Some(AsyncMessage::Notification(notification))); } - Message::ParameterStatus(body) => { + BackendMessage::Async(Message::ParameterStatus(body)) => { self.parameters.insert( body.name().map_err(Error::parse)?.to_string(), body.value().map_err(Error::parse)?.to_string(), ); continue; } - m => m, + BackendMessage::Async(_) => unreachable!(), + BackendMessage::Normal { + messages, + request_complete, + } => (messages, request_complete), }; let mut response = match self.responses.pop_front() { Some(response) => response, - None => match message { - Message::ErrorResponse(error) => return Err(Error::db(error)), + None => match messages.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), _ => return Err(Error::unexpected_message()), }, }; - let request_complete = match message { - Message::ReadyForQuery(_) => true, - _ => false, - }; - - match response.sender.start_send(message) { + match response.sender.start_send(messages) { // if the receiver's hung up we still need to page through the rest of the messages // designated to it Ok(AsyncSink::Ready) | Err(_) => { @@ -145,9 +145,12 @@ where self.responses.push_front(response); } } - Ok(AsyncSink::NotReady(message)) => { + Ok(AsyncSink::NotReady(messages)) => { self.responses.push_front(response); - self.pending_response = Some(message); + self.pending_response = Some(BackendMessage::Normal { + messages, + request_complete, + }); trace!("poll_read: waiting on sender"); return Ok(None); } @@ -161,8 +164,8 @@ where return Ok(Async::Ready(Some(message))); } - match try_ready_receive!(self.receiver.poll()) { - Some(request) => { + match self.receiver.poll() { + Ok(Async::Ready(Some(request))) => { trace!("polled new request"); self.responses.push_back(Response { sender: request.sender, @@ -170,7 +173,9 @@ where }); Ok(Async::Ready(Some(request.messages))) } - None => Ok(Async::Ready(None)), + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(()) => unreachable!("mpsc::Receiver doesn't error"), } } diff --git a/tokio-postgres/src/proto/copy_in.rs b/tokio-postgres/src/proto/copy_in.rs index 80f09c520..762f1c462 100644 --- a/tokio-postgres/src/proto/copy_in.rs +++ b/tokio-postgres/src/proto/copy_in.rs @@ -10,6 +10,7 @@ use std::error::Error as StdError; use crate::proto::client::{Client, PendingRequest}; use crate::proto::codec::FrontendMessage; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::Error; @@ -82,7 +83,7 @@ where ReadCopyInResponse { stream: S, sender: mpsc::Sender, - receiver: mpsc::Receiver, + receiver: Responses, }, #[state_machine_future(transitions(WriteCopyDone))] WriteCopyData { @@ -90,15 +91,15 @@ where buf: BytesMut, pending_message: Option, sender: mpsc::Sender, - receiver: mpsc::Receiver, + receiver: Responses, }, #[state_machine_future(transitions(ReadCommandComplete))] WriteCopyDone { future: sink::Send>, - receiver: mpsc::Receiver, + receiver: Responses, }, #[state_machine_future(transitions(Finished))] - ReadCommandComplete { receiver: mpsc::Receiver }, + ReadCommandComplete { receiver: Responses }, #[state_machine_future(ready)] Finished(u64), #[state_machine_future(error)] @@ -128,7 +129,7 @@ where state: &'a mut RentToOwn<'a, ReadCopyInResponse>, ) -> Poll, Error> { loop { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); match message { Some(Message::BindComplete) => {} @@ -229,7 +230,7 @@ where fn poll_read_command_complete<'a>( state: &'a mut RentToOwn<'a, ReadCommandComplete>, ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); match message { Some(Message::CommandComplete(body)) => { diff --git a/tokio-postgres/src/proto/copy_out.rs b/tokio-postgres/src/proto/copy_out.rs index c0418222a..1ae714188 100644 --- a/tokio-postgres/src/proto/copy_out.rs +++ b/tokio-postgres/src/proto/copy_out.rs @@ -1,10 +1,10 @@ use bytes::Bytes; -use futures::sync::mpsc; use futures::{Async, Poll, Stream}; use postgres_protocol::message::backend::Message; use std::mem; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::Error; @@ -15,10 +15,10 @@ enum State { statement: Statement, }, ReadingCopyOutResponse { - receiver: mpsc::Receiver, + receiver: Responses, }, ReadingCopyData { - receiver: mpsc::Receiver, + receiver: Responses, }, Done, } @@ -49,7 +49,7 @@ impl Stream for CopyOutStream { self.0 = State::ReadingCopyOutResponse { receiver }; break Ok(Async::NotReady); } - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + Err(e) => return Err(e), }; match message { @@ -71,7 +71,7 @@ impl Stream for CopyOutStream { self.0 = State::ReadingCopyData { receiver }; break Ok(Async::NotReady); } - Err(()) => unreachable!("mpsc::Reciever doesn't return errors"), + Err(e) => return Err(e), }; match message { diff --git a/tokio-postgres/src/proto/execute.rs b/tokio-postgres/src/proto/execute.rs index 25b1e90a0..0f8e021fe 100644 --- a/tokio-postgres/src/proto/execute.rs +++ b/tokio-postgres/src/proto/execute.rs @@ -1,9 +1,9 @@ -use futures::sync::mpsc; -use futures::{Poll, Stream}; +use futures::{try_ready, Poll, Stream}; use postgres_protocol::message::backend::Message; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::Error; @@ -16,7 +16,7 @@ pub enum Execute { statement: Statement, }, #[state_machine_future(transitions(Finished))] - ReadResponse { receiver: mpsc::Receiver }, + ReadResponse { receiver: Responses }, #[state_machine_future(ready)] Finished(u64), #[state_machine_future(error)] @@ -36,7 +36,7 @@ impl PollExecute for Execute { state: &'a mut RentToOwn<'a, ReadResponse>, ) -> Poll { loop { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); match message { Some(Message::BindComplete) => {} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 7c30cda85..2979d8a30 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -1,13 +1,3 @@ -macro_rules! try_ready_receive { - ($e:expr) => { - match $e { - Ok(::futures::Async::Ready(v)) => v, - Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady), - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), - } - }; -} - macro_rules! try_ready_closed { ($e:expr) => { match $e { @@ -40,6 +30,7 @@ mod maybe_tls_stream; mod portal; mod prepare; mod query; +mod responses; mod simple_query; mod statement; mod tls; diff --git a/tokio-postgres/src/proto/prepare.rs b/tokio-postgres/src/proto/prepare.rs index 029bbb8a5..a29aca11b 100644 --- a/tokio-postgres/src/proto/prepare.rs +++ b/tokio-postgres/src/proto/prepare.rs @@ -1,7 +1,6 @@ #![allow(clippy::large_enum_variant)] use fallible_iterator::FallibleIterator; -use futures::sync::mpsc; use futures::{try_ready, Future, Poll, Stream}; use postgres_protocol::message::backend::Message; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; @@ -9,6 +8,7 @@ use std::mem; use std::vec; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::proto::typeinfo::TypeinfoFuture; use crate::types::{Oid, Type}; @@ -25,19 +25,19 @@ pub enum Prepare { #[state_machine_future(transitions(ReadParameterDescription))] ReadParseComplete { client: Client, - receiver: mpsc::Receiver, + receiver: Responses, name: String, }, #[state_machine_future(transitions(ReadRowDescription))] ReadParameterDescription { client: Client, - receiver: mpsc::Receiver, + receiver: Responses, name: String, }, #[state_machine_future(transitions(GetParameterTypes, GetColumnTypes, Finished))] ReadRowDescription { client: Client, - receiver: mpsc::Receiver, + receiver: Responses, name: String, parameters: Vec, }, @@ -79,7 +79,7 @@ impl PollPrepare for Prepare { fn poll_read_parse_complete<'a>( state: &'a mut RentToOwn<'a, ReadParseComplete>, ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); let state = state.take(); match message { @@ -97,7 +97,7 @@ impl PollPrepare for Prepare { fn poll_read_parameter_description<'a>( state: &'a mut RentToOwn<'a, ReadParameterDescription>, ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); let state = state.take(); match message { @@ -115,7 +115,7 @@ impl PollPrepare for Prepare { fn poll_read_row_description<'a>( state: &'a mut RentToOwn<'a, ReadRowDescription>, ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); + let message = try_ready!(state.receiver.poll()); let state = state.take(); let columns = match message { diff --git a/tokio-postgres/src/proto/query.rs b/tokio-postgres/src/proto/query.rs index 59877f061..2d84abdee 100644 --- a/tokio-postgres/src/proto/query.rs +++ b/tokio-postgres/src/proto/query.rs @@ -1,10 +1,10 @@ -use futures::sync::mpsc; use futures::{Async, Poll, Stream}; use postgres_protocol::message::backend::Message; use std::mem; use crate::proto::client::{Client, PendingRequest}; use crate::proto::portal::Portal; +use crate::proto::responses::Responses; use crate::proto::statement::Statement; use crate::{Error, Row}; @@ -31,7 +31,7 @@ enum State { statement: T, }, ReadingResponse { - receiver: mpsc::Receiver, + receiver: Responses, statement: T, }, Done, @@ -73,7 +73,7 @@ where }; break Ok(Async::NotReady); } - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + Err(e) => return Err(e), }; match message { diff --git a/tokio-postgres/src/proto/responses.rs b/tokio-postgres/src/proto/responses.rs new file mode 100644 index 000000000..7cc259a83 --- /dev/null +++ b/tokio-postgres/src/proto/responses.rs @@ -0,0 +1,42 @@ +use fallible_iterator::FallibleIterator; +use futures::sync::mpsc; +use futures::{try_ready, Async, Poll, Stream}; +use postgres_protocol::message::backend; + +use crate::proto::codec::BackendMessages; +use crate::Error; + +pub fn channel() -> (mpsc::Sender, Responses) { + let (sender, receiver) = mpsc::channel(1); + + ( + sender, + Responses { + receiver, + cur: BackendMessages::empty(), + }, + ) +} + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Stream for Responses { + type Item = backend::Message; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + loop { + if let Some(message) = self.cur.next().map_err(Error::parse)? { + return Ok(Async::Ready(Some(message))); + } + + match try_ready!(self.receiver.poll().map_err(|()| Error::closed())) { + Some(messages) => self.cur = messages, + None => return Ok(Async::Ready(None)), + } + } + } +} diff --git a/tokio-postgres/src/proto/simple_query.rs b/tokio-postgres/src/proto/simple_query.rs index 71f458a84..fdfb52270 100644 --- a/tokio-postgres/src/proto/simple_query.rs +++ b/tokio-postgres/src/proto/simple_query.rs @@ -1,11 +1,11 @@ use fallible_iterator::FallibleIterator; -use futures::sync::mpsc; use futures::{Async, Poll, Stream}; use postgres_protocol::message::backend::Message; use std::mem; use std::sync::Arc; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::responses::Responses; use crate::{Error, SimpleQueryMessage, SimpleQueryRow}; pub enum State { @@ -15,7 +15,7 @@ pub enum State { }, ReadResponse { columns: Option>, - receiver: mpsc::Receiver, + receiver: Responses, }, Done, } @@ -46,7 +46,7 @@ impl Stream for SimpleQueryStream { self.0 = State::ReadResponse { columns, receiver }; return Ok(Async::NotReady); } - Err(()) => unreachable!("mpsc receiver can't panic"), + Err(e) => return Err(e), }; match message {