Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c5b1d75

Browse files
kelvichpetuhovskiy
authored and
Julius de Bruijn
committedOct 13, 2023
Add text protocol based query method (sfackler#14)
Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky <[email protected]>
1 parent c5ff8cf commit c5b1d75

File tree

11 files changed

+289
-6
lines changed

11 files changed

+289
-6
lines changed
 

‎.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ jobs:
5353
steps:
5454
- uses: actions/checkout@v3
5555
- uses: sfackler/actions/rustup@master
56-
- run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT
56+
with:
57+
version: 1.65.0
58+
- run: echo "::set-output name=version::$(rustc --version)"
5759
id: rust-version
5860
- run: rustup target add wasm32-unknown-unknown
5961
- uses: actions/cache@v3

‎postgres-types/src/lib.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,22 @@ impl WrongType {
442442
}
443443
}
444444

445+
/// An error indicating that a as_text conversion was attempted on a binary
446+
/// result.
447+
#[derive(Debug)]
448+
pub struct WrongFormat {}
449+
450+
impl Error for WrongFormat {}
451+
452+
impl fmt::Display for WrongFormat {
453+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
454+
write!(
455+
fmt,
456+
"cannot read column as text while it is in binary format"
457+
)
458+
}
459+
}
460+
445461
/// A trait for types that can be created from a Postgres value.
446462
///
447463
/// # Types
@@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug {
893909
/// Supported Postgres message format types
894910
///
895911
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
896-
#[derive(Clone, Copy, Debug)]
912+
#[derive(Clone, Copy, Debug, PartialEq)]
897913
pub enum Format {
898914
/// Text format (UTF-8)
899915
Text,

‎tokio-postgres/src/client.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ use crate::connection::{Request, RequestMessages};
44
use crate::copy_out::CopyOutStream;
55
#[cfg(feature = "runtime")]
66
use crate::keepalive::KeepaliveConfig;
7+
use crate::prepare::get_type;
78
use crate::query::RowStream;
89
use crate::simple_query::SimpleQueryStream;
10+
use crate::statement::Column;
911
#[cfg(feature = "runtime")]
1012
use crate::tls::MakeTlsConnect;
1113
use crate::tls::TlsConnect;
@@ -16,7 +18,7 @@ use crate::{
1618
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
1719
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
1820
};
19-
use bytes::{Buf, BytesMut};
21+
use bytes::{Buf, BufMut, BytesMut};
2022
use fallible_iterator::FallibleIterator;
2123
use futures_channel::mpsc;
2224
use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
@@ -368,6 +370,87 @@ impl Client {
368370
query::query(&self.inner, statement, params).await
369371
}
370372

373+
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
374+
/// to save a roundtrip
375+
pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result<RowStream, Error>
376+
where
377+
S: AsRef<str>,
378+
I: IntoIterator<Item = S>,
379+
I::IntoIter: ExactSizeIterator,
380+
{
381+
let params = params.into_iter();
382+
let params_len = params.len();
383+
384+
let buf = self.inner.with_buf(|buf| {
385+
// Parse, anonymous portal
386+
frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?;
387+
// Bind, pass params as text, retrieve as binary
388+
match frontend::bind(
389+
"", // empty string selects the unnamed portal
390+
"", // empty string selects the unnamed prepared statement
391+
std::iter::empty(), // all parameters use the default format (text)
392+
params,
393+
|param, buf| {
394+
buf.put_slice(param.as_ref().as_bytes());
395+
Ok(postgres_protocol::IsNull::No)
396+
},
397+
Some(0), // all text
398+
buf,
399+
) {
400+
Ok(()) => Ok(()),
401+
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
402+
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
403+
}?;
404+
405+
// Describe portal to typecast results
406+
frontend::describe(b'P', "", buf).map_err(Error::encode)?;
407+
// Execute
408+
frontend::execute("", 0, buf).map_err(Error::encode)?;
409+
// Sync
410+
frontend::sync(buf);
411+
412+
Ok(buf.split().freeze())
413+
})?;
414+
415+
let mut responses = self
416+
.inner
417+
.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
418+
419+
// now read the responses
420+
421+
match responses.next().await? {
422+
Message::ParseComplete => {}
423+
_ => return Err(Error::unexpected_message()),
424+
}
425+
match responses.next().await? {
426+
Message::BindComplete => {}
427+
_ => return Err(Error::unexpected_message()),
428+
}
429+
let row_description = match responses.next().await? {
430+
Message::RowDescription(body) => Some(body),
431+
Message::NoData => None,
432+
_ => return Err(Error::unexpected_message()),
433+
};
434+
435+
// construct statement object
436+
437+
let parameters = vec![Type::UNKNOWN; params_len];
438+
439+
let mut columns = vec![];
440+
if let Some(row_description) = row_description {
441+
let mut it = row_description.fields();
442+
while let Some(field) = it.next().map_err(Error::parse)? {
443+
let type_ = get_type(&self.inner, field.type_oid()).await?;
444+
let column = Column::new(field.name().to_string(), type_);
445+
columns.push(column);
446+
}
447+
}
448+
449+
let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns);
450+
451+
Ok(RowStream::new(statement, responses))
452+
}
453+
371454
/// Executes a statement, returning the number of rows modified.
372455
///
373456
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

‎tokio-postgres/src/codec.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
3535
}
3636
}
3737

