Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM rust:bullseye
FROM rust:1.70-bullseye

# Dependencies
RUN apt-get update -y \
Expand Down
3 changes: 3 additions & 0 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets.
tcp_keepalives_interval = 5

# Handle prepared statements.
prepared_statements = true

# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
Expand Down
10 changes: 10 additions & 0 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ where
("bytes_sent", DataType::Numeric),
("bytes_received", DataType::Numeric),
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
];

let new_map = get_server_stats();
Expand All @@ -722,6 +724,14 @@ where
.duration_since(server.connect_time())
.as_secs()
.to_string(),
server
.prepared_hit_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
];

res.put(data_row(&row));
Expand Down
215 changes: 209 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
Expand All @@ -13,7 +14,9 @@ use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
};
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
Expand All @@ -25,6 +28,11 @@ use crate::tls::Tls;

use tokio_rustls::server::TlsStream;

/// Incrementally count prepared statements
/// to avoid random conflicts in places where the random number generator is weak.
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
Lazy::new(|| Arc::new(AtomicUsize::new(0)));

/// Type of connection received from client.
enum ClientConnectionType {
Startup,
Expand Down Expand Up @@ -93,6 +101,9 @@ pub struct Client<S, T> {

/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,

/// Prepared statements
prepared_statements: HashMap<String, Parse>,
}

/// Client entrypoint.
Expand Down Expand Up @@ -682,6 +693,7 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
})
}

Expand Down Expand Up @@ -716,6 +728,7 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
})
}

Expand Down Expand Up @@ -757,6 +770,10 @@ where
// Result returned by one of the plugins.
let mut plugin_output = None;

// Prepared statement being executed
let mut prepared_statement = None;
let mut will_prepare = false;

// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
Expand All @@ -766,13 +783,16 @@ where
self.transaction_mode
);

// Should we rewrite prepared statements and bind messages?
let mut prepared_statements_enabled = get_prepared_statements();

// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';

let message = tokio::select! {
let mut message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
Expand Down Expand Up @@ -800,7 +820,21 @@ where
// allocate a connection, we wouldn't be able to send back an error message
// to the client so we buffer them and defer the decision to error out or not
// to when we get the S message
'D' | 'E' => {
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;

if let Some(name) = name {
prepared_statement = Some(name);
}
}

self.buffer.put(&message[..]);
continue;
}

'E' => {
self.buffer.put(&message[..]);
continue;
}
Expand Down Expand Up @@ -830,6 +864,11 @@ where
}

'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}

self.buffer.put(&message[..]);

if query_router.query_parser_enabled() {
Expand All @@ -846,6 +885,10 @@ where
}

'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}

self.buffer.put(&message[..]);

if query_router.query_parser_enabled() {
Expand Down Expand Up @@ -1054,7 +1097,48 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let message = match initial_message {
// Only check if we should rewrite prepared statements
// in session mode. In transaction mode, we check at the beginning of
// each transaction.
if !self.transaction_mode {
prepared_statements_enabled = get_prepared_statements();
}

debug!("Prepared statement active: {:?}", prepared_statement);

// We are processing a prepared statement.
if let Some(ref name) = prepared_statement {
debug!("Checking prepared statement is on server");
// Get the prepared statement the server expects to see.
let statement = match self.prepared_statements.get(name) {
Some(statement) => {
debug!("Prepared statement `{}` found in cache", name);
statement
}
None => {
return Err(Error::ClientError(format!(
"prepared statement `{}` not found",
name
)))
}
};

// Since it's already in the buffer, we don't need to prepare it on this server.
if will_prepare {
server.will_prepare(&statement.name);
will_prepare = false;
} else {
// The statement is not prepared on the server, so we need to prepare it.
if server.should_prepare(&statement.name) {
server.prepare(statement).await?;
}
}

// Done processing the prepared statement.
prepared_statement = None;
}

let mut message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");

Expand Down Expand Up @@ -1173,6 +1257,11 @@ where
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}

if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
Expand All @@ -1187,12 +1276,25 @@ where
// Bind
// The placeholder's replacements are here, e.g. '[email protected]' and 'true'
'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}

self.buffer.put(&message[..]);
}

// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;

if let Some(name) = name {
prepared_statement = Some(name);
}
}

self.buffer.put(&message[..]);
}

Expand Down Expand Up @@ -1235,7 +1337,7 @@ where
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

// Almost certainly true
if first_message_code == 'P' {
if first_message_code == 'P' && !prepared_statements_enabled {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
Expand Down Expand Up @@ -1363,6 +1465,107 @@ where
}
}

/// Rewrite Parse (F) message to set the prepared statement name to one we control.
/// Save it into the client cache.
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
let parse: Parse = (&message).try_into()?;

let name = parse.name.clone();

// Don't rewrite anonymous prepared statements
if parse.anonymous() {
debug!("Anonymous prepared statement");
return Ok((None, message));
}

let parse = parse.rename();

debug!(
"Renamed prepared statement `{}` to `{}` and saved to cache",
name, parse.name
);

self.prepared_statements.insert(name.clone(), parse.clone());

Ok((Some(name), parse.try_into()?))
}

/// Rewrite the Bind (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_bind(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let bind: Bind = (&message).try_into()?;
let name = bind.prepared_statement.clone();

if bind.anonymous() {
debug!("Anonymous bind message");
return Ok((None, message));
}

match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let bind = bind.reassign(prepared_stmt);

debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);

Ok((Some(name), bind.try_into()?))
}
None => {
debug!("Got bind for unknown prepared statement {:?}", bind);

error_response(
&mut self.write,
&format!(
"prepared statement \"{}\" does not exist",
bind.prepared_statement
),
)
.await?;

Err(Error::ClientError(format!(
"Prepared statement `{}` doesn't exist",
name
)))
}
}
}

/// Rewrite the Describe (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_describe(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let describe: Describe = (&message).try_into()?;
let name = describe.statement_name.clone();

if describe.anonymous() {
debug!("Anonymous describe");
return Ok((None, message));
}

match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let describe = describe.rename(&prepared_stmt.name);

debug!(
"Rewrote describe `{}` to `{}`",
name, describe.statement_name
);

Ok((Some(name), describe.try_into()?))
}

None => {
debug!("Got describe for unknown prepared statement {:?}", describe);

Ok((None, message))
}
}
}

/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();
Expand Down
Loading