Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Commit f4aa454

Browse files
kelvichpetuhovskiy
authored and
Jakub Wieczorek
committed
Add text protocol based query method (sfackler#14)
Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky <[email protected]>
1 parent d6c2835 commit f4aa454

File tree

11 files changed

+293
-7
lines changed

11 files changed

+293
-7
lines changed

.github/workflows/ci.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ jobs:
5353
steps:
5454
- uses: actions/checkout@v3
5555
- uses: sfackler/actions/rustup@master
56-
- run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT
56+
with:
57+
version: 1.65.0
58+
- run: echo "::set-output name=version::$(rustc --version)"
5759
id: rust-version
5860
- run: rustup target add wasm32-unknown-unknown
5961
- uses: actions/cache@v3

postgres-types/src/lib.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,22 @@ impl WrongType {
449449
}
450450
}
451451

452+
/// An error indicating that a as_text conversion was attempted on a binary
453+
/// result.
454+
#[derive(Debug)]
455+
pub struct WrongFormat {}
456+
457+
impl Error for WrongFormat {}
458+
459+
impl fmt::Display for WrongFormat {
460+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
461+
write!(
462+
fmt,
463+
"cannot read column as text while it is in binary format"
464+
)
465+
}
466+
}
467+
452468
/// A trait for types that can be created from a Postgres value.
453469
///
454470
/// # Types
@@ -900,7 +916,7 @@ pub trait ToSql: fmt::Debug {
900916
/// Supported Postgres message format types
901917
///
902918
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
903-
#[derive(Clone, Copy, Debug)]
919+
#[derive(Clone, Copy, Debug, PartialEq)]
904920
pub enum Format {
905921
/// Text format (UTF-8)
906922
Text,

tokio-postgres/src/client.rs

+84-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ use crate::connection::{Request, RequestMessages};
44
use crate::copy_out::CopyOutStream;
55
#[cfg(feature = "runtime")]
66
use crate::keepalive::KeepaliveConfig;
7+
use crate::prepare::get_type;
78
use crate::query::RowStream;
89
use crate::simple_query::SimpleQueryStream;
10+
use crate::statement::Column;
911
#[cfg(feature = "runtime")]
1012
use crate::tls::MakeTlsConnect;
1113
use crate::tls::TlsConnect;
@@ -16,7 +18,7 @@ use crate::{
1618
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
1719
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
1820
};
19-
use bytes::{Buf, BytesMut};
21+
use bytes::{Buf, BufMut, BytesMut};
2022
use fallible_iterator::FallibleIterator;
2123
use futures_channel::mpsc;
2224
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
@@ -368,6 +370,87 @@ impl Client {
368370
query::query(&self.inner, statement, params).await
369371
}
370372

373+
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
374+
/// to save a roundtrip
375+
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
376+
where
377+
S: AsRef<str>,
378+
I: IntoIterator<Item = S>,
379+
I::IntoIter: ExactSizeIterator,
380+
{
381+
let params = params.into_iter();
382+
let params_len = params.len();
383+
384+
let buf = self.inner.with_buf(|buf| {
385+
// Parse, anonymous portal
386+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
387+
// Bind, pass params as text, retrieve as binary
388+
match frontend::bind(
389+
"", // empty string selects the unnamed portal
390+
"", // empty string selects the unnamed prepared statement
391+
std::iter::empty(), // all parameters use the default format (text)
392+
params,
393+
|param, buf| {
394+
buf.put_slice(param.as_ref().as_bytes());
395+
Ok(postgres_protocol::IsNull::No)
396+
},
397+
Some(0), // all text
398+
buf,
399+
) {
400+
Ok(()) => Ok(()),
401+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
402+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
403+
}?;
404+
405+
// Describe portal to typecast results
406+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
407+
// Execute
408+
frontend::execute("", 0, buf).map_err(Error::encode)?;
409+
// Sync
410+
frontend::sync(buf);
411+
412+
Ok(buf.split().freeze())
413+
})?;
414+
415+
let mut responses = self
416+
.inner
417+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
418+
419+
// now read the responses
420+
421+
match responses.next().await? {
422+
Message::ParseComplete => {}
423+
_ => return Err(Error::unexpected_message()),
424+
}
425+
match responses.next().await? {
426+
Message::BindComplete => {}
427+
_ => return Err(Error::unexpected_message()),
428+
}
429+
let row_description = match responses.next().await? {
430+
Message::RowDescription(body) => Some(body),
431+
Message::NoData => None,
432+
_ => return Err(Error::unexpected_message()),
433+
};
434+
435+
// construct statement object
436+
437+
let parameters = vec![Type::UNKNOWN; params_len];
438+
439+
let mut columns = vec![];
440+
if let Some(row_description) = row_description {
441+
let mut it = row_description.fields();
442+
while let Some(field) = it.next().map_err(Error::parse)? {
443+
let type_ = get_type(&self.inner, field.type_oid()).await?;
444+
let column = Column::new(field.name().to_string(), type_);
445+
columns.push(column);
446+
}
447+
}
448+
449+
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
450+
451+
Ok(RowStream::new(statement, responses))
452+
}
453+
371454
/// Executes a statement, returning the number of rows modified.
372455
///
373456
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/codec.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
3535
}
3636
}
3737

