diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 008158fb0..431e17748 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,7 +53,9 @@ jobs: steps: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + with: + version: 1.67.0 + - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown - uses: actions/cache@v3 diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 1b5be1098..da267101c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -72,6 +72,7 @@ impl Header { } /// An enum representing Postgres backend messages. +#[derive(Debug, PartialEq)] #[non_exhaustive] pub enum Message { AuthenticationCleartextPassword, @@ -333,6 +334,7 @@ impl Read for Buffer { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationMd5PasswordBody { salt: [u8; 4], } @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationGssContinueBody(Bytes); impl AuthenticationGssContinueBody { @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslBody(Bytes); impl AuthenticationSaslBody { @@ -362,6 +366,7 @@ impl AuthenticationSaslBody { } } +#[derive(Debug, PartialEq)] pub struct SaslMechanisms<'a>(&'a [u8]); impl<'a> FallibleIterator for SaslMechanisms<'a> { @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslContinueBody(Bytes); impl AuthenticationSaslContinueBody { @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslFinalBody(Bytes); impl AuthenticationSaslFinalBody { @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody { } } +#[derive(Debug, PartialEq)] pub struct BackendKeyDataBody { process_id: i32, secret_key: i32, @@ -422,6 +430,7 @@ impl BackendKeyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CommandCompleteBody { tag: Bytes, } @@ -433,6 +442,7 @@ impl CommandCompleteBody { } } +#[derive(Debug, PartialEq)] pub struct CopyDataBody { storage: Bytes, } @@ -449,6 +459,7 @@ impl CopyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CopyInResponseBody { format: u8, len: u16, @@ -470,6 +481,7 @@ impl CopyInResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ColumnFormats<'a> { buf: &'a [u8], remaining: u16, @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> { } } +#[derive(Debug, PartialEq)] pub struct CopyOutResponseBody { format: u8, len: u16, @@ -524,7 +537,7 @@ impl CopyOutResponseBody { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct DataRowBody { storage: Bytes, len: u16, @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { } } +#[derive(Debug, PartialEq)] pub struct ErrorResponseBody { storage: Bytes, } @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> { } } +#[derive(Debug, PartialEq)] pub struct NoticeResponseBody { storage: Bytes, } @@ -668,6 +683,7 @@ impl NoticeResponseBody { } } +#[derive(Debug, PartialEq)] pub struct NotificationResponseBody { process_id: i32, channel: Bytes, @@ -691,6 +707,7 @@ impl NotificationResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ParameterDescriptionBody { storage: Bytes, len: u16, @@ -706,6 +723,7 @@ impl ParameterDescriptionBody { } } +#[derive(Debug, PartialEq)] pub struct Parameters<'a> { buf: &'a [u8], remaining: u16, @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> { } } +#[derive(Debug, PartialEq)] pub struct ParameterStatusBody { name: Bytes, value: Bytes, @@ -756,6 +775,7 @@ impl ParameterStatusBody { } } +#[derive(Debug, PartialEq)] pub struct ReadyForQueryBody { status: u8, } @@ -767,6 +787,7 @@ impl ReadyForQueryBody { } } +#[derive(Debug, PartialEq)] pub struct RowDescriptionBody { storage: Bytes, len: u16, diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 52b5c773a..531a9f719 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -442,6 +442,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index ec5e3cbec..c11de2e2b 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" -whoami = "1.4.1" +whoami = "1.4" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs index 9c5c49218..dac1a3c06 100644 --- a/tokio-postgres/src/bind.rs +++ b/tokio-postgres/src/bind.rs @@ -31,7 +31,7 @@ where match responses.next().await? { Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(Portal::new(client, name, statement)) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 427a05049..fae26a313 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -104,7 +104,7 @@ impl InnerClient { } pub fn typeinfo(&self) -> Option { - self.cached_typeinfo.lock().typeinfo.clone() + None } pub fn set_typeinfo(&self, statement: &Statement) { @@ -112,7 +112,7 @@ impl InnerClient { } pub fn typeinfo_composite(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_composite.clone() + None } pub fn set_typeinfo_composite(&self, statement: &Statement) { @@ -120,15 +120,15 @@ impl InnerClient { } pub fn typeinfo_enum(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_enum.clone() + None } pub fn set_typeinfo_enum(&self, statement: &Statement) { self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); } - pub fn type_(&self, oid: Oid) -> Option { - self.cached_typeinfo.lock().types.get(&oid).cloned() + pub fn type_(&self, _: Oid) -> Option { + None } pub fn set_type(&self, oid: Oid, type_: &Type) { @@ -231,7 +231,11 @@ impl Client { query: &str, parameter_types: &[Type], ) -> Result { - prepare::prepare(&self.inner, query, parameter_types).await + prepare::prepare(&self.inner, query, parameter_types, false).await + } + + pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result { + prepare::prepare(&self.inner, query, &[], true).await } /// Executes a statement, returning a vector of the resulting rows. @@ -368,6 +372,23 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result + where + T: ?Sized + ToStatement, + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query_txt(&self.inner, statement, params).await + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b178eac80..2547469ec 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -207,6 +207,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -240,6 +241,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + max_backend_message_size: None, } } @@ -520,6 +522,17 @@ impl Config { self.load_balance_hosts } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -655,6 +668,14 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..e697e5bc6 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -195,7 +195,7 @@ where } } Some(_) => {} - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), } } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..b468c5f32 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -92,7 +92,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; @@ -190,14 +195,14 @@ where )) } Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), + Some(m) => Err(Error::unexpected_message(m)), None => Err(Error::closed()), } } @@ -291,7 +296,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -309,7 +314,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -348,7 +353,7 @@ where } Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } } diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..652038cc0 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -139,7 +139,8 @@ where Some(response) => response, None => match messages.next().map_err(Error::parse)? { Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), + None => return Err(Error::closed()), }, }; diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index 59e31fea6..f997e9433 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -114,7 +114,7 @@ where let rows = extract_row_affected(&body)?; return Poll::Ready(Ok(rows)); } - _ => return Poll::Ready(Err(Error::unexpected_message())), + m => return Poll::Ready(Err(Error::unexpected_message(m))), } } } @@ -206,13 +206,19 @@ where .map_err(|_| Error::closed())?; match responses.next().await? { + Message::ParseComplete => { + match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + } + } Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyInResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(CopyInSink { diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..4141bee92 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -26,13 +26,17 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyOutResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -56,7 +60,7 @@ impl Stream for CopyOutStream { match ready!(this.responses.poll_next(cx)?) { Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), Message::CopyDone => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + m => Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f1e2644c6..764f77f9c 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -1,7 +1,7 @@ //! Errors. use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; +use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody, Message}; use std::error::{self, Error as _Error}; use std::fmt; use std::io; @@ -339,7 +339,7 @@ pub enum ErrorPosition { #[derive(Debug, PartialEq)] enum Kind { Io, - UnexpectedMessage, + UnexpectedMessage(Message), Tls, ToSql(usize), FromSql(usize), @@ -379,7 +379,9 @@ impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0.kind { Kind::Io => fmt.write_str("error communicating with the server")?, - Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::UnexpectedMessage(msg) => { + write!(fmt, "unexpected message from server: {:?}", msg)? + } Kind::Tls => fmt.write_str("error performing TLS handshake")?, Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, @@ -445,8 +447,8 @@ impl Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { - Error::new(Kind::UnexpectedMessage, None) + pub(crate) fn unexpected_message(message: Message) -> Error { + Error::new(Kind::UnexpectedMessage(message), None) } #[allow(clippy::needless_pass_by_value)] diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..a4ee4808b 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,18 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -136,6 +148,16 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -222,6 +244,16 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..9895aa0d4 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -62,25 +62,30 @@ pub async fn prepare( client: &Arc, query: &str, types: &[Type], + unnamed: bool, ) -> Result { - let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let name = if unnamed { + String::new() + } else { + format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)) + }; let buf = encode(client, &name, query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } let parameter_description = match responses.next().await? { Message::ParameterDescription(body) => body, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let row_description = match responses.next().await? { Message::RowDescription(body) => Some(body), Message::NoData => None, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let mut parameters = vec![]; @@ -95,12 +100,16 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } - Ok(Statement::new(client, name, parameters, columns)) + if unnamed { + Ok(Statement::unnamed(query.to_owned(), parameters, columns)) + } else { + Ok(Statement::named(client, name, parameters, columns)) + } } fn prepare_rec<'a>( @@ -108,7 +117,7 @@ fn prepare_rec<'a>( query: &'a str, types: &'a [Type], ) -> Pin> + 'a + Send>> { - Box::pin(prepare(client, query, types)) + Box::pin(prepare(client, query, types, false)) } fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { @@ -126,7 +135,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } @@ -142,7 +151,7 @@ async fn get_type(client: &Arc, oid: Oid) -> Result { let row = match rows.try_next().await? { Some(row) => row, - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), }; let name: String = row.try_get(0)?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..8b7e048e8 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -3,15 +3,17 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -53,10 +55,68 @@ where statement, responses, rows_affected: None, + command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } +pub async fn query_txt( + client: &Arc, + statement: Statement, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + statement.name(), // named prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: None, + output_format: Format::Text, + _p: PhantomPinned, + rows_affected: None, + }) +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, @@ -74,6 +134,9 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, rows_affected: None, + command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -123,7 +186,7 @@ where } Message::EmptyQueryResponse => rows = 0, Message::ReadyForQuery(_) => return Ok(rows), - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -132,8 +195,12 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -146,6 +213,9 @@ where I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { + if let Some(query) = statement.query() { + frontend::parse("", query, [], buf).unwrap(); + } encode_bind(statement, params, "", buf)?; frontend::execute("", 0, buf).map_err(Error::encode)?; frontend::sync(buf); @@ -208,6 +278,9 @@ pin_project! { statement: Statement, responses: Responses, rows_affected: Option, + command_tag: Option, + output_format: Format, + status: Option, #[pin] _p: PhantomPinned, } @@ -221,14 +294,25 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) } Message::CommandComplete(body) => { *this.rows_affected = Some(extract_row_affected(&body)?); + + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } } Message::EmptyQueryResponse | Message::PortalSuspended => {} - Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + Message::ReadyForQuery(status) => { + *this.status = Some(status.status()); + return Poll::Ready(None); + } + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } @@ -241,4 +325,18 @@ impl RowStream { pub fn rows_affected(&self) -> Option { self.rows_affected } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> Option { + self.status + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..754b5f28c 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -97,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -110,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -187,6 +194,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index bcc6d928b..9838b0809 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -58,7 +58,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro | Message::EmptyQueryResponse | Message::RowDescription(_) | Message::DataRow(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -107,12 +107,12 @@ impl Stream for SimpleQueryStream { Message::DataRow(body) => { let row = match &this.columns { Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, - None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + None => return Poll::Ready(Some(Err(Error::closed()))), }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..920bd74da 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -2,28 +2,40 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; -use postgres_protocol::message::frontend; +use postgres_protocol::{ + message::{backend::Field, frontend}, + Oid, +}; use std::{ fmt, sync::{Arc, Weak}, }; -struct StatementInner { - client: Weak, - name: String, - params: Vec, - columns: Vec, +enum StatementInner { + Unnamed { + query: String, + params: Vec, + columns: Vec, + }, + Named { + client: Weak, + name: String, + params: Vec, + columns: Vec, + }, } impl Drop for StatementInner { fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let buf = client.with_buf(|buf| { - frontend::close(b'S', &self.name, buf).unwrap(); - frontend::sync(buf); - buf.split().freeze() - }); - let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + if let StatementInner::Named { client, name, .. } = self { + if let Some(client) = client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } } } } @@ -35,13 +47,13 @@ impl Drop for StatementInner { pub struct Statement(Arc); impl Statement { - pub(crate) fn new( + pub(crate) fn named( inner: &Arc, name: String, params: Vec, columns: Vec, ) -> Statement { - Statement(Arc::new(StatementInner { + Statement(Arc::new(StatementInner::Named { client: Arc::downgrade(inner), name, params, @@ -49,18 +61,42 @@ impl Statement { })) } + pub(crate) fn unnamed(query: String, params: Vec, columns: Vec) -> Self { + Statement(Arc::new(StatementInner::Unnamed { + query, + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { - &self.0.name + match &*self.0 { + StatementInner::Unnamed { .. } => "", + StatementInner::Named { name, .. } => name, + } + } + + pub(crate) fn query(&self) -> Option<&str> { + match &*self.0 { + StatementInner::Unnamed { query, .. } => Some(query), + StatementInner::Named { .. } => None, + } } /// Returns the expected types of the statement's parameters. pub fn params(&self) -> &[Type] { - &self.0.params + match &*self.0 { + StatementInner::Unnamed { params, .. } => params, + StatementInner::Named { params, .. } => params, + } } /// Returns information about the columns returned when the statement is queried. pub fn columns(&self) -> &[Column] { - &self.0.columns + match &*self.0 { + StatementInner::Unnamed { columns, .. } => columns, + StatementInner::Named { columns, .. } => columns, + } } } @@ -68,11 +104,30 @@ impl Statement { pub struct Column { name: String, type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } } /// Returns the name of the column. @@ -84,6 +139,36 @@ impl Column { pub fn type_(&self) -> &Type { &self.type_ } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } } impl fmt::Debug for Column { diff --git a/tokio-postgres/src/to_statement.rs b/tokio-postgres/src/to_statement.rs index 427f77dd7..ef1e65272 100644 --- a/tokio-postgres/src/to_statement.rs +++ b/tokio-postgres/src/to_statement.rs @@ -15,7 +15,7 @@ mod private { pub async fn into_statement(self, client: &Client) -> Result { match self { ToStatementType::Statement(s) => Ok(s.clone()), - ToStatementType::Query(s) => client.prepare(s).await, + ToStatementType::Query(s) => client.prepare_unnamed(s).await, } } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..ca386974e 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,17 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..565984271 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -249,6 +249,161 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn query_raw_txt_nulls() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "SELECT $1 as str, $2 as n, 'null' as str2, null as n2", + [Some("null"), None], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + + let res = rows[0].as_text(0).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(1).unwrap(); + assert_eq!(res, None); + + let res = rows[0].as_text(2).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(3).unwrap(); + assert_eq!(res, None); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'T')); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'I')); +} + +#[tokio::test] +async fn column_extras() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let column = rows[0].columns().get(1).unwrap(); + assert_eq!(column.name(), "relname"); + assert_eq!(column.type_(), &Type::NAME); + + assert!(column.table_oid() > 0); + assert_eq!(column.column_id(), 2); + assert_eq!(column.format(), 0); + + assert_eq!(column.type_oid(), 19); + assert_eq!(column.type_size(), 64); + assert_eq!(column.type_modifier(), -1); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await;