Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Commit a773102

Browse files
devsnekJulius de Bruijn
authored and
Julius de Bruijn
committed
support unnamed statements
1 parent c5c8c9f commit a773102

File tree

15 files changed

+140
-53
lines changed

15 files changed

+140
-53
lines changed

postgres-protocol/src/message/backend.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ impl Header {
7272
}
7373

7474
/// An enum representing Postgres backend messages.
75+
#[derive(Debug, PartialEq)]
7576
#[non_exhaustive]
7677
pub enum Message {
7778
AuthenticationCleartextPassword,
@@ -333,6 +334,7 @@ impl Read for Buffer {
333334
}
334335
}
335336

337+
#[derive(Debug, PartialEq)]
336338
pub struct AuthenticationMd5PasswordBody {
337339
salt: [u8; 4],
338340
}
@@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody {
344346
}
345347
}
346348

349+
#[derive(Debug, PartialEq)]
347350
pub struct AuthenticationGssContinueBody(Bytes);
348351

349352
impl AuthenticationGssContinueBody {
@@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody {
353356
}
354357
}
355358

359+
#[derive(Debug, PartialEq)]
356360
pub struct AuthenticationSaslBody(Bytes);
357361

358362
impl AuthenticationSaslBody {
@@ -362,6 +366,7 @@ impl AuthenticationSaslBody {
362366
}
363367
}
364368

369+
#[derive(Debug, PartialEq)]
365370
pub struct SaslMechanisms<'a>(&'a [u8]);
366371

367372
impl<'a> FallibleIterator for SaslMechanisms<'a> {
@@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
387392
}
388393
}
389394

395+
#[derive(Debug, PartialEq)]
390396
pub struct AuthenticationSaslContinueBody(Bytes);
391397

392398
impl AuthenticationSaslContinueBody {
@@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody {
396402
}
397403
}
398404

405+
#[derive(Debug, PartialEq)]
399406
pub struct AuthenticationSaslFinalBody(Bytes);
400407

401408
impl AuthenticationSaslFinalBody {
@@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody {
405412
}
406413
}
407414

415+
#[derive(Debug, PartialEq)]
408416
pub struct BackendKeyDataBody {
409417
process_id: i32,
410418
secret_key: i32,
@@ -422,6 +430,7 @@ impl BackendKeyDataBody {
422430
}
423431
}
424432

433+
#[derive(Debug, PartialEq)]
425434
pub struct CommandCompleteBody {
426435
tag: Bytes,
427436
}
@@ -433,6 +442,7 @@ impl CommandCompleteBody {
433442
}
434443
}
435444

445+
#[derive(Debug, PartialEq)]
436446
pub struct CopyDataBody {
437447
storage: Bytes,
438448
}
@@ -449,6 +459,7 @@ impl CopyDataBody {
449459
}
450460
}
451461

462+
#[derive(Debug, PartialEq)]
452463
pub struct CopyInResponseBody {
453464
format: u8,
454465
len: u16,
@@ -470,6 +481,7 @@ impl CopyInResponseBody {
470481
}
471482
}
472483

484+
#[derive(Debug, PartialEq)]
473485
pub struct ColumnFormats<'a> {
474486
buf: &'a [u8],
475487
remaining: u16,
@@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
503515
}
504516
}
505517

518+
#[derive(Debug, PartialEq)]
506519
pub struct CopyOutResponseBody {
507520
format: u8,
508521
len: u16,
@@ -524,7 +537,7 @@ impl CopyOutResponseBody {
524537
}
525538
}
526539

527-
#[derive(Debug)]
540+
#[derive(Debug, PartialEq)]
528541
pub struct DataRowBody {
529542
storage: Bytes,
530543
len: u16,
@@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
599612
}
600613
}
601614

615+
#[derive(Debug, PartialEq)]
602616
pub struct ErrorResponseBody {
603617
storage: Bytes,
604618
}
@@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> {
657671
}
658672
}
659673

674+
#[derive(Debug, PartialEq)]
660675
pub struct NoticeResponseBody {
661676
storage: Bytes,
662677
}
@@ -668,6 +683,7 @@ impl NoticeResponseBody {
668683
}
669684
}
670685