38-
pub struct PostgresCodec;
38+
pub struct PostgresCodec {
39+
pub max_message_size: Option<usize>,
40+
}
3941

4042
impl Encoder<FrontendMessage> for PostgresCodec {
4143
type Error = io::Error;
@@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
6466
break;
6567
}
6668

69+
if let Some(max) = self.max_message_size {
70+
if len > max {
71+
return Err(io::Error::new(
72+
io::ErrorKind::InvalidInput,
73+
"message too large",
74+
));
75+
}
76+
}
77+
6778
match header.tag() {
6879
backend::NOTICE_RESPONSE_TAG
6980
| backend::NOTIFICATION_RESPONSE_TAG

tokio-postgres/src/config.rs

+23
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ pub struct Config {
209209
pub(crate) target_session_attrs: TargetSessionAttrs,
210210
pub(crate) channel_binding: ChannelBinding,
211211
pub(crate) load_balance_hosts: LoadBalanceHosts,
212+
pub(crate) replication_mode: Option<ReplicationMode>,
213+
pub(crate) max_backend_message_size: Option<usize>,
212214
}
213215

214216
impl Default for Config {
@@ -242,6 +244,8 @@ impl Config {
242244
target_session_attrs: TargetSessionAttrs::Any,
243245
channel_binding: ChannelBinding::Prefer,
244246
load_balance_hosts: LoadBalanceHosts::Disable,
247+
replication_mode: None,
248+
max_backend_message_size: None,
245249
}
246250
}
247251

@@ -522,6 +526,17 @@ impl Config {
522526
self.load_balance_hosts
523527
}
524528

529+
/// Set limit for backend messages size.
530+
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
531+
self.max_backend_message_size = Some(max_backend_message_size);
532+
self
533+
}
534+
535+
/// Get limit for backend messages size.
536+
pub fn get_max_backend_message_size(&self) -> Option<usize> {
537+
self.max_backend_message_size
538+
}
539+
525540
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
526541
match key {
527542
"user" => {
@@ -658,6 +673,14 @@ impl Config {
658673
};
659674
self.load_balance_hosts(load_balance_hosts);
660675
}
676+
"max_backend_message_size" => {
677+
let limit = value.parse::<usize>().map_err(|_| {
678+
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
679+
})?;
680+
if limit > 0 {
681+
self.max_backend_message_size(limit);
682+
}
683+
}
661684
key => {
662685
return Err(Error::config_parse(Box::new(UnknownOption(
663686
key.to_string(),

tokio-postgres/src/connect_raw.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,12 @@ where
9292
let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?;
9393

9494
let mut stream = StartupStream {
95-
inner: Framed::new(stream, PostgresCodec),
95+
inner: Framed::new(
96+
stream,
97+
PostgresCodec {
98+
max_message_size: config.max_backend_message_size,
99+
},
100+
),
96101
buf: BackendMessages::empty(),
97102
delayed: VecDeque::new(),
98103
};

tokio-postgres/src/prepare.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
131131
})
132132
}
133133

134-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
134+
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
135135
if let Some(type_) = Type::from_oid(oid) {
136136
return Ok(type_);
137137
}

tokio-postgres/src/query.rs

+26
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ where
5353
statement,
5454
responses,
5555
rows_affected: None,
56+
command_tag: None,
5657
_p: PhantomPinned,
5758
})
5859
}
@@ -74,6 +75,7 @@ pub async fn query_portal(
7475
statement: portal.statement().clone(),
7576
responses,
7677
rows_affected: None,
78+
command_tag: None,
7779
_p: PhantomPinned,
7880
})
7981
}
@@ -208,11 +210,24 @@ pin_project! {
208210
statement: Statement,
209211
responses: Responses,
210212
rows_affected: Option<u64>,
213+
command_tag: Option<String>,
211214
#[pin]
212215
_p: PhantomPinned,
213216
}
214217
}
215218

