diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index f79392b7..8aa24265 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -157,7 +157,10 @@ jobs: # running containers via `services` only works on linux # https://github.com/actions/runner/issues/1866 - name: Setup postgres + id: postgres uses: ikalnytskyi/action-setup-postgres@v7 + - name: Print Roles + run: psql ${{ steps.postgres.outputs.connection-uri }} -c "select rolname from pg_roles;" - name: Run tests run: cargo test --workspace diff --git a/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json b/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json new file mode 100644 index 00000000..dfc842b7 --- /dev/null +++ b/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "select rolname from pg_catalog.pg_roles;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "rolname", + "type_info": "Name" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + true + ] + }, + "hash": "b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba" +} diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index da6d23bc..b1dcbdf7 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -44,6 +44,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio mod tests { use std::vec; + use sqlx::{Executor, PgPool}; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -66,8 +68,8 @@ mod tests { } } - #[tokio::test] - async fn completes_columns() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_columns(pool: PgPool) { let setup = r#" create schema private; @@ -87,6 +89,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let queries: Vec = vec![ TestCase { message: "correctly prefers the columns of present tables", @@ -121,7 +125,7 @@ mod tests { ]; for q in queries { - let (tree, cache) = get_test_deps(setup, q.get_input_query()).await; + let (tree, cache) = get_test_deps(None, q.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, q.get_input_query()); let results = complete(params); @@ -137,8 +141,8 @@ mod tests { } } - #[tokio::test] - async fn shows_multiple_columns_if_no_relation_specified() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn shows_multiple_columns_if_no_relation_specified(pool: PgPool) { let setup = r#" create schema private; @@ -158,6 +162,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let case = TestCase { query: format!(r#"select n{};"#, CURSOR_POS), description: "", @@ -165,11 +171,11 @@ mod tests { message: "", }; - let (tree, cache) = get_test_deps(setup, case.get_input_query()).await; + let (tree, cache) = get_test_deps(None, case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, case.get_input_query()); let mut items = complete(params); - let _ = items.split_off(6); + let _ = items.split_off(4); #[derive(Eq, PartialEq, Debug)] struct LabelAndDesc { @@ -190,8 +196,6 @@ mod tests { ("narrator", "public.audio_books"), ("narrator_id", "private.audio_books"), ("id", "public.audio_books"), - ("name", "Schema: pg_catalog"), - ("nameconcatoid", "Schema: pg_catalog"), ] .into_iter() .map(|(label, schema)| LabelAndDesc { @@ -203,8 +207,8 @@ mod tests { assert_eq!(labels, expected); } - #[tokio::test] - async fn suggests_relevant_columns_without_letters() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_relevant_columns_without_letters(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -221,7 +225,7 @@ mod tests { description: "", }; - let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let (tree, cache) = get_test_deps(Some(setup), test_case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, test_case.get_input_query()); let results = complete(params); @@ -251,8 +255,8 @@ mod tests { ); } - #[tokio::test] - async fn ignores_cols_in_from_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn ignores_cols_in_from_clause(pool: PgPool) { let setup = r#" create schema private; @@ -271,7 +275,7 @@ mod tests { description: "", }; - let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let (tree, cache) = get_test_deps(Some(setup), test_case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, test_case.get_input_query()); let results = complete(params); @@ -282,8 +286,8 @@ mod tests { ); } - #[tokio::test] - async fn prefers_columns_of_mentioned_tables() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_columns_of_mentioned_tables(pool: PgPool) { let setup = r#" create schema private; @@ -304,6 +308,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!(r#"select {} from users"#, CURSOR_POS).as_str(), vec![ @@ -312,7 +318,8 @@ mod tests { CompletionAssertion::Label("id2".into()), CompletionAssertion::Label("name2".into()), ], - setup, + None, + &pool, ) .await; @@ -324,7 +331,8 @@ mod tests { CompletionAssertion::Label("id1".into()), CompletionAssertion::Label("name1".into()), ], - setup, + None, + &pool, ) .await; @@ -332,13 +340,14 @@ mod tests { assert_complete_results( format!(r#"select sett{} from private.users"#, CURSOR_POS).as_str(), vec![CompletionAssertion::Label("user_settings".into())], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn filters_out_by_aliases() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn filters_out_by_aliases(pool: PgPool) { let setup = r#" create schema auth; @@ -357,6 +366,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + // test in SELECT clause assert_complete_results( format!( @@ -374,7 +385,8 @@ mod tests { CompletionAssertion::Label("title".to_string()), CompletionAssertion::Label("user_id".to_string()), ], - setup, + None, + &pool, ) .await; @@ -396,13 +408,14 @@ mod tests { CompletionAssertion::Label("title".to_string()), CompletionAssertion::Label("user_id".to_string()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn does_not_complete_cols_in_join_clauses() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn does_not_complete_cols_in_join_clauses(pool: PgPool) { let setup = r#" create schema auth; @@ -435,13 +448,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn completes_in_join_on_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_in_join_on_clause(pool: PgPool) { let setup = r#" create schema auth; @@ -460,6 +474,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!( "select u.id, auth.posts.content from auth.users u join auth.posts on u.{}", @@ -472,7 +488,8 @@ mod tests { CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), ], - setup, + None, + &pool, ) .await; @@ -488,13 +505,14 @@ mod tests { CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn prefers_not_mentioned_columns() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_not_mentioned_columns(pool: PgPool) { let setup = r#" create schema auth; @@ -513,6 +531,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!( "select {} from public.one o join public.two on o.id = t.id;", @@ -526,7 +546,8 @@ mod tests { CompletionAssertion::Label("d".to_string()), CompletionAssertion::Label("e".to_string()), ], - setup, + None, + &pool, ) .await; @@ -546,7 +567,8 @@ mod tests { CompletionAssertion::Label("z".to_string()), CompletionAssertion::Label("a".to_string()), ], - setup, + None, + &pool, ) .await; @@ -562,7 +584,8 @@ mod tests { CompletionAssertion::LabelAndDesc("id".to_string(), "public.two".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; @@ -574,13 +597,14 @@ mod tests { ) .as_str(), vec![CompletionAssertion::Label("z".to_string())], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_columns_in_insert_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_in_insert_clause(pool: PgPool) { let setup = r#" create table instruments ( id bigint primary key generated always as identity, @@ -595,6 +619,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + // We should prefer the instrument columns, even though they // are lower in the alphabet @@ -605,7 +631,8 @@ mod tests { CompletionAssertion::Label("name".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; @@ -615,14 +642,16 @@ mod tests { CompletionAssertion::Label("name".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; assert_complete_results( format!("insert into instruments (id, {}, name)", CURSOR_POS).as_str(), vec![CompletionAssertion::Label("z".to_string())], - setup, + None, + &pool, ) .await; @@ -637,20 +666,22 @@ mod tests { CompletionAssertion::Label("id".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; // no completions in the values list! assert_no_complete_results( format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_columns_in_where_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_in_where_clause(pool: PgPool) { let setup = r#" create table instruments ( id bigint primary key generated always as identity, @@ -666,6 +697,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("select name from instruments where {} ", CURSOR_POS).as_str(), vec![ @@ -674,7 +707,8 @@ mod tests { CompletionAssertion::Label("name".into()), CompletionAssertion::Label("z".into()), ], - setup, + None, + &pool, ) .await; @@ -689,7 +723,8 @@ mod tests { CompletionAssertion::KindNotExists(CompletionItemKind::Column), CompletionAssertion::KindNotExists(CompletionItemKind::Schema), ], - setup, + None, + &pool, ) .await; @@ -705,7 +740,8 @@ mod tests { CompletionAssertion::Label("name".into()), CompletionAssertion::Label("z".into()), ], - setup, + None, + &pool, ) .await; @@ -721,7 +757,8 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/functions.rs b/crates/pgt_completions/src/providers/functions.rs index f1b57e8c..2bc4f331 100644 --- a/crates/pgt_completions/src/providers/functions.rs +++ b/crates/pgt_completions/src/providers/functions.rs @@ -65,13 +65,15 @@ fn get_completion_text(ctx: &CompletionContext, func: &Function) -> CompletionTe #[cfg(test)] mod tests { + use sqlx::PgPool; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{CURSOR_POS, get_test_deps, get_test_params}, }; - #[tokio::test] - async fn completes_fn() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_fn(pool: PgPool) { let setup = r#" create or replace function cool() returns trigger @@ -86,7 +88,7 @@ mod tests { let query = format!("select coo{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -98,8 +100,8 @@ mod tests { assert_eq!(label, "cool"); } - #[tokio::test] - async fn prefers_fn_if_invocation() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_fn_if_invocation(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -119,7 +121,7 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -132,8 +134,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Function); } - #[tokio::test] - async fn prefers_fn_in_select_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_fn_in_select_clause(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -153,7 +155,7 @@ mod tests { let query = format!(r#"select coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -166,8 +168,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Function); } - #[tokio::test] - async fn prefers_function_in_from_clause_if_invocation() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_function_in_from_clause_if_invocation(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -187,7 +189,7 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index a4d3a9bb..216fcefa 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -59,10 +59,12 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; - #[tokio::test] - async fn completes_within_quotation_marks() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_within_quotation_marks(pool: PgPool) { let setup = r#" create schema private; @@ -84,13 +86,16 @@ mod tests { with check (true); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(), vec![ CompletionAssertion::Label("read for public users disallowed".into()), CompletionAssertion::Label("write for public users allowed".into()), ], - setup, + None, + &pool, ) .await; @@ -99,7 +104,8 @@ mod tests { vec![CompletionAssertion::Label( "write for public users allowed".into(), )], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index 02d2fd0c..561da0f8 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -27,13 +27,15 @@ pub fn complete_schemas<'a>(ctx: &'a CompletionContext, builder: &mut Completion #[cfg(test)] mod tests { + use sqlx::PgPool; + use crate::{ CompletionItemKind, test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}, }; - #[tokio::test] - async fn autocompletes_schemas() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_schemas(pool: PgPool) { let setup = r#" create schema private; create schema auth; @@ -75,13 +77,14 @@ mod tests { CompletionItemKind::Schema, ), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_and_schemas_with_matching_keys() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_and_schemas_with_matching_keys(pool: PgPool) { let setup = r#" create schema ultimate; @@ -99,7 +102,8 @@ mod tests { CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("ultimate".into(), CompletionItemKind::Schema), ], - setup, + Some(setup), + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 2102d41c..3fbee8f1 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -42,6 +42,8 @@ pub fn complete_tables<'a>(ctx: &'a CompletionContext, builder: &mut CompletionB #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -50,8 +52,8 @@ mod tests { }, }; - #[tokio::test] - async fn autocompletes_simple_table() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_simple_table(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -62,7 +64,7 @@ mod tests { let query = format!("select * from u{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -77,8 +79,8 @@ mod tests { ) } - #[tokio::test] - async fn autocompletes_table_alphanumerically() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_table_alphanumerically(pool: PgPool) { let setup = r#" create table addresses ( id serial primary key @@ -93,6 +95,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let test_cases = vec![ (format!("select * from u{}", CURSOR_POS), "users"), (format!("select * from e{}", CURSOR_POS), "emails"), @@ -100,7 +104,7 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(None, query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -116,8 +120,8 @@ mod tests { } } - #[tokio::test] - async fn autocompletes_table_with_schema() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_table_with_schema(pool: PgPool) { let setup = r#" create schema customer_support; create schema private; @@ -135,6 +139,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let test_cases = vec![ (format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically (format!("select * from private.u{}", CURSOR_POS), "user_z"), @@ -145,7 +151,7 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(None, query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -161,8 +167,8 @@ mod tests { } } - #[tokio::test] - async fn prefers_table_in_from_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_table_in_from_clause(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -182,7 +188,7 @@ mod tests { let query = format!(r#"select * from coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -195,8 +201,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Table); } - #[tokio::test] - async fn suggests_tables_in_update() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_update(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -204,13 +210,16 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("update {}", CURSOR_POS).as_str(), vec![CompletionAssertion::LabelAndKind( "public".into(), CompletionItemKind::Schema, )], - setup, + None, + &pool, ) .await; @@ -220,12 +229,17 @@ mod tests { "coos".into(), CompletionItemKind::Table, )], - setup, + None, + &pool, ) .await; - assert_no_complete_results(format!("update public.coos {}", CURSOR_POS).as_str(), setup) - .await; + assert_no_complete_results( + format!("update public.coos {}", CURSOR_POS).as_str(), + None, + &pool, + ) + .await; assert_complete_results( format!("update coos set {}", CURSOR_POS).as_str(), @@ -233,7 +247,8 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; @@ -243,13 +258,14 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_delete() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_delete(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -257,7 +273,9 @@ mod tests { ); "#; - assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), setup).await; + pool.execute(setup).await.unwrap(); + + assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), None, &pool).await; assert_complete_results( format!("delete from {}", CURSOR_POS).as_str(), @@ -265,14 +283,16 @@ mod tests { CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("coos".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; assert_complete_results( format!("delete from public.{}", CURSOR_POS).as_str(), vec![CompletionAssertion::Label("coos".into())], - setup, + None, + &pool, ) .await; @@ -282,13 +302,14 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_join() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_join(pool: PgPool) { let setup = r#" create schema auth; @@ -315,13 +336,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_alter_and_drop_statements() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_alter_and_drop_statements(pool: PgPool) { let setup = r#" create schema auth; @@ -340,6 +362,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("alter table {}", CURSOR_POS).as_str(), vec![ @@ -348,7 +372,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -360,7 +385,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -372,7 +398,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -384,13 +411,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_insert_into() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_insert_into(pool: PgPool) { let setup = r#" create schema auth; @@ -401,6 +429,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("insert into {}", CURSOR_POS).as_str(), vec![ @@ -408,7 +438,8 @@ mod tests { CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -418,7 +449,8 @@ mod tests { "users".into(), CompletionItemKind::Table, )], - setup, + None, + &pool, ) .await; @@ -434,7 +466,8 @@ mod tests { CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 5323e2bc..0be9e48a 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -218,12 +218,14 @@ impl CompletionFilter<'_> { #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::test_helper::{ CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results, }; - #[tokio::test] - async fn completion_after_asterisk() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_after_asterisk(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -232,7 +234,9 @@ mod tests { ); "#; - assert_no_complete_results(format!("select * {}", CURSOR_POS).as_str(), setup).await; + pool.execute(setup).await.unwrap(); + + assert_no_complete_results(format!("select * {}", CURSOR_POS).as_str(), None, &pool).await; // if there s a COMMA after the asterisk, we're good assert_complete_results( @@ -242,19 +246,21 @@ mod tests { CompletionAssertion::Label("email".into()), CompletionAssertion::Label("id".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn completion_after_create_table() { - assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), "").await; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_after_create_table(pool: PgPool) { + assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), None, &pool) + .await; } - #[tokio::test] - async fn completion_in_column_definitions() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_in_column_definitions(pool: PgPool) { let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); - assert_no_complete_results(query.as_str(), "").await; + assert_no_complete_results(query.as_str(), None, &pool).await; } } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 2fe12511..a8c89f50 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -187,6 +187,16 @@ impl CompletionScore<'_> { } } + fn get_item_name(&self) -> &str { + match self.data { + CompletionRelevanceData::Table(t) => t.name.as_str(), + CompletionRelevanceData::Function(f) => f.name.as_str(), + CompletionRelevanceData::Column(c) => c.name.as_str(), + CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Policy(p) => p.name.as_str(), + } + } + fn get_schema_name(&self) -> &str { match self.data { CompletionRelevanceData::Function(f) => f.schema.as_str(), @@ -234,19 +244,30 @@ impl CompletionScore<'_> { } fn check_is_user_defined(&mut self) { - let schema = self.get_schema_name().to_string(); + let schema_name = self.get_schema_name().to_string(); let system_schemas = ["pg_catalog", "information_schema", "pg_toast"]; - if system_schemas.contains(&schema.as_str()) { + if system_schemas.contains(&schema_name.as_str()) { self.score -= 20; } // "public" is the default postgres schema where users // create objects. Prefer it by a slight bit. - if schema.as_str() == "public" { + if schema_name.as_str() == "public" { self.score += 2; } + + let item_name = self.get_item_name().to_string(); + let table_name = self.get_table_name(); + + // migrations shouldn't pop up on top + if item_name.contains("migrations") + || table_name.is_some_and(|t| t.contains("migrations")) + || schema_name.contains("migrations") + { + self.score -= 15; + } } fn check_columns_in_stmt(&mut self, ctx: &CompletionContext) { diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 937c11af..1bd5229c 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -1,8 +1,7 @@ use std::fmt::Display; use pgt_schema_cache::SchemaCache; -use pgt_test_utils::test_database::get_new_test_db; -use sqlx::Executor; +use sqlx::{Executor, PgPool}; use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; @@ -34,17 +33,18 @@ impl Display for InputQuery { } pub(crate) async fn get_test_deps( - setup: &str, + setup: Option<&str>, input: InputQuery, + test_db: &PgPool, ) -> (tree_sitter::Tree, pgt_schema_cache::SchemaCache) { - let test_db = get_new_test_db().await; - - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); + if let Some(setup) = setup { + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + } - let schema_cache = SchemaCache::load(&test_db) + let schema_cache = SchemaCache::load(test_db) .await .expect("Failed to load Schema Cache"); @@ -206,9 +206,10 @@ impl CompletionAssertion { pub(crate) async fn assert_complete_results( query: &str, assertions: Vec, - setup: &str, + setup: Option<&str>, + pool: &PgPool, ) { - let (tree, cache) = get_test_deps(setup, query.into()).await; + let (tree, cache) = get_test_deps(setup, query.into(), pool).await; let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); @@ -241,8 +242,8 @@ pub(crate) async fn assert_complete_results( }); } -pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { - let (tree, cache) = get_test_deps(setup, query.into()).await; +pub(crate) async fn assert_no_complete_results(query: &str, setup: Option<&str>, pool: &PgPool) { + let (tree, cache) = get_test_deps(setup, query.into(), pool).await; let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); diff --git a/crates/pgt_lsp/tests/server.rs b/crates/pgt_lsp/tests/server.rs index 581ea1fe..19b65b06 100644 --- a/crates/pgt_lsp/tests/server.rs +++ b/crates/pgt_lsp/tests/server.rs @@ -13,13 +13,13 @@ use pgt_configuration::database::PartialDatabaseConfiguration; use pgt_fs::MemoryFileSystem; use pgt_lsp::LSPServer; use pgt_lsp::ServerFactory; -use pgt_test_utils::test_database::get_new_test_db; use pgt_workspace::DynRef; use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; use serde_json::{from_value, to_value}; use sqlx::Executor; +use sqlx::PgPool; use std::any::type_name; use std::fmt::Display; use std::time::Duration; @@ -345,11 +345,10 @@ async fn basic_lifecycle() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_database_connection() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_database_connection(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -457,11 +456,10 @@ async fn server_shutdown() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_completions() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_completions(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -558,11 +556,10 @@ async fn test_completions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_271() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_271(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -760,11 +757,10 @@ async fn test_issue_271() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_execute_statement() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_execute_statement(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let database = test_db .connect_options() @@ -899,11 +895,10 @@ async fn test_execute_statement() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_281() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_281(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -983,11 +978,10 @@ async fn test_issue_281() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_303() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_303(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( diff --git a/crates/pgt_schema_cache/src/columns.rs b/crates/pgt_schema_cache/src/columns.rs index 60d422fd..01f9b41c 100644 --- a/crates/pgt_schema_cache/src/columns.rs +++ b/crates/pgt_schema_cache/src/columns.rs @@ -82,15 +82,12 @@ impl SchemaCacheItem for Column { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; + use sqlx::{Executor, PgPool}; use crate::{SchemaCache, columns::ColumnClassKind}; - #[tokio::test] - async fn loads_columns() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_columns(test_db: PgPool) { let setup = r#" create table public.users ( id serial primary key, @@ -129,7 +126,7 @@ mod tests { let public_schema_columns = cache .columns .iter() - .filter(|c| c.schema_name.as_str() == "public") + .filter(|c| c.schema_name.as_str() == "public" && !c.table_name.contains("migrations")) .count(); assert_eq!(public_schema_columns, 4); diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 85cd7821..8e2ee4d7 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -80,27 +80,14 @@ impl SchemaCacheItem for Policy { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; - use crate::{SchemaCache, policies::PolicyCommand}; + use sqlx::{Executor, PgPool}; - #[tokio::test] - async fn loads_policies() { - let test_db = get_new_test_db().await; + use crate::{SchemaCache, policies::PolicyCommand}; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_policies(test_db: PgPool) { let setup = r#" - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'admin' - ) then - create role admin; - end if; - end $$; - - create table public.users ( id serial primary key, name varchar(255) not null @@ -125,22 +112,12 @@ mod tests { to public with check (true); - create policy admin_policy + create policy owner_policy on public.users for all - to admin + to owner with check (true); - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'owner' - ) then - create role owner; - end if; - end $$; - create schema real_estate; create table real_estate.properties ( @@ -148,10 +125,10 @@ mod tests { owner_id int not null ); - create policy owner_policy + create policy test_nologin_policy on real_estate.properties for update - to owner + to test_nologin using (owner_id = current_user::int); "#; @@ -193,29 +170,29 @@ mod tests { assert_eq!(public_policy.security_qualification, Some("true".into())); assert_eq!(public_policy.with_check, None); - let admin_policy = cache + let owner_policy = cache .policies .iter() - .find(|p| p.name == "admin_policy") + .find(|p| p.name == "owner_policy") .unwrap(); - assert_eq!(admin_policy.table_name, "users"); - assert_eq!(admin_policy.schema_name, "public"); - assert!(admin_policy.is_permissive); - assert_eq!(admin_policy.command, PolicyCommand::All); - assert_eq!(admin_policy.role_names, vec!["admin"]); - assert_eq!(admin_policy.security_qualification, None); - assert_eq!(admin_policy.with_check, Some("true".into())); + assert_eq!(owner_policy.table_name, "users"); + assert_eq!(owner_policy.schema_name, "public"); + assert!(owner_policy.is_permissive); + assert_eq!(owner_policy.command, PolicyCommand::All); + assert_eq!(owner_policy.role_names, vec!["owner"]); + assert_eq!(owner_policy.security_qualification, None); + assert_eq!(owner_policy.with_check, Some("true".into())); let owner_policy = cache .policies .iter() - .find(|p| p.name == "owner_policy") + .find(|p| p.name == "test_nologin_policy") .unwrap(); assert_eq!(owner_policy.table_name, "properties"); assert_eq!(owner_policy.schema_name, "real_estate"); assert!(owner_policy.is_permissive); assert_eq!(owner_policy.command, PolicyCommand::Update); - assert_eq!(owner_policy.role_names, vec!["owner"]); + assert_eq!(owner_policy.role_names, vec!["test_nologin"]); assert_eq!( owner_policy.security_qualification, Some("(owner_id = (CURRENT_USER)::integer)".into()) diff --git a/crates/pgt_schema_cache/src/roles.rs b/crates/pgt_schema_cache/src/roles.rs index c212b791..7ced66f9 100644 --- a/crates/pgt_schema_cache/src/roles.rs +++ b/crates/pgt_schema_cache/src/roles.rs @@ -21,50 +21,19 @@ impl SchemaCacheItem for Role { #[cfg(test)] mod tests { - use crate::SchemaCache; - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; - - #[tokio::test] - async fn loads_roles() { - let test_db = get_new_test_db().await; - - let setup = r#" - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_super' - ) then - create role test_super superuser createdb login bypassrls; - end if; - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_nologin' - ) then - create role test_nologin; - end if; - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_login' - ) then - create role test_login login; - end if; - end $$; - "#; + use sqlx::PgPool; - test_db - .execute(setup) - .await - .expect("Failed to setup test database"); + use crate::SchemaCache; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_roles(test_db: PgPool) { let cache = SchemaCache::load(&test_db) .await .expect("Failed to load Schema Cache"); let roles = &cache.roles; - let super_role = roles.iter().find(|r| r.name == "test_super").unwrap(); + let super_role = roles.iter().find(|r| r.name == "owner").unwrap(); assert!(super_role.is_super_user); assert!(super_role.can_create_db); assert!(super_role.can_login); diff --git a/crates/pgt_schema_cache/src/schema_cache.rs b/crates/pgt_schema_cache/src/schema_cache.rs index 516b37e6..8fb9683b 100644 --- a/crates/pgt_schema_cache/src/schema_cache.rs +++ b/crates/pgt_schema_cache/src/schema_cache.rs @@ -93,14 +93,12 @@ pub trait SchemaCacheItem { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + use sqlx::PgPool; use crate::SchemaCache; - #[tokio::test] - async fn it_loads() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn it_loads(test_db: PgPool) { SchemaCache::load(&test_db) .await .expect("Couldnt' load Schema Cache"); diff --git a/crates/pgt_schema_cache/src/tables.rs b/crates/pgt_schema_cache/src/tables.rs index a0a40d6a..16b86c54 100644 --- a/crates/pgt_schema_cache/src/tables.rs +++ b/crates/pgt_schema_cache/src/tables.rs @@ -79,14 +79,12 @@ impl SchemaCacheItem for Table { #[cfg(test)] mod tests { - use crate::{SchemaCache, tables::TableKind}; - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; + use sqlx::{Executor, PgPool}; - #[tokio::test] - async fn includes_views_in_query() { - let test_db = get_new_test_db().await; + use crate::{SchemaCache, tables::TableKind}; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn includes_views_in_query(test_db: PgPool) { let setup = r#" create table public.base_table ( id serial primary key, @@ -116,10 +114,8 @@ mod tests { assert_eq!(view.schema, "public"); } - #[tokio::test] - async fn includes_materialized_views_in_query() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn includes_materialized_views_in_query(test_db: PgPool) { let setup = r#" create table public.base_table ( id serial primary key, diff --git a/crates/pgt_schema_cache/src/triggers.rs b/crates/pgt_schema_cache/src/triggers.rs index 0a5241d6..2b2a3aff 100644 --- a/crates/pgt_schema_cache/src/triggers.rs +++ b/crates/pgt_schema_cache/src/triggers.rs @@ -126,18 +126,16 @@ impl SchemaCacheItem for Trigger { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; + + use sqlx::{Executor, PgPool}; use crate::{ SchemaCache, triggers::{TriggerAffected, TriggerEvent, TriggerTiming}, }; - #[tokio::test] - async fn loads_triggers() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_triggers(test_db: PgPool) { let setup = r#" create table public.users ( id serial primary key, @@ -219,10 +217,8 @@ mod tests { assert_eq!(delete_trigger.proc_name, "log_user_insert"); } - #[tokio::test] - async fn loads_instead_and_truncate_triggers() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_instead_and_truncate_triggers(test_db: PgPool) { let setup = r#" create table public.docs ( id serial primary key, diff --git a/crates/pgt_test_utils/src/lib.rs b/crates/pgt_test_utils/src/lib.rs index 4d6d3070..e21c6ce4 100644 --- a/crates/pgt_test_utils/src/lib.rs +++ b/crates/pgt_test_utils/src/lib.rs @@ -1 +1 @@ -pub mod test_database; +pub static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("./testdb_migrations"); diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs deleted file mode 100644 index 67415c4a..00000000 --- a/crates/pgt_test_utils/src/test_database.rs +++ /dev/null @@ -1,42 +0,0 @@ -use sqlx::{Executor, PgPool, postgres::PgConnectOptions}; -use uuid::Uuid; - -// TODO: Work with proper config objects instead of a connection_string. -// With the current implementation, we can't parse the password from the connection string. -pub async fn get_new_test_db() -> PgPool { - dotenv::dotenv().expect("Unable to load .env file for tests"); - - let connection_string = std::env::var("DATABASE_URL").expect("DATABASE_URL not set"); - let password = std::env::var("DB_PASSWORD").unwrap_or("postgres".into()); - - let options_from_conn_str: PgConnectOptions = connection_string - .parse() - .expect("Invalid Connection String"); - - let host = options_from_conn_str.get_host(); - assert!( - host == "localhost" || host == "127.0.0.1", - "Running tests against non-local database!" - ); - - let options_without_db_name = PgConnectOptions::new() - .host(host) - .port(options_from_conn_str.get_port()) - .username(options_from_conn_str.get_username()) - .password(&password); - - let postgres = sqlx::PgPool::connect_with(options_without_db_name.clone()) - .await - .expect("Unable to connect to test postgres instance"); - - let database_name = Uuid::new_v4().to_string(); - - postgres - .execute(format!(r#"create database "{}";"#, database_name).as_str()) - .await - .expect("Failed to create test database."); - - sqlx::PgPool::connect_with(options_without_db_name.database(&database_name)) - .await - .expect("Could not connect to test database") -} diff --git a/crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql b/crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql new file mode 100644 index 00000000..1f1d50b3 --- /dev/null +++ b/crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql @@ -0,0 +1,32 @@ +do $$ +begin + +begin + create role owner superuser createdb login bypassrls; +exception + when duplicate_object then + null; + when unique_violation then + null; +end; + +begin + create role test_login login; +exception + when duplicate_object then + null; + when unique_violation then + null; +end; + +begin + create role test_nologin; +exception + when duplicate_object then + null; + when unique_violation then + null; +end; + +end +$$; \ No newline at end of file diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 5efe0421..710b2fe9 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -231,11 +231,10 @@ fn resolve_type<'a>( #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; + use sqlx::{Executor, PgPool}; - #[tokio::test] - async fn test_apply_identifiers() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_apply_identifiers(test_db: PgPool) { let input = "select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3 + fn_name.v_test2 + enum_type"; let identifiers = vec![ @@ -295,8 +294,6 @@ mod tests { }, ]; - let test_db = get_new_test_db().await; - let setup = r#" CREATE TYPE "public"."custom_type" AS ( v_test2 integer, diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 9628962d..a7448503 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -3,13 +3,10 @@ use pgt_console::{ markup, }; use pgt_diagnostics::PrintDiagnostic; -use pgt_test_utils::test_database::get_new_test_db; use pgt_typecheck::{TypecheckParams, check_sql}; -use sqlx::Executor; - -async fn test(name: &str, query: &str, setup: Option<&str>) { - let test_db = get_new_test_db().await; +use sqlx::{Executor, PgPool}; +async fn test(name: &str, query: &str, setup: Option<&str>, test_db: &PgPool) { if let Some(setup) = setup { test_db .execute(setup) @@ -22,7 +19,7 @@ async fn test(name: &str, query: &str, setup: Option<&str>) { .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + let schema_cache = pgt_schema_cache::SchemaCache::load(test_db) .await .expect("Failed to load Schema Cache"); @@ -58,8 +55,8 @@ async fn test(name: &str, query: &str, setup: Option<&str>) { }); } -#[tokio::test] -async fn invalid_column() { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn invalid_column(pool: PgPool) { test( "invalid_column", "select id, unknown from contacts;", @@ -73,6 +70,7 @@ async fn invalid_column() { ); "#, ), + &pool, ) .await; }