Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Remove named statements from tokio-postgres #2

Open
wants to merge 7 commits into
base: grafbase
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -53,7 +53,9 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: sfackler/actions/rustup@master
- run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT
with:
version: 1.67.0
- run: echo "::set-output name=version::$(rustc --version)"
id: rust-version
- run: rustup target add wasm32-unknown-unknown
- uses: actions/cache@v3
23 changes: 22 additions & 1 deletion postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
@@ -72,6 +72,7 @@ impl Header {
}

/// An enum representing Postgres backend messages.
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
@@ -333,6 +334,7 @@ impl Read for Buffer {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
@@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationGssContinueBody(Bytes);

impl AuthenticationGssContinueBody {
@@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslBody(Bytes);

impl AuthenticationSaslBody {
@@ -362,6 +366,7 @@ impl AuthenticationSaslBody {
}
}

#[derive(Debug, PartialEq)]
pub struct SaslMechanisms<'a>(&'a [u8]);

impl<'a> FallibleIterator for SaslMechanisms<'a> {
@@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslContinueBody(Bytes);

impl AuthenticationSaslContinueBody {
@@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslFinalBody(Bytes);

impl AuthenticationSaslFinalBody {
@@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody {
}
}

#[derive(Debug, PartialEq)]
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
@@ -422,6 +430,7 @@ impl BackendKeyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CommandCompleteBody {
tag: Bytes,
}
@@ -433,6 +442,7 @@ impl CommandCompleteBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyDataBody {
storage: Bytes,
}
@@ -449,6 +459,7 @@ impl CopyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyInResponseBody {
format: u8,
len: u16,
@@ -470,6 +481,7 @@ impl CopyInResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ColumnFormats<'a> {
buf: &'a [u8],
remaining: u16,
@@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyOutResponseBody {
format: u8,
len: u16,
@@ -524,7 +537,7 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
@@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ErrorResponseBody {
storage: Bytes,
}
@@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct NoticeResponseBody {
storage: Bytes,
}
@@ -668,6 +683,7 @@ impl NoticeResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
@@ -691,6 +707,7 @@ impl NotificationResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
@@ -706,6 +723,7 @@ impl ParameterDescriptionBody {
}
}

#[derive(Debug, PartialEq)]
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
@@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
@@ -756,6 +775,7 @@ impl ParameterStatusBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ReadyForQueryBody {
status: u8,
}
@@ -767,6 +787,7 @@ impl ReadyForQueryBody {
}
}

#[derive(Debug, PartialEq)]
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
2 changes: 1 addition & 1 deletion postgres-types/src/chrono_04.rs
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime {
impl<'a> FromSql<'a> for DateTime<Utc> {
fn from_sql(type_: &Type, raw: &[u8]) -> Result<DateTime<Utc>, Box<dyn Error + Sync + Send>> {
let naive = NaiveDateTime::from_sql(type_, raw)?;
Ok(DateTime::from_utc(naive, Utc))
Ok(DateTime::from_naive_utc_and_offset(naive, Utc))
}

accepts!(TIMESTAMPTZ);
18 changes: 17 additions & 1 deletion postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -442,6 +442,22 @@ impl WrongType {
}
}

/// An error indicating that a as_text conversion was attempted on a binary
/// result.
#[derive(Debug)]
pub struct WrongFormat {}

impl Error for WrongFormat {}

impl fmt::Display for WrongFormat {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot read column as text while it is in binary format"
)
}
}

/// A trait for types that can be created from a Postgres value.
///
/// # Types
@@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug {
/// Supported Postgres message format types
///
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Text format (UTF-8)
Text,
2 changes: 1 addition & 1 deletion tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" }
tokio = { version = "1.27", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec"] }
rand = "0.8.5"
whoami = "1.4.1"
whoami = "1.4"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
socket2 = { version = "0.5", features = ["all"] }
2 changes: 1 addition & 1 deletion tokio-postgres/src/bind.rs
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ where

match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(Portal::new(client, name, statement))
21 changes: 21 additions & 0 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -234,6 +234,10 @@ impl Client {
prepare::prepare(&self.inner, query, parameter_types).await
}

pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, &[]).await
}

/// Executes a statement, returning a vector of the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
@@ -368,6 +372,23 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<'a, T, S, I>(
&self,
statement: &T,
params: I,
) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query_txt(&self.inner, statement, params).await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
13 changes: 12 additions & 1 deletion tokio-postgres/src/codec.rs
Original file line number Diff line number Diff line change
@@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
}
}

pub struct PostgresCodec;
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}

impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
@@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
break;
}

if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}

match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
21 changes: 21 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
@@ -207,6 +207,7 @@ pub struct Config {
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
pub(crate) max_backend_message_size: Option<usize>,
}

impl Default for Config {
@@ -240,6 +241,7 @@ impl Config {
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
max_backend_message_size: None,
}
}

@@ -520,6 +522,17 @@ impl Config {
self.load_balance_hosts
}

/// Set limit for backend messages size.
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
self.max_backend_message_size = Some(max_backend_message_size);
self
}

/// Get limit for backend messages size.
pub fn get_max_backend_message_size(&self) -> Option<usize> {
self.max_backend_message_size
}

fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
@@ -655,6 +668,14 @@ impl Config {
};
self.load_balance_hosts(load_balance_hosts);
}
"max_backend_message_size" => {
let limit = value.parse::<usize>().map_err(|_| {
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
})?;
if limit > 0 {
self.max_backend_message_size(limit);
}
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
2 changes: 1 addition & 1 deletion tokio-postgres/src/connect.rs
Original file line number Diff line number Diff line change
@@ -195,7 +195,7 @@ where
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}
Loading