219+
impl RowStream {
220+
/// Creates a new `RowStream`.
221+
pub fn new(statement: Statement, responses: Responses) -> Self {
222+
RowStream {
223+
statement,
224+
responses,
225+
command_tag: None,
226+
_p: PhantomPinned,
227+
}
228+
}
229+
}
230+
216231
impl Stream for RowStream {
217232
type Item = Result<Row, Error>;
218233

@@ -225,6 +240,10 @@ impl Stream for RowStream {
225240
}
226241
Message::CommandComplete(body) => {
227242
*this.rows_affected = Some(extract_row_affected(&body)?);
243+
244+
if let Ok(tag) = body.tag() {
245+
*this.command_tag = Some(tag.to_string());
246+
}
228247
}
229248
Message::EmptyQueryResponse | Message::PortalSuspended => {}
230249
Message::ReadyForQuery(_) => return Poll::Ready(None),
@@ -241,4 +260,11 @@ impl RowStream {
241260
pub fn rows_affected(&self) -> Option<u64> {
242261
self.rows_affected
243262
}
263+
264+
/// Returns the command tag of this query.
265+
///
266+
/// This is only available after the stream has been exhausted.
267+
pub fn command_tag(&self) -> Option<String> {
268+
self.command_tag.clone()
269+
}
244270
}

tokio-postgres/src/row.rs

+22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType};
77
use crate::{Error, Statement};
88
use fallible_iterator::FallibleIterator;
99
use postgres_protocol::message::backend::DataRowBody;
10+
use postgres_types::{Format, WrongFormat};
1011
use std::fmt;
1112
use std::ops::Range;
1213
use std::str;
@@ -188,6 +189,27 @@ impl Row {
188189
let range = self.ranges[idx].to_owned()?;
189190
Some(&self.body.buffer()[range])
190191
}
192+
193+
/// Interpret the column at the given index as text
194+
///
195+
/// Useful when using query_raw_txt() which sets text transfer mode
196+
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
197+
if self.statement.output_format() == Format::Text {
198+
match self.col_buffer(idx) {
199+
Some(raw) => {
200+
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
201+
}
202+
None => Ok(None),
203+
}
204+
} else {
205+
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
206+
}
207+
}
208+
209+
/// Row byte size
210+
pub fn body_len(&self) -> usize {
211+
self.body.buffer().len()
212+
}
191213
}
192214

193215
impl AsName for SimpleColumn {

0 commit comments

Comments
 (0)