686+
#[derive(Debug, PartialEq)]
671687
pub struct NotificationResponseBody {
672688
process_id: i32,
673689
channel: Bytes,
@@ -691,6 +707,7 @@ impl NotificationResponseBody {
691707
}
692708
}
693709

710+
#[derive(Debug, PartialEq)]
694711
pub struct ParameterDescriptionBody {
695712
storage: Bytes,
696713
len: u16,
@@ -706,6 +723,7 @@ impl ParameterDescriptionBody {
706723
}
707724
}
708725

726+
#[derive(Debug, PartialEq)]
709727
pub struct Parameters<'a> {
710728
buf: &'a [u8],
711729
remaining: u16,
@@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> {
739757
}
740758
}
741759

760+
#[derive(Debug, PartialEq)]
742761
pub struct ParameterStatusBody {
743762
name: Bytes,
744763
value: Bytes,
@@ -756,6 +775,7 @@ impl ParameterStatusBody {
756775
}
757776
}
758777

778+
#[derive(Debug, PartialEq)]
759779
pub struct ReadyForQueryBody {
760780
status: u8,
761781
}
@@ -767,6 +787,7 @@ impl ReadyForQueryBody {
767787
}
768788
}
769789

790+
#[derive(Debug, PartialEq)]
770791
pub struct RowDescriptionBody {
771792
storage: Bytes,
772793
len: u16,

tokio-postgres/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" }
5959
tokio = { version = "1.27", features = ["io-util"] }
6060
tokio-util = { version = "0.7", features = ["codec"] }
6161
rand = "0.8.5"
62-
whoami = "1.4.1"
62+
whoami = "1.4"
6363

6464
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
6565
socket2 = { version = "0.5", features = ["all"] }

tokio-postgres/src/bind.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ where
3131

3232
match responses.next().await? {
3333
Message::BindComplete => {}
34-
_ => return Err(Error::unexpected_message()),
34+
m => return Err(Error::unexpected_message(m)),
3535
}
3636

3737
Ok(Portal::new(client, name, statement))

tokio-postgres/src/client.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ impl Client {
231231
query: &str,
232232
parameter_types: &[Type],
233233
) -> Result<Statement, Error> {
234-
prepare::prepare(&self.inner, query, parameter_types).await
234+
prepare::prepare(&self.inner, query, parameter_types, false).await
235+
}
236+
237+
pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result<Statement, Error> {
238+
prepare::prepare(&self.inner, query, &[], true).await
235239
}
236240

237241
/// Executes a statement, returning a vector of the resulting rows.

tokio-postgres/src/connect.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ where
195195
}
196196
}
197197
Some(_) => {}
198-
None => return Err(Error::unexpected_message()),
198+
None => return Err(Error::closed()),
199199
}
200200
}
201201
}

tokio-postgres/src/connect_raw.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ where
195195
))
196196
}
197197
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
198-
Some(_) => return Err(Error::unexpected_message()),
198+
Some(m) => return Err(Error::unexpected_message(m)),
199199
None => return Err(Error::closed()),
200200
}
201201

202202
match stream.try_next().await.map_err(Error::io)? {
203203
Some(Message::AuthenticationOk) => Ok(()),
204204
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
205-
Some(_) => Err(Error::unexpected_message()),
205+
Some(m) => Err(Error::unexpected_message(m)),
206206
None => Err(Error::closed()),
207207
}
208208
}
@@ -296,7 +296,7 @@ where
296296
let body = match stream.try_next().await.map_err(Error::io)? {
297297
Some(Message::AuthenticationSaslContinue(body)) => body,
298298
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
299-
Some(_) => return Err(Error::unexpected_message()),
299+
Some(m) => return Err(Error::unexpected_message(m)),
300300
None => return Err(Error::closed()),
301301
};
302302

@@ -314,7 +314,7 @@ where
314314
let body = match stream.try_next().await.map_err(Error::io)? {
315315
Some(Message::AuthenticationSaslFinal(body)) => body,
316316
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
317-
Some(_) => return Err(Error::unexpected_message()),
317+
Some(m) => return Err(Error::unexpected_message(m)),
318318
None => return Err(Error::closed()),
319319
};
320320