38-
pub struct PostgresCodec;
38+
pub struct PostgresCodec {
39+
pub max_message_size: Option<usize>,
40+
}
3941

4042
impl Encoder<FrontendMessage> for PostgresCodec {
4143
type Error = io::Error;
@@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
6466
break;
6567
}
6668

69+
if let Some(max) = self.max_message_size {
70+
if len > max {
71+
return Err(io::Error::new(
72+
io::ErrorKind::InvalidInput,
73+
"message too large",
74+
));
75+
}
76+
}
77+
6778
match header.tag() {
6879
backend::NOTICE_RESPONSE_TAG
6980
| backend::NOTIFICATION_RESPONSE_TAG

‎tokio-postgres/src/config.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ pub struct Config {
207207
pub(crate) target_session_attrs: TargetSessionAttrs,
208208
pub(crate) channel_binding: ChannelBinding,
209209
pub(crate) load_balance_hosts: LoadBalanceHosts,
210+
pub(crate) replication_mode: Option<ReplicationMode>,
211+
pub(crate) max_backend_message_size: Option<usize>,
210212
}
211213

212214
impl Default for Config {
@@ -240,6 +242,8 @@ impl Config {
240242
target_session_attrs: TargetSessionAttrs::Any,
241243
channel_binding: ChannelBinding::Prefer,
242244
load_balance_hosts: LoadBalanceHosts::Disable,
245+
replication_mode: None,
246+
max_backend_message_size: None,
243247
}
244248
}
245249

@@ -520,6 +524,17 @@ impl Config {
520524
self.load_balance_hosts
521525
}
522526

527+
/// Set limit for backend messages size.
528+
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
529+
self.max_backend_message_size = Some(max_backend_message_size);
530+
self
531+
}
532+
533+
/// Get limit for backend messages size.
534+
pub fn get_max_backend_message_size(&self) -> Option<usize> {
535+
self.max_backend_message_size
536+
}
537+
523538
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
524539
match key {
525540
"user" => {
@@ -655,6 +670,14 @@ impl Config {
655670
};
656671
self.load_balance_hosts(load_balance_hosts);
657672
}
673+
"max_backend_message_size" => {
674+
let limit = value.parse::<usize>().map_err(|_| {
675+
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
676+
})?;
677+
if limit > 0 {
678+
self.max_backend_message_size(limit);
679+
}
680+
}
658681
key => {
659682
return Err(Error::config_parse(Box::new(UnknownOption(
660683
key.to_string(),

‎tokio-postgres/src/connect_raw.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,12 @@ where
9292
let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?;
9393

9494
let mut stream = StartupStream {
95-
inner: Framed::new(stream, PostgresCodec),
95+
inner: Framed::new(
96+
stream,
97+
PostgresCodec {
98+
max_message_size: config.max_backend_message_size,
99+
},
100+
),
96101
buf: BackendMessages::empty(),
97102
delayed: VecDeque::new(),
98103
};

‎tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
126126
})
127127
}
128128

129-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
129+
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130130
if let Some(type_) = Type::from_oid(oid) {
131131
return Ok(type_);
132132
}

‎tokio-postgres/src/query.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ where
5353
statement,
5454
responses,
5555
rows_affected: None,
56+
command_tag: None,
5657
_p: PhantomPinned,
5758
})
5859
}
@@ -74,6 +75,7 @@ pub async fn query_portal(
7475
statement: portal.statement().clone(),
7576
responses,
7677
rows_affected: None,
78+
command_tag: None,
7779
_p: PhantomPinned,
7880
})
7981
}
@@ -208,11 +210,24 @@ pin_project! {
208210
statement: Statement,
209211
responses: Responses,
210212
rows_affected: Option<u64>,
213+
command_tag: Option<String>,
211214
#[pin]
212215
_p: PhantomPinned,
213216
}
214217
}
215218

