Skip to content
This repository was archived by the owner on Jan 14, 2025. It is now read-only.

Do not store and re-use typeinfo statements #3

Open
wants to merge 7 commits into
base: grafbase
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: sfackler/actions/rustup@master
- run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT
with:
version: 1.67.0
- run: echo "::set-output name=version::$(rustc --version)"
id: rust-version
- run: rustup target add wasm32-unknown-unknown
- uses: actions/cache@v3
Expand Down
23 changes: 22 additions & 1 deletion postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Header {
}

/// An enum representing Postgres backend messages.
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
Expand Down Expand Up @@ -333,6 +334,7 @@ impl Read for Buffer {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
Expand All @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationGssContinueBody(Bytes);

impl AuthenticationGssContinueBody {
Expand All @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslBody(Bytes);

impl AuthenticationSaslBody {
Expand All @@ -362,6 +366,7 @@ impl AuthenticationSaslBody {
}
}

#[derive(Debug, PartialEq)]
pub struct SaslMechanisms<'a>(&'a [u8]);

impl<'a> FallibleIterator for SaslMechanisms<'a> {
Expand All @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslContinueBody(Bytes);

impl AuthenticationSaslContinueBody {
Expand All @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslFinalBody(Bytes);

impl AuthenticationSaslFinalBody {
Expand All @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody {
}
}

#[derive(Debug, PartialEq)]
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
Expand All @@ -422,6 +430,7 @@ impl BackendKeyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CommandCompleteBody {
tag: Bytes,
}
Expand All @@ -433,6 +442,7 @@ impl CommandCompleteBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyDataBody {
storage: Bytes,
}
Expand All @@ -449,6 +459,7 @@ impl CopyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyInResponseBody {
format: u8,
len: u16,
Expand All @@ -470,6 +481,7 @@ impl CopyInResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ColumnFormats<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyOutResponseBody {
format: u8,
len: u16,
Expand All @@ -524,7 +537,7 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
Expand Down Expand Up @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ErrorResponseBody {
storage: Bytes,
}
Expand Down Expand Up @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct NoticeResponseBody {
storage: Bytes,
}
Expand All @@ -668,6 +683,7 @@ impl NoticeResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
Expand All @@ -691,6 +707,7 @@ impl NotificationResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
Expand All @@ -706,6 +723,7 @@ impl ParameterDescriptionBody {
}
}

#[derive(Debug, PartialEq)]
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
Expand All @@ -756,6 +775,7 @@ impl ParameterStatusBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ReadyForQueryBody {
status: u8,
}
Expand All @@ -767,6 +787,7 @@ impl ReadyForQueryBody {
}
}

#[derive(Debug, PartialEq)]
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
Expand Down
18 changes: 17 additions & 1 deletion postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,22 @@ impl WrongType {
}
}

/// An error indicating that a as_text conversion was attempted on a binary
/// result.
#[derive(Debug)]
pub struct WrongFormat {}

impl Error for WrongFormat {}

impl fmt::Display for WrongFormat {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot read column as text while it is in binary format"
)
}
}

/// A trait for types that can be created from a Postgres value.
///
/// # Types
Expand Down Expand Up @@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug {
/// Supported Postgres message format types
///
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Text format (UTF-8)
Text,
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" }
tokio = { version = "1.27", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec"] }
rand = "0.8.5"
whoami = "1.4.1"
whoami = "1.4"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
socket2 = { version = "0.5", features = ["all"] }
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where

match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(Portal::new(client, name, statement))
Expand Down
33 changes: 27 additions & 6 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,31 +104,31 @@ impl InnerClient {
}

pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
None
}

pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}

pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
None
}

pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}

pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
None
}

pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}

pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
pub fn type_(&self, _: Oid) -> Option<Type> {
None
}

pub fn set_type(&self, oid: Oid, type_: &Type) {
Expand Down Expand Up @@ -231,7 +231,11 @@ impl Client {
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
prepare::prepare(&self.inner, query, parameter_types, false).await
}

pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, &[], true).await
}

/// Executes a statement, returning a vector of the resulting rows.
Expand Down Expand Up @@ -368,6 +372,23 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<'a, T, S, I>(
&self,
statement: &T,
params: I,
) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query_txt(&self.inner, statement, params).await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
13 changes: 12 additions & 1 deletion tokio-postgres/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
}
}

pub struct PostgresCodec;
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}

impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
Expand Down Expand Up @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
break;
}

if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}

match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
Expand Down
21 changes: 21 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ pub struct Config {
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
pub(crate) max_backend_message_size: Option<usize>,
}

impl Default for Config {
Expand Down Expand Up @@ -240,6 +241,7 @@ impl Config {
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
max_backend_message_size: None,
}
}

Expand Down Expand Up @@ -520,6 +522,17 @@ impl Config {
self.load_balance_hosts
}

/// Set limit for backend messages size.
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
self.max_backend_message_size = Some(max_backend_message_size);
self
}

/// Get limit for backend messages size.
pub fn get_max_backend_message_size(&self) -> Option<usize> {
self.max_backend_message_size
}

fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
Expand Down Expand Up @@ -655,6 +668,14 @@ impl Config {
};
self.load_balance_hosts(load_balance_hosts);
}
"max_backend_message_size" => {
let limit = value.parse::<usize>().map_err(|_| {
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
})?;
if limit > 0 {
self.max_backend_message_size(limit);
}
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ where
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}
Expand Down
Loading