@@ -353,7 +353,7 @@ where
353353
}
354354
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
355355
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
356-
Some(_) => return Err(Error::unexpected_message()),
356+
Some(m) => return Err(Error::unexpected_message(m)),
357357
None => return Err(Error::closed()),
358358
}
359359
}

tokio-postgres/src/connection.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ where
139139
Some(response) => response,
140140
None => match messages.next().map_err(Error::parse)? {
141141
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
142-
_ => return Err(Error::unexpected_message()),
142+
Some(m) => return Err(Error::unexpected_message(m)),
143+
None => return Err(Error::closed()),
143144
},
144145
};
145146

tokio-postgres/src/copy_in.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ where
114114
let rows = extract_row_affected(&body)?;
115115
return Poll::Ready(Ok(rows));
116116
}
117-
_ => return Poll::Ready(Err(Error::unexpected_message())),
117+
m => return Poll::Ready(Err(Error::unexpected_message(m))),
118118
}
119119
}
120120
}
@@ -206,13 +206,19 @@ where
206206
.map_err(|_| Error::closed())?;
207207

208208
match responses.next().await? {
209+
Message::ParseComplete => {
210+
match responses.next().await? {
211+
Message::BindComplete => {}
212+
m => return Err(Error::unexpected_message(m)),
213+
}
214+
}
209215
Message::BindComplete => {}
210-
_ => return Err(Error::unexpected_message()),
216+
m => return Err(Error::unexpected_message(m)),
211217
}
212218

213219
match responses.next().await? {
214220
Message::CopyInResponse(_) => {}
215-
_ => return Err(Error::unexpected_message()),
221+
m => return Err(Error::unexpected_message(m)),
216222
}
217223

218224
Ok(CopyInSink {

tokio-postgres/src/copy_out.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
2626
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
2727

2828
match responses.next().await? {
29+
Message::ParseComplete => match responses.next().await? {
30+
Message::BindComplete => {}
31+
m => return Err(Error::unexpected_message(m)),
32+
},
2933
Message::BindComplete => {}
30-
_ => return Err(Error::unexpected_message()),
34+
m => return Err(Error::unexpected_message(m)),
3135
}
3236

3337
match responses.next().await? {
3438
Message::CopyOutResponse(_) => {}
35-
_ => return Err(Error::unexpected_message()),
39+
m => return Err(Error::unexpected_message(m)),
3640
}
3741

3842
Ok(responses)
@@ -56,7 +60,7 @@ impl Stream for CopyOutStream {
5660
match ready!(this.responses.poll_next(cx)?) {
5761
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
5862
Message::CopyDone => Poll::Ready(None),
59-
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
63+
m => Poll::Ready(Some(Err(Error::unexpected_message(m)))),
6064
}
6165
}
6266
}

tokio-postgres/src/error/mod.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Errors.
22
33
use fallible_iterator::FallibleIterator;
4-
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
4+
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody, Message};
55
use std::error::{self, Error as _Error};
66
use std::fmt;
77
use std::io;
@@ -339,7 +339,7 @@ pub enum ErrorPosition {
339339
#[derive(Debug, PartialEq)]
340340
enum Kind {
341341
Io,
342-
UnexpectedMessage,
342+
UnexpectedMessage(Message),
343343
Tls,
344344
ToSql(usize),
345345
FromSql(usize),
@@ -379,7 +379,9 @@ impl fmt::Display for Error {
379379
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
380380
match &self.0.kind {
381381
Kind::Io => fmt.write_str("error communicating with the server")?,
382-
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
382+
Kind::UnexpectedMessage(msg) => {
383+
write!(fmt, "unexpected message from server: {:?}", msg)?
384+
}
383385
Kind::Tls => fmt.write_str("error performing TLS handshake")?,
384386
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
385387
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
@@ -445,8 +447,8 @@ impl Error {
445447
Error::new(Kind::Closed, None)
446448
}
447449

448-
pub(crate) fn unexpected_message() -> Error {
449-
Error::new(Kind::UnexpectedMessage, None)
450+
pub(crate) fn unexpected_message(message: Message) -> Error {
451+
Error::new(Kind::UnexpectedMessage(message), None)
450452
}
451453

452454
#[allow(clippy::needless_pass_by_value)]

0 commit comments

Comments
 (0)