diff --git a/Cargo.lock b/Cargo.lock index f6d392fc1b..6afba201e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2720,8 +2720,10 @@ dependencies = [ "anyhow", "async-std", "dotenvy", + "either", "env_logger", "futures", + "futures-util", "hex", "libsqlite3-sys", "paste", @@ -2738,6 +2740,7 @@ dependencies = [ "tokio", "trybuild", "url", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 41d05245c6..28d2977994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,10 @@ hex = "0.4.3" tempdir = "0.3.7" # Needed to test SQLCipher libsqlite3-sys = { version = "*", features = ["bundled-sqlcipher"] } +# Used to test PgExtendedQueryPipeline +uuid = "1" +futures-util = "0.3" +either = "1.6.1" # # Any @@ -275,6 +279,11 @@ name = "postgres-test-attr" path = "tests/postgres/test-attr.rs" required-features = ["postgres", "macros", "migrate"] +[[test]] +name = "postgres-pipeline" +path = "tests/postgres/pipeline.rs" +required-features = ["postgres", "macros", "migrate", "uuid"] + # # Microsoft SQL Server (MSSQL) # diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 6af7fe11cf..0bf90419d9 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -160,7 +160,7 @@ sqlformat = "0.2.0" thiserror = "1.0.30" time = { version = "0.3.2", features = ["macros", "formatting", "parsing"], optional = true } tokio-stream = { version = "0.1.8", features = ["fs"], optional = true } -smallvec = "1.7.0" +smallvec = { version = "1.7.0", features = ["const_generics"] } url = { version = "2.2.2", default-features = false } uuid = { version = "1.0", default-features = false, optional = true, features = ["std"] } webpki-roots = { version = "0.22.0", optional = true } @@ -179,4 +179,4 @@ dotenvy = "0.15" [dev-dependencies] sqlx = { version = "0.6.1", path = "..", features = ["postgres", "sqlite", "mysql"] } -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "macros"] } diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index dda7ada4a8..5927341abb 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -92,7 +92,7 @@ impl TryFrom for TypCategory { } impl PgConnection { - pub(super) async fn handle_row_description( + pub(in crate::postgres) async fn handle_row_description( &mut self, desc: Option, should_fetch: bool, diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index 80a7e9e121..a927d17e61 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -154,14 +154,14 @@ impl PgConnection { Ok(()) } - pub(crate) fn write_sync(&mut self) { + pub(in crate::postgres) fn write_sync(&mut self) { self.stream.write(message::Sync); // all SYNC messages will return a ReadyForQuery self.pending_ready_for_query_count += 1; } - async fn get_or_prepare<'a>( + pub(in crate::postgres) async fn get_or_prepare<'a>( &mut self, sql: &str, parameters: &[PgTypeInfo], diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index 325b565c3b..8c9df70099 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -32,7 +32,7 @@ pub struct PgConnection { // underlying TCP or UDS stream, // wrapped in a potentially TLS stream, // wrapped in a buffered stream - pub(crate) stream: PgStream, + pub(in crate::postgres) stream: PgStream, // process id of this backend // used to send cancel requests @@ -56,13 +56,13 @@ pub struct PgConnection { cache_type_oid: HashMap, // number of ReadyForQuery messages that we are currently expecting - pub(crate) pending_ready_for_query_count: usize, + pub(in crate::postgres) pending_ready_for_query_count: usize, // current transaction status transaction_status: TransactionStatus, - pub(crate) transaction_depth: usize, + pub(in crate::postgres) transaction_depth: usize, - log_settings: LogSettings, + pub(in crate::postgres) log_settings: LogSettings, } impl PgConnection { @@ -100,7 +100,10 @@ impl PgConnection { Ok(()) } - fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> { + pub(in crate::postgres) fn handle_ready_for_query( + &mut self, + message: Message, + ) -> Result<(), Error> { self.pending_ready_for_query_count -= 1; self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; @@ -110,7 +113,7 @@ impl PgConnection { /// Queue a simple query (not prepared) to execute the next time this connection is used. /// /// Used for rolling back transactions and releasing advisory locks. - pub(crate) fn queue_simple_query(&mut self, query: &str) { + pub(in crate::postgres) fn queue_simple_query(&mut self, query: &str) { self.pending_ready_for_query_count += 1; self.stream.write(Query(query)); } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 00abc9c967..15358db023 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -13,6 +13,7 @@ mod io; mod listener; mod message; mod options; +mod pipeline; mod query_result; mod row; mod statement; @@ -37,6 +38,7 @@ pub use error::{PgDatabaseError, PgErrorPosition}; pub use listener::{PgListener, PgNotification}; pub use message::PgSeverity; pub use options::{PgConnectOptions, PgSslMode}; +pub use pipeline::PgExtendedQueryPipeline; pub use query_result::PgQueryResult; pub use row::PgRow; pub use statement::PgStatement; @@ -51,6 +53,8 @@ pub type PgPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for Postgres. pub type PgPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Query`][crate::query::Query], specialized for Postgres. +pub type PgQuery<'q> = crate::query::Query<'q, Postgres, PgArguments>; /// An alias for [`Executor<'_, Database = Postgres>`][Executor]. pub trait PgExecutor<'c>: Executor<'c, Database = Postgres> {} impl<'c, T: Executor<'c, Database = Postgres>> PgExecutor<'c> for T {} diff --git a/sqlx-core/src/postgres/pipeline.rs b/sqlx-core/src/postgres/pipeline.rs new file mode 100644 index 0000000000..e5128cc92c --- /dev/null +++ b/sqlx-core/src/postgres/pipeline.rs @@ -0,0 +1,415 @@ +use crate::error::Error; +use crate::executor::Execute; +use crate::logger::QueryLogger; +use crate::postgres::message::{self, Bind, CommandComplete, DataRow, MessageFormat}; +use crate::postgres::statement::PgStatementMetadata; +use crate::postgres::types::Oid; +use crate::postgres::{ + PgArguments, PgConnection, PgPool, PgQueryResult, PgRow, PgValueFormat, Postgres, +}; +use either::Either; +use futures_core::Stream; +use futures_util::{stream, StreamExt, TryStreamExt}; +use smallvec::SmallVec; +use std::sync::Arc; + +// tuple that contains the data required to run a query +// +// (sql, arguments, persistent, metadata_options) +type QueryContext<'q> = ( + &'q str, + Option, + bool, + Option>, +); + +/// Pipeline of independent queries. +/// +/// Query pipeline allows to issue multiple independent queries via extended +/// query protocol in a single batch, write Sync command and wait for the result sets of all queries. +/// +/// Pipeline queries run on the same physical database connection. +/// +/// If there is no explicit transaction than queries will run in an implicit +/// transaction with shortest possible duration. +/// +/// It's assumed that the queries produce small enough result sets that together +/// fit in client's memory. Pipeline doesn't use server side cursors. +/// +/// Simple queries are not supported due to focus on efficient execution of the +/// same queries with different parameters. +/// Though technically the implicit transaction may commit by a single simple +/// Query instead of the final Sync. +/// +/// CockroachDB specifics: This transaction could be automatically retried by +/// the database gateway node during contention with other transactions as long as it can +/// buffer all result sets (see +/// https://www.cockroachlabs.com/docs/stable/transactions.html#automatic-retries). +/// +/// [PgExtendedQueryPipeline] has `N` type parameter that defines the expected +/// maximum number of pipeline queries. This number is used for stack +/// allocations. +/// +#[cfg_attr( + feature = "_rt-tokio", + doc = r##" +# Example usage + +```no_run +use sqlx::postgres::PgExtendedQueryPipeline; +use sqlx::PgPool; +use uuid::{uuid, Uuid}; + +#[tokio::main] +async fn main() -> sqlx::Result<()> { + let pool = PgPool::connect("postgres://user@postgres/db").await?; + + let user_id = uuid!("6592b7c0-b531-4613-ace5-94246b7ce0c3"); + let post_id = uuid!("252c1d98-a9b0-4f18-8298-e59058bdfe16"); + let comment_id = uuid!("fbbbb7dc-dc6f-4649-b663-8d3636035164"); + + let user_insert_query = sqlx::query( + " + INSERT INTO \"user\" (user_id, username) + VALUES + ($1, $2) + ", + ) + .bind(user_id) + .bind("alice"); + + const EXPECTED_QUERIES_IN_PIPELINE: usize = 3; + let mut pipeline = + PgExtendedQueryPipeline::::from(user_insert_query); + + // query without parameters + let post_insert_query = sqlx::query( + " + INSERT INTO post (post_id, user_id, content) + VALUES + ('252c1d98-a9b0-4f18-8298-e59058bdfe16', '6592b7c0-b531-4613-ace5-94246b7ce0c3', 'test post') + ", + ); + + pipeline.push(post_insert_query); + + let comment_insert_query = sqlx::query( + " + INSERT INTO comment (comment_id, post_id, user_id, content) + VALUES + ($1, $2, $3, $4) + ", + ) + .bind(comment_id) + .bind(post_id) + .bind(user_id) + .bind("test comment"); + + pipeline.push(comment_insert_query); + let _ = pipeline.execute(&pool).await?; + Ok(()) +} +``` +"## +)] +/// # Operations +/// There are two public operations available on pipelines: +/// +/// * Execute +/// * Fetch +/// +/// `Execute` filters any returned data rows and returns only a vector of +/// PgQueryResult structures. +/// `Execute` is available as [PgExtendedQueryPipeline::execute] method and implemented as `execute_pipeline` method +/// for the following: +/// +/// * [`&PgPool`](super::PgPool) +/// * [`&mut PgConnection`](super::connection::PgConnection) +/// +/// `Transaction` instance proxies `execute_pipeline` method to the underlying `PgConnection`. + +/// `Fetch` returns a stream of either [PgQueryResult] or [PgRow] structures. +/// PgQueryResult structures. +/// `Fetch` is implemented as `fetch_pipeline` method for [`&mut PgConnection`](super::connection::PgConnection) +/// +/// `Transaction` instance proxies `fetch_pipeline` method to the underlying [PgConnection]. +/// + +// public interface section; private section is below +pub struct PgExtendedQueryPipeline<'q, const N: usize> { + queries: SmallVec<[QueryContext<'q>; N]>, +} + +impl<'q, const N: usize> PgExtendedQueryPipeline<'q, N> { + pub fn push(&mut self, mut query: impl Execute<'q, Postgres>) { + self.queries.push(( + query.sql(), + query.take_arguments(), + query.persistent(), + query.statement().map(|s| Arc::clone(&s.metadata)), + )) + } + + pub async fn execute( + self: PgExtendedQueryPipeline<'q, N>, + pool: &PgPool, + ) -> Result, Error> { + pool.execute_pipeline(self).await + } +} + +impl<'q, E, const N: usize> From for PgExtendedQueryPipeline<'q, N> +where + E: Execute<'q, Postgres>, +{ + /// Query pipeline has at least one query. + fn from(query: E) -> Self { + let mut pipeline = Self { + queries: SmallVec::new(), + }; + pipeline.push(query); + pipeline + } +} + +impl PgPool { + pub async fn execute_pipeline<'q, const N: usize>( + &self, + pipeline: PgExtendedQueryPipeline<'q, N>, + ) -> Result, Error> { + let mut conn = self.acquire().await?; + conn.execute_pipeline(pipeline).await + } +} + +impl PgConnection { + pub async fn execute_pipeline<'q, const N: usize>( + &mut self, + pipeline: PgExtendedQueryPipeline<'q, N>, + ) -> Result, Error> { + let pgresults = self + .fetch_pipeline(pipeline) + .await? + .filter_map(|pgresult_or_row_result| async move { + match pgresult_or_row_result { + Ok(Either::Left(pgresult)) => Some(Ok(pgresult)), + // filter data rows + Ok(Either::Right(_)) => None, + Err(e) => Some(Err(e)), + } + }) + .try_collect() + .await?; + Ok(pgresults) + } + + pub async fn fetch_pipeline<'e, 'c: 'e, 'q: 'e, const N: usize>( + &'c mut self, + pipeline: PgExtendedQueryPipeline<'q, N>, + ) -> Result, Error>> + 'e, Error> { + self.run_pipeline(pipeline).await + } +} + +// Private interface section + +impl<'q, const N: usize> PgExtendedQueryPipeline<'q, N> { + fn len(&self) -> usize { + self.queries.len() + } + + fn queries(&self) -> &SmallVec<[QueryContext<'q>; N]> { + &self.queries + } + + fn into_querycontext_stream(self) -> impl Stream> { + stream::iter(self.queries) + } +} + +impl PgConnection { + async fn get_or_prepare_pipeline<'q, const N: usize>( + &mut self, + pipeline: PgExtendedQueryPipeline<'q, N>, + ) -> Result, Error> { + let prepared_statements = SmallVec::<[(Oid, PgArguments); N]>::new(); + + pipeline + .into_querycontext_stream() + .map(|v| Ok(v)) + .try_fold( + (self, prepared_statements), + |(conn, mut prepared), (sql, maybe_arguments, persistent, maybe_metadata)| { + async move { + let mut arguments = maybe_arguments.unwrap_or_default(); + // prepare the statement if this our first time executing it + // always return the statement ID here + let (statement, metadata) = conn + .get_or_prepare(sql, &arguments.types, persistent, maybe_metadata) + .await?; + + // patch holes created during encoding + arguments.apply_patches(conn, &metadata.parameters).await?; + + // apply patches use fetch_optional thaht may produce `PortalSuspended` message, + // consume messages til `ReadyForQuery` before bind and execute + conn.wait_until_ready().await?; + prepared.push((statement, arguments)); + Ok((conn, prepared)) + } + }, + ) + .await + .map(|(_, prepared)| prepared) + } + + async fn run_pipeline<'e, 'c: 'e, 'q: 'e, const N: usize>( + &'c mut self, + pipeline: PgExtendedQueryPipeline<'q, N>, + ) -> Result, Error>> + 'e, Error> { + // loggers stack is in reversed query order + let mut loggers_stack = self.query_loggers_stack(&pipeline); + + // before we continue, wait until we are "ready" to accept more queries + self.wait_until_ready().await?; + + let pipeline_length = pipeline.len(); + let prepared_statements = self.get_or_prepare_pipeline(pipeline).await?; + + prepared_statements + .into_iter() + .for_each(|(statement, arguments)| { + // bind to attach the arguments to the statement and create a portal + self.stream.write(Bind { + portal: None, + statement, + formats: &[PgValueFormat::Binary], + num_params: arguments.types.len() as i16, + params: &arguments.buffer, + result_formats: &[PgValueFormat::Binary], + }); + + self.stream.write(message::Execute { + portal: None, + // result set is expected to be small enough to buffer on client side + // don't use server-side cursors + limit: 0, + }); + }); + + // finally, [Sync] asks postgres to process the messages that we sent and respond with + // a [ReadyForQuery] message when it's completely done. + self.write_sync(); + // send all commands in batch + self.stream.flush().await?; + + Ok(try_stream! { + let mut metadata = Arc::new(PgStatementMetadata::default()); + // prepared statements are binary + let format = PgValueFormat::Binary; + + loop { + let message = self.stream.recv().await?; + + match message.format { + MessageFormat::BindComplete + | MessageFormat::ParseComplete + | MessageFormat::ParameterDescription + | MessageFormat::NoData => { + // harmless messages to ignore + } + + MessageFormat::CommandComplete => { + // a SQL command completed normally + let cc: CommandComplete = message.decode()?; + + let rows_affected = cc.rows_affected(); + if let Some(logger) = loggers_stack.last_mut() { + logger.increase_rows_affected(rows_affected); + // drop and finish current logger + loggers_stack.pop(); + } + else { + return Err(err_protocol!( + "execute: received more CommandComplete messages than expected; expected: {}", + pipeline_length + )); + + } + + r#yield!(Either::Left(PgQueryResult { + rows_affected, + })); + } + + MessageFormat::EmptyQueryResponse => { + // empty query string passed to an unprepared execute + } + + MessageFormat::RowDescription => { + // indicates that a *new* set of rows are about to be returned + let (columns, column_names) = self + .handle_row_description(Some(message.decode()?), false) + .await?; + + metadata = Arc::new(PgStatementMetadata { + column_names, + columns, + parameters: Vec::default(), + }); + } + + MessageFormat::DataRow => { + if let Some(logger) = loggers_stack.last_mut() { + logger.increment_rows_returned(); + } + else { + return Err(err_protocol!( + "execute: received a data row after receiving the expected {} CommandComplete messages", + pipeline_length + )); + + } + + // one of the set of rows returned by a SELECT, FETCH, etc query + let data: DataRow = message.decode()?; + let row = PgRow { + data, + format, + metadata: Arc::clone(&metadata), + }; + + r#yield!(Either::Right(row)); + } + + MessageFormat::ReadyForQuery => { + // processing of the query string is complete + self.handle_ready_for_query(message)?; + break; + } + + _ => { + return Err(err_protocol!( + "execute: unexpected message: {:?}", + message.format + )); + } + } + } + + Ok(()) + }) + } + + fn query_loggers_stack<'q, const N: usize>( + &self, + pipeline: &PgExtendedQueryPipeline<'q, N>, + ) -> SmallVec<[QueryLogger<'q>; N]> { + pipeline + .queries() + .iter() + .rev() + .map(|(q_ref, _, _, _)| QueryLogger::new(*q_ref, self.log_settings.clone())) + .collect() + } +} diff --git a/src/macros/test.md b/src/macros/test.md index dbcb7646bc..a34ae9d935 100644 --- a/src/macros/test.md +++ b/src/macros/test.md @@ -1,11 +1,11 @@ Mark an `async fn` as a test with SQLx support. -The test will automatically be executed in the async runtime according to the chosen +The test will automatically be executed in the async runtime according to the chosen `runtime-{async-std, tokio}-{native-tls, rustls}` feature. By default, this behaves identically to `#[tokio::test]`1 or `#[async_std::test]`: -```rust,norun +```rust,no_run # // Note if reading these examples directly in `test.md`: # // lines prefixed with `#` are not meant to be shown; # // they are supporting code to help the examples to compile successfully. @@ -13,7 +13,7 @@ By default, this behaves identically to `#[tokio::test]`1 or `#[async #[sqlx::test] async fn test_async_fn() { tokio::task::yield_now().await; -} +} ``` However, several advanced features are also supported as shown in the next section. @@ -36,20 +36,20 @@ This feature is activated by changing the signature of your test function. The f * `PoolConnection`, etc. * `async fn(PoolOptions, impl ConnectOptions) -> Ret` * Where `impl ConnectOptions` is, e.g, `PgConnectOptions`, `MySqlConnectOptions`, etc. - * If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`), + * If your test wants to create its own `Pool` (for example, to set pool callbacks or to modify `ConnectOptions`), you can use this signature. Where `DB` is a supported `Database` type and `Ret` is `()` or `Result<_, _>`. ##### Supported Databases -Most of these will require you to set `DATABASE_URL` as an environment variable +Most of these will require you to set `DATABASE_URL` as an environment variable or in a `.env` file like `sqlx::query!()` _et al_, to give the test driver a superuser connection with which to manage test databases. | Database | Requires `DATABASE_URL` | -| --- | --- | +| --- | --- | | Postgres | Yes | | MySQL | Yes | | SQLite | No2 | @@ -58,12 +58,12 @@ Test databases are automatically cleaned up as tests succeed, but failed tests w to facilitate debugging. Note that to simplify the implementation, panics are _always_ considered to be failures, even for `#[should_panic]` tests. -To limit disk space usage, any previously created test databases will be deleted the next time a test binary using +To limit disk space usage, any previously created test databases will be deleted the next time a test binary using `#[sqlx::test]` is run. ```rust,no_run # #[cfg(all(feature = "migrate", feature = "postgres"))] -# mod example { +# mod example { use sqlx::PgPool; #[sqlx::test] @@ -73,12 +73,12 @@ async fn basic_test(pool: PgPool) -> sqlx::Result<()> { sqlx::query("SELECT * FROM foo") .fetch_one(&mut conn) .await?; - + assert_eq!(foo.get::("bar"), "foobar!"); - + Ok(()) } -# } +# } ``` 2 SQLite defaults to `target/sqlx/test-dbs/.sqlite` where `` is the path of the test function @@ -86,8 +86,8 @@ converted to a filesystem path (`::` replaced with `/`). ### Automatic Migrations (requires `migrate` feature) -To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a -`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's +To ensure a straightforward test implementation against a fresh test database, migrations are automatically applied if a +`migrations` folder is found in the same directory as `CARGO_MANIFEST_DIR` (the directory where the current crate's `Cargo.toml` resides). You can override the resolved path relative to `CARGO_MANIFEST_DIR` in the attribute (global overrides are not currently @@ -95,7 +95,7 @@ supported): ```rust,ignore # #[cfg(all(feature = "migrate", feature = "postgres"))] -# mod example { +# mod example { use sqlx::PgPool; #[sqlx::test(migrations = "foo_migrations")] @@ -105,9 +105,9 @@ async fn basic_test(pool: PgPool) -> sqlx::Result<()> { sqlx::query("SELECT * FROM foo") .fetch_one(&mut conn) .await?; - + assert_eq!(foo.get::("bar"), "foobar!"); - + Ok(()) } # } @@ -123,17 +123,17 @@ pub static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("foo_migrations"); `foo_crate/tests/foo_test.rs` ```rust,no_run # #[cfg(all(feature = "migrate", feature = "postgres"))] -# mod example { +# mod example { use sqlx::PgPool; # // This is standing in for the main crate since doc examples don't support multiple crates. -# mod foo_crate { +# mod foo_crate { # use std::borrow::Cow; # static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate::Migrator { # migrations: Cow::Borrowed(&[]), # ignore_missing: false, # }; -# } +# } // You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. #[sqlx::test(migrator = "foo_crate::MIGRATOR")] @@ -143,9 +143,9 @@ async fn basic_test(pool: PgPool) -> sqlx::Result<()> { sqlx::query("SELECT * FROM foo") .fetch_one(&mut conn) .await?; - + assert_eq!(foo.get::("bar"), "foobar!"); - + Ok(()) } # } @@ -155,21 +155,21 @@ Or disable migrations processing entirely: ```rust,no_run # #[cfg(all(feature = "migrate", feature = "postgres"))] -# mod example { +# mod example { use sqlx::PgPool; #[sqlx::test(migrations = false)] async fn basic_test(pool: PgPool) -> sqlx::Result<()> { let mut conn = pool.acquire().await?; - + conn.execute("CREATE TABLE foo(bar text)").await?; sqlx::query("SELECT * FROM foo") .fetch_one(&mut conn) .await?; - + assert_eq!(foo.get::("bar"), "foobar!"); - + Ok(()) } # } @@ -188,7 +188,7 @@ You can pass a list of fixture names to the attribute like so, and they will be ```rust,no_run # #[cfg(all(feature = "migrate", feature = "postgres"))] -# mod example { +# mod example { # struct App {} # fn create_app(pool: PgPool) -> App { App {} } use sqlx::PgPool; @@ -196,15 +196,15 @@ use serde_json::json; #[sqlx::test(fixtures("users", "posts"))] async fn test_create_comment(pool: PgPool) -> sqlx::Result<()> { - // See examples/postgres/social-axum-with-tests for a more in-depth example. - let mut app = create_app(pool); - + // See examples/postgres/social-axum-with-tests for a more in-depth example. + let mut app = create_app(pool); + let comment = test_request( &mut app, "POST", "/v1/comment", json! { "postId": "1234" } ).await?; - + assert_eq!(comment["postId"], "1234"); - + Ok(()) } # } @@ -213,6 +213,6 @@ async fn test_create_comment(pool: PgPool) -> sqlx::Result<()> { Fixtures are resolved relative to the current file as `./fixtures/{name}.sql`. 3Ordering for test fixtures is entirely up to the application, and each test may choose which fixtures to -apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped -in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key -requirements are always satisfied, or else you might get errors. +apply and which to omit. However, since each fixture is applied separately (sent as a single command string, so wrapped +in an implicit `BEGIN` and `COMMIT`), you will want to make sure to order the fixtures such that foreign key +requirements are always satisfied, or else you might get errors. diff --git a/tests/postgres/pipeline.rs b/tests/postgres/pipeline.rs new file mode 100644 index 0000000000..99ae9afa1f --- /dev/null +++ b/tests/postgres/pipeline.rs @@ -0,0 +1,227 @@ +// Test PgExtendedQueryPipeline + +use either::Either; +use futures_util::TryStreamExt; +use sqlx::postgres::PgExtendedQueryPipeline; +use sqlx::PgPool; +use uuid::{uuid, Uuid}; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/postgres/migrations"); + +async fn cleanup_test_data( + pool: &PgPool, + user_id: Uuid, + post_id: Uuid, + comment_id: Uuid, +) -> sqlx::Result<()> { + sqlx::query("DELETE FROM comment WHERE comment_id = $1") + .bind(comment_id) + .execute(pool) + .await?; + sqlx::query("DELETE FROM post WHERE post_id = $1") + .bind(post_id) + .execute(pool) + .await?; + sqlx::query("DELETE FROM \"user\" WHERE user_id = $1") + .bind(user_id) + .execute(pool) + .await?; + Ok(()) +} + +// Ensure the test data exists or not +// +// not_exists == true => the test data shouldn't exist +// not_exists == false => the test data is expected to exist +async fn ensure_test_data( + not_exists: bool, + user_id: Uuid, + post_id: Uuid, + comment_id: Uuid, + pool: &PgPool, +) -> sqlx::Result<()> { + let user_exists_query = + sqlx::query_scalar("SELECT exists(SELECT 1 FROM \"user\" WHERE user_id = $1)") + .bind(user_id); + let post_exists_query = + sqlx::query_scalar("SELECT exists(SELECT 1 FROM post WHERE post_id = $1)").bind(post_id); + let comment_exists_query = + sqlx::query_scalar("SELECT exists(SELECT 1 FROM comment WHERE comment_id = $1)") + .bind(comment_id); + + let user_exists: bool = user_exists_query.fetch_one(pool).await?; + assert!(not_exists ^ user_exists); + + let post_exists: bool = post_exists_query.fetch_one(pool).await?; + assert!(not_exists ^ post_exists); + + let comment_exists: bool = comment_exists_query.fetch_one(pool).await?; + assert!(not_exists ^ comment_exists); + Ok(()) +} + +const EXPECTED_QUERIES_IN_PIPELINE: usize = 3; + +fn construct_test_pipeline( + user_id: Uuid, + post_id: Uuid, + comment_id: Uuid, +) -> PgExtendedQueryPipeline<'static, EXPECTED_QUERIES_IN_PIPELINE> { + // query with parameters + let user_insert_query = sqlx::query( + " + INSERT INTO \"user\" (user_id, username) + VALUES + ($1, $2) + ", + ) + .bind(user_id) + .bind("alice"); + + let mut pipeline = + PgExtendedQueryPipeline::::from(user_insert_query); + + // query without parameters + let post_insert_query = sqlx::query( + " + INSERT INTO post (post_id, user_id, content) + VALUES + ('252c1d98-a9b0-4f18-8298-e59058bdfe16', '6592b7c0-b531-4613-ace5-94246b7ce0c3', 'test post') + ", + ); + + pipeline.push(post_insert_query); + + let comment_insert_query = sqlx::query( + " + INSERT INTO comment (comment_id, post_id, user_id, content) + VALUES + ($1, $2, $3, $4) + ", + ) + .bind(comment_id) + .bind(post_id) + .bind(user_id) + .bind("test comment"); + + pipeline.push(comment_insert_query); + pipeline +} + +// test execute/execute_pipeline methods +#[sqlx::test(migrations = "tests/postgres/migrations")] +async fn it_executes_pipeline(pool: PgPool) -> sqlx::Result<()> { + // 0. ensure the clean state + + let user_id = uuid!("6592b7c0-b531-4613-ace5-94246b7ce0c3"); + let post_id = uuid!("252c1d98-a9b0-4f18-8298-e59058bdfe16"); + let comment_id = uuid!("fbbbb7dc-dc6f-4649-b663-8d3636035164"); + + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + ensure_test_data(true, user_id, post_id, comment_id, &pool).await?; + + // 1. construct pipeline of 3 inserts + let pipeline = construct_test_pipeline(user_id, post_id, comment_id); + + // 2. execute pipeline via connection pool and validate PgQueryResult values + let query_results = pipeline.execute(&pool).await?; + + for result in query_results { + // each insert created a row + assert_eq!(result.rows_affected(), 1); + } + + // 3. assert the data was inserted + ensure_test_data(false, user_id, post_id, comment_id, &pool).await?; + + // 4. cleanup + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + + // 5. construct pipeline of 3 inserts + let pipeline = construct_test_pipeline(user_id, post_id, comment_id); + + // 6. execute pipeline in an explicit transaction and validate PgQueryResult values + let mut tx = pool.begin().await?; + + let query_results = tx.execute_pipeline(pipeline).await?; + + tx.commit().await?; + + for result in query_results { + // each insert created a row + assert_eq!(result.rows_affected(), 1); + } + // 7. assert the data was inserted + ensure_test_data(false, user_id, post_id, comment_id, &pool).await?; + + // 8. cleanup + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + + Ok(()) +} + +// test fetch_pipeline methods +#[sqlx::test(migrations = "tests/postgres/migrations")] +async fn it_fetches_pipeline(pool: PgPool) -> sqlx::Result<()> { + // 0. ensure the clean state + + let user_id = uuid!("6592b7c0-b531-4613-ace5-94246b7ce0c3"); + let post_id = uuid!("252c1d98-a9b0-4f18-8298-e59058bdfe16"); + let comment_id = uuid!("fbbbb7dc-dc6f-4649-b663-8d3636035164"); + + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + ensure_test_data(true, user_id, post_id, comment_id, &pool).await?; + + // 1. construct pipeline of 3 inserts + let pipeline = construct_test_pipeline(user_id, post_id, comment_id); + + // 2. fetch pipeline via a pool connection and validate PgQueryResult values + let mut conn = pool.acquire().await?; + conn.fetch_pipeline(pipeline) + .await? + .try_for_each(|pg_result_or_row| async { + match pg_result_or_row { + // each insert created a row + Either::Left(pg_result) => assert_eq!(pg_result.rows_affected(), 1), + // inserts shouldn't return data rows + Either::Right(_) => unreachable!(), + } + Ok(()) + }) + .await?; + drop(conn); + + // 3. assert the data was inserted + ensure_test_data(false, user_id, post_id, comment_id, &pool).await?; + + // 4. cleanup + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + + // 5. construct pipeline of 3 inserts + let pipeline = construct_test_pipeline(user_id, post_id, comment_id); + + // 6. fetch pipeline in an explicit transaction and validate PgQueryResult values + + let mut tx = pool.begin().await?; + tx.fetch_pipeline(pipeline) + .await? + .try_for_each(|pg_result_or_row| async { + match pg_result_or_row { + // each insert created a row + Either::Left(pg_result) => assert_eq!(pg_result.rows_affected(), 1), + // inserts shouldn't return data rows + Either::Right(_) => unreachable!(), + } + Ok(()) + }) + .await?; + tx.commit().await?; + + // 7. assert the data was inserted + ensure_test_data(false, user_id, post_id, comment_id, &pool).await?; + + // 8. cleanup + cleanup_test_data(&pool, user_id, post_id, comment_id).await?; + + Ok(()) +}