219+
impl RowStream {
220+
/// Creates a new `RowStream`.
221+
pub fn new(statement: Statement, responses: Responses) -> Self {
222+
RowStream {
223+
statement,
224+
responses,
225+
command_tag: None,
226+
_p: PhantomPinned,
227+
}
228+
}
229+
}
230+
216231
impl Stream for RowStream {
217232
type Item = Result<Row, Error>;
218233

@@ -225,6 +240,10 @@ impl Stream for RowStream {
225240
}
226241
Message::CommandComplete(body) => {
227242
*this.rows_affected = Some(extract_row_affected(&body)?);
243+
244+
if let Ok(tag) = body.tag() {
245+
*this.command_tag = Some(tag.to_string());
246+
}
228247
}
229248
Message::EmptyQueryResponse | Message::PortalSuspended => {}
230249
Message::ReadyForQuery(_) => return Poll::Ready(None),
@@ -241,4 +260,11 @@ impl RowStream {
241260
pub fn rows_affected(&self) -> Option<u64> {
242261
self.rows_affected
243262
}
263+
264+
/// Returns the command tag of this query.
265+
///
266+
/// This is only available after the stream has been exhausted.
267+
pub fn command_tag(&self) -> Option<String> {
268+
self.command_tag.clone()
269+
}
244270
}

‎tokio-postgres/src/row.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType};
77
use crate::{Error, Statement};
88
use fallible_iterator::FallibleIterator;
99
use postgres_protocol::message::backend::DataRowBody;
10+
use postgres_types::{Format, WrongFormat};
1011
use std::fmt;
1112
use std::ops::Range;
1213
use std::str;
@@ -187,6 +188,27 @@ impl Row {
187188
let range = self.ranges[idx].to_owned()?;
188189
Some(&self.body.buffer()[range])
189190
}
191+
192+
/// Interpret the column at the given index as text
193+
///
194+
/// Useful when using query_raw_txt() which sets text transfer mode
195+
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
196+
if self.statement.output_format() == Format::Text {
197+
match self.col_buffer(idx) {
198+
Some(raw) => {
199+
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
200+
}
201+
None => Ok(None),
202+
}
203+
} else {
204+
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
205+
}
206+
}
207+
208+
/// Row byte size
209+
pub fn body_len(&self) -> usize {
210+
self.body.buffer().len()
211+
}
190212
}
191213

192214
impl AsName for SimpleColumn {

‎tokio-postgres/src/statement.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
44
use crate::types::Type;
55
use postgres_protocol::message::frontend;
6+
use postgres_types::Format;
67
use std::{
78
fmt,
89
sync::{Arc, Weak},
@@ -13,6 +14,7 @@ struct StatementInner {
1314
name: String,
1415
params: Vec<Type>,
1516
columns: Vec<Column>,
17+
output_format: Format,
1618
}
1719

1820
impl Drop for StatementInner {
@@ -46,6 +48,22 @@ impl Statement {
4648
name,
4749
params,
4850
columns,
51+
output_format: Format::Binary,
52+
}))
53+
}
54+
55+
pub(crate) fn new_text(
56+
inner: &Arc<InnerClient>,
57+
name: String,
58+
params: Vec<Type>,
59+
columns: Vec<Column>,
60+
) -> Statement {
61+
Statement(Arc::new(StatementInner {
62+
client: Arc::downgrade(inner),
63+
name,
64+
params,
65+
columns,
66+
output_format: Format::Text,
4967
}))
5068
}
5169

@@ -62,6 +80,11 @@ impl Statement {
6280
pub fn columns(&self) -> &[Column] {
6381
&self.0.columns
6482
}
83+
84+
/// Returns output format for the statement.
85+
pub fn output_format(&self) -> Format {
86+
self.0.output_format
87+
}
6588
}
6689

6790
/// Information about a column of a query.

‎tokio-postgres/tests/test/main.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,78 @@ async fn custom_array() {
249249
}
250250
}
251251

252+
#[tokio::test]
253+
async fn query_raw_txt() {
254+
let client = connect("user=postgres").await;
255+
256+
let rows: Vec<tokio_postgres::Row> = client
257+
.query_raw_txt("SELECT 55 * $1", ["42"])
258+
.await
259+
.unwrap()
260+
.try_collect()
261+
.await
262+
.unwrap();
263+
264+
assert_eq!(rows.len(), 1);
265+
let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::<i32>().unwrap();
266+
assert_eq!(res, 55 * 42);
267+
268+
let rows: Vec<tokio_postgres::Row> = client
269+
.query_raw_txt("SELECT $1", ["42"])
270+
.await
271+
.unwrap()
272+
.try_collect()
273+
.await
274+
.unwrap();
275+
276+
assert_eq!(rows.len(), 1);
277+
assert_eq!(rows[0].get::<_, &str>(0), "42");
278+
assert!(rows[0].body_len() > 0);
279+
}
280+
281+
#[tokio::test]
282+
async fn limit_max_backend_message_size() {
283+
let client = connect("user=postgres max_backend_message_size=10000").await;
284+
let small: Vec<tokio_postgres::Row> = client
285+
.query_raw_txt("SELECT REPEAT('a', 20)", [])
286+
.await
287+
.unwrap()
288+
.try_collect()
289+
.await
290+
.unwrap();
291+
292+
assert_eq!(small.len(), 1);
293+
assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20);
294+
295+
let large: Result<Vec<tokio_postgres::Row>, Error> = client
296+
.query_raw_txt("SELECT REPEAT('a', 2000000)", [])
297+
.await
298+
.unwrap()
299+
.try_collect()
300+
.await;
301+
302+
assert!(large.is_err());
303+
}
304+
305+
#[tokio::test]
306+
async fn command_tag() {
307+
let client = connect("user=postgres").await;
308+
309+
let row_stream = client
310+
.query_raw_txt("select unnest('{1,2,3}'::int[]);", [])
311+
.await
312+
.unwrap();
313+
314+
pin_mut!(row_stream);
315+
316+
let mut rows: Vec<tokio_postgres::Row> = Vec::new();
317+
while let Some(row) = row_stream.next().await {
318+
rows.push(row.unwrap());
319+
}
320+
321+
assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string()));
322+
}
323+
252324
#[tokio::test]
253325
async fn custom_composite() {
254326
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)
This repository has been archived.