From daf8efd83e45fdf4079700e2d601cbd94d153aa3 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 27 May 2025 09:13:05 +0200 Subject: [PATCH 01/27] sure --- crates/pgt_completions/src/complete.rs | 4 ++- crates/pgt_completions/src/item.rs | 2 ++ crates/pgt_completions/src/providers/mod.rs | 2 ++ crates/pgt_completions/src/providers/roles.rs | 26 +++++++++++++++ crates/pgt_completions/src/relevance.rs | 1 + .../src/relevance/filtering.rs | 6 ++-- .../pgt_completions/src/relevance/scoring.rs | 32 +++++++++++++------ crates/pgt_lsp/src/handlers/completions.rs | 1 + 8 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 crates/pgt_completions/src/providers/roles.rs diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 5bc5d41c..bd5efd19 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -5,7 +5,8 @@ use crate::{ context::CompletionContext, item::CompletionItem, providers::{ - complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables, + complete_columns, complete_functions, complete_policies, complete_roles, complete_schemas, + complete_tables, }, sanitization::SanitizedCompletionParams, }; @@ -36,6 +37,7 @@ pub fn complete(params: CompletionParams) -> Vec { complete_columns(&ctx, &mut builder); complete_schemas(&ctx, &mut builder); complete_policies(&ctx, &mut builder); + complete_roles(&ctx, &mut builder); builder.finish() } diff --git a/crates/pgt_completions/src/item.rs b/crates/pgt_completions/src/item.rs index 73e08cc0..766e436c 100644 --- a/crates/pgt_completions/src/item.rs +++ b/crates/pgt_completions/src/item.rs @@ -12,6 +12,7 @@ pub enum CompletionItemKind { Column, Schema, Policy, + Role, } impl Display for CompletionItemKind { @@ -22,6 +23,7 @@ impl Display for CompletionItemKind { CompletionItemKind::Column => "Column", CompletionItemKind::Schema => "Schema", CompletionItemKind::Policy => "Policy", + CompletionItemKind::Role => "Role", }; write!(f, "{txt}") diff --git a/crates/pgt_completions/src/providers/mod.rs b/crates/pgt_completions/src/providers/mod.rs index 7b07cee8..ddbdf252 100644 --- a/crates/pgt_completions/src/providers/mod.rs +++ b/crates/pgt_completions/src/providers/mod.rs @@ -2,11 +2,13 @@ mod columns; mod functions; mod helper; mod policies; +mod roles; mod schemas; mod tables; pub use columns::*; pub use functions::*; pub use policies::*; +pub use roles::*; pub use schemas::*; pub use tables::*; diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs new file mode 100644 index 00000000..76471662 --- /dev/null +++ b/crates/pgt_completions/src/providers/roles.rs @@ -0,0 +1,26 @@ +use crate::{ + CompletionItemKind, + builder::{CompletionBuilder, PossibleCompletionItem}, + context::CompletionContext, + relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, +}; + +pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { + let available_roles = &ctx.schema_cache.roles; + + for role in available_roles { + let relevance = CompletionRelevanceData::Role(role); + + let item = PossibleCompletionItem { + label: role.name.chars().take(35).collect::(), + score: CompletionScore::from(relevance.clone()), + filter: CompletionFilter::from(relevance), + description: role.name.clone(), + kind: CompletionItemKind::Role, + completion_text: None, + detail: None, + }; + + builder.add_item(item); + } +} diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index f51c3c52..1d39d9bb 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -8,4 +8,5 @@ pub(crate) enum CompletionRelevanceData<'a> { Column(&'a pgt_schema_cache::Column), Schema(&'a pgt_schema_cache::Schema), Policy(&'a pgt_schema_cache::Policy), + Role(&'a pgt_schema_cache::Role), } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 5323e2bc..26e938fe 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -169,6 +169,8 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Policy(_) => { matches!(clause, WrappingClause::PolicyName) } + + CompletionRelevanceData::Role(_) => false, } }) .and_then(|is_ok| if is_ok { Some(()) } else { None }) @@ -204,8 +206,8 @@ impl CompletionFilter<'_> { // we should never allow schema suggestions if there already was one. CompletionRelevanceData::Schema(_) => false, - // no policy comletion if user typed a schema node first. - CompletionRelevanceData::Policy(_) => false, + // no policy or row completion if user typed a schema node first. + CompletionRelevanceData::Policy(_) | CompletionRelevanceData::Role(_) => false, }; if !matches { diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 2fe12511..d65f23e9 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -47,6 +47,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Column(c) => c.name.as_str().to_ascii_lowercase(), CompletionRelevanceData::Schema(s) => s.name.as_str().to_ascii_lowercase(), CompletionRelevanceData::Policy(p) => p.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Role(r) => r.name.as_str().to_ascii_lowercase(), }; let fz_matcher = SkimMatcherV2::default(); @@ -126,6 +127,8 @@ impl CompletionScore<'_> { WrappingClause::PolicyName => 25, _ => -50, }, + + CompletionRelevanceData::Role(_) => 0, } } @@ -160,6 +163,7 @@ impl CompletionScore<'_> { _ => -50, }, CompletionRelevanceData::Policy(_) => 0, + CompletionRelevanceData::Role(_) => 0, } } @@ -178,7 +182,10 @@ impl CompletionScore<'_> { Some(n) => n, }; - let data_schema = self.get_schema_name(); + let data_schema = match self.get_schema_name() { + Some(s) => s, + None => return, + }; if schema_name == data_schema { self.score += 25; @@ -187,13 +194,14 @@ impl CompletionScore<'_> { } } - fn get_schema_name(&self) -> &str { + fn get_schema_name(&self) -> Option<&str> { match self.data { - CompletionRelevanceData::Function(f) => f.schema.as_str(), - CompletionRelevanceData::Table(t) => t.schema.as_str(), - CompletionRelevanceData::Column(c) => c.schema_name.as_str(), - CompletionRelevanceData::Schema(s) => s.name.as_str(), - CompletionRelevanceData::Policy(p) => p.schema_name.as_str(), + CompletionRelevanceData::Function(f) => Some(f.schema.as_str()), + CompletionRelevanceData::Table(t) => Some(t.schema.as_str()), + CompletionRelevanceData::Column(c) => Some(c.schema_name.as_str()), + CompletionRelevanceData::Schema(s) => Some(s.name.as_str()), + CompletionRelevanceData::Policy(p) => Some(p.schema_name.as_str()), + CompletionRelevanceData::Role(p) => None, } } @@ -212,7 +220,10 @@ impl CompletionScore<'_> { _ => {} } - let schema = self.get_schema_name().to_string(); + let schema = match self.get_schema_name() { + Some(s) => s.to_string(), + None => return, + }; let table_name = match self.get_table_name() { Some(t) => t, None => return, @@ -234,7 +245,10 @@ impl CompletionScore<'_> { } fn check_is_user_defined(&mut self) { - let schema = self.get_schema_name().to_string(); + let schema = match self.get_schema_name() { + Some(s) => s.to_string(), + None => return, + }; let system_schemas = ["pg_catalog", "information_schema", "pg_toast"]; diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index 7e901c79..4a035fcf 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -76,5 +76,6 @@ fn to_lsp_types_completion_item_kind( pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS, pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::CONSTANT, + pgt_completions::CompletionItemKind::Role => lsp_types::CompletionItemKind::CONSTANT, } } From dcf24393e692c26cc5b5e29a67922ba1d67181e1 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 27 May 2025 10:02:19 +0200 Subject: [PATCH 02/27] so far! --- crates/pgt_completions/src/context/mod.rs | 8 ++- crates/pgt_completions/src/providers/roles.rs | 67 +++++++++++++++++++ .../src/relevance/filtering.rs | 4 +- .../pgt_completions/src/relevance/scoring.rs | 31 ++++++++- crates/pgt_schema_cache/src/roles.rs | 4 +- 5 files changed, 108 insertions(+), 6 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 0bb190a9..6e0a952c 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -33,6 +33,9 @@ pub enum WrappingClause<'a> { DropTable, PolicyName, ToRoleAssignment, + SetStatement, + AlterRole, + DropRole, } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -424,7 +427,7 @@ impl<'a> CompletionContext<'a> { } "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" - | "drop_table" | "alter_table" => { + | "drop_table" | "alter_table" | "alter_role" | "drop_role" | "set_statement" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -628,7 +631,10 @@ impl<'a> CompletionContext<'a> { "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), "drop_table" => Some(WrappingClause::DropTable), + "alter_role" => Some(WrappingClause::AlterRole), + "drop_role" => Some(WrappingClause::DropRole), "alter_table" => Some(WrappingClause::AlterTable), + "set_statement" => Some(WrappingClause::SetStatement), "column_definitions" => Some(WrappingClause::ColumnDefinitions), "insert" => Some(WrappingClause::Insert), "join" => { diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 76471662..66d8eefb 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -24,3 +24,70 @@ pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionB builder.add_item(item); } } + +#[cfg(test)] +mod tests { + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + + const SETUP: &'static str = r#" + do $$ + begin + if not exists ( + select from pg_catalog.pg_roles + where rolname = 'test' + ) then + create role test; + end if; + end $$; + + create table users ( + id serial primary key, + email varchar, + address text + ); + "#; + + #[tokio::test] + async fn works_in_drop_role() { + assert_complete_results( + format!("drop role {}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "test".into(), + crate::CompletionItemKind::Role, + )], + SETUP, + ) + .await; + } + + #[tokio::test] + async fn works_in_alter_role() { + assert_complete_results( + format!("alter role {}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "test".into(), + crate::CompletionItemKind::Role, + )], + SETUP, + ) + .await; + } + + async fn works_in_set_statement() { + // set role ROLE; + // set session authorization ROLE; + } + + async fn works_in_policies() {} + + async fn works_in_grant_statements() { + // grant select on my_table to ROLE; + // grant ROLE to OTHER_ROLE with admin option; + } + + async fn works_in_revoke_statements() { + // revoke select on my_table from ROLE; + // revoke ROLE from OTHER_ROLE; + // revoke admin option for ROLE from OTHER_ROLE; + } +} diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 26e938fe..8ea0f1ec 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -170,7 +170,9 @@ impl CompletionFilter<'_> { matches!(clause, WrappingClause::PolicyName) } - CompletionRelevanceData::Role(_) => false, + CompletionRelevanceData::Role(_) => { + matches!(clause, WrappingClause::DropRole | WrappingClause::AlterRole) + } } }) .and_then(|is_ok| if is_ok { Some(()) } else { None }) diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index d65f23e9..05441239 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -128,7 +128,10 @@ impl CompletionScore<'_> { _ => -50, }, - CompletionRelevanceData::Role(_) => 0, + CompletionRelevanceData::Role(_) => match clause_type { + WrappingClause::DropRole | WrappingClause::AlterRole => 25, + _ => -50, + }, } } @@ -201,7 +204,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Column(c) => Some(c.schema_name.as_str()), CompletionRelevanceData::Schema(s) => Some(s.name.as_str()), CompletionRelevanceData::Policy(p) => Some(p.schema_name.as_str()), - CompletionRelevanceData::Role(p) => None, + CompletionRelevanceData::Role(_) => None, } } @@ -245,6 +248,30 @@ impl CompletionScore<'_> { } fn check_is_user_defined(&mut self) { + if let CompletionRelevanceData::Role(r) = self.data { + match r.name.as_str() { + "pg_read_all_data" + | "pg_write_all_data" + | "pg_read_all_settings" + | "pg_read_all_stats" + | "pg_stat_scan_tables" + | "pg_monitor" + | "pg_database_owner" + | "pg_signal_backend" + | "pg_read_server_files" + | "pg_write_server_files" + | "pg_execute_server_program" + | "pg_checkpoint" + | "pg_maintain" + | "pg_use_reserved_connections" + | "pg_create_subscription" + | "postgres" => self.score -= 20, + _ => {} + }; + + return; + } + let schema = match self.get_schema_name() { Some(s) => s.to_string(), None => return, diff --git a/crates/pgt_schema_cache/src/roles.rs b/crates/pgt_schema_cache/src/roles.rs index c212b791..2d00ab8b 100644 --- a/crates/pgt_schema_cache/src/roles.rs +++ b/crates/pgt_schema_cache/src/roles.rs @@ -27,8 +27,6 @@ mod tests { #[tokio::test] async fn loads_roles() { - let test_db = get_new_test_db().await; - let setup = r#" do $$ begin @@ -53,6 +51,8 @@ mod tests { end $$; "#; + let test_db = get_new_test_db().await; + test_db .execute(setup) .await From 9abf45921cfe9226d55b7ba26ff34350dfa73569 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 28 May 2025 08:48:39 +0200 Subject: [PATCH 03/27] so far --- crates/pgt_schema_cache/src/columns.rs | 1 - crates/pgt_schema_cache/src/policies.rs | 1 - crates/pgt_test_utils/src/test_database.rs | 89 +++++++++++++++++++++- 3 files changed, 85 insertions(+), 6 deletions(-) diff --git a/crates/pgt_schema_cache/src/columns.rs b/crates/pgt_schema_cache/src/columns.rs index 60d422fd..943cf9ca 100644 --- a/crates/pgt_schema_cache/src/columns.rs +++ b/crates/pgt_schema_cache/src/columns.rs @@ -83,7 +83,6 @@ impl SchemaCacheItem for Column { #[cfg(test)] mod tests { use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; use crate::{SchemaCache, columns::ColumnClassKind}; diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 85cd7821..770d1a7b 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -81,7 +81,6 @@ impl SchemaCacheItem for Policy { #[cfg(test)] mod tests { use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; use crate::{SchemaCache, policies::PolicyCommand}; diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs index 67415c4a..a8ed9a4d 100644 --- a/crates/pgt_test_utils/src/test_database.rs +++ b/crates/pgt_test_utils/src/test_database.rs @@ -1,9 +1,85 @@ -use sqlx::{Executor, PgPool, postgres::PgConnectOptions}; +use std::ops::Deref; + +use sqlx::{ + Executor, PgPool, + postgres::{PgConnectOptions, PgQueryResult}, +}; use uuid::Uuid; +#[derive(Debug)] +pub struct TestDb { + pool: PgPool, + roles: Vec, +} + +#[derive(Debug)] +pub struct RoleWithArgs { + role: String, + args: Vec, +} + +impl TestDb { + pub async fn execute(&self, sql: &str) -> Result { + if sql.to_ascii_lowercase().contains("create role") { + panic!("Please setup roles via the `setup_roles` method.") + } + self.pool.execute(sql).await + } + + pub async fn setup_roles( + &mut self, + roles: Vec, + ) -> Result { + self.roles = roles.iter().map(|r| &r.role).cloned().collect(); + + let role_statements: Vec = roles + .into_iter() + .map(|r| { + format!( + r#" + if not exists ( + select from pg_catalog.pg_roles + where rolname = '{0}' + ) then + create role {0} {1}; + end if; + "#, + r.role, + r.args.join(" ") + ) + }) + .collect(); + + let query = format!( + r#" + do $$ + begin + {} + end $$; + "#, + role_statements.join("\n") + ); + + println!("{}", query); + + self.execute(&query).await + } + + pub fn get_roles(&self) -> &[String] { + &self.roles + } +} + +impl Deref for TestDb { + type Target = PgPool; + fn deref(&self) -> &Self::Target { + &self.pool + } +} + // TODO: Work with proper config objects instead of a connection_string. // With the current implementation, we can't parse the password from the connection string. -pub async fn get_new_test_db() -> PgPool { +pub async fn get_new_test_db() -> TestDb { dotenv::dotenv().expect("Unable to load .env file for tests"); let connection_string = std::env::var("DATABASE_URL").expect("DATABASE_URL not set"); @@ -36,7 +112,12 @@ pub async fn get_new_test_db() -> PgPool { .await .expect("Failed to create test database."); - sqlx::PgPool::connect_with(options_without_db_name.database(&database_name)) + let pool = sqlx::PgPool::connect_with(options_without_db_name.database(&database_name)) .await - .expect("Could not connect to test database") + .expect("Could not connect to test database"); + + TestDb { + pool, + roles: vec![], + } } From 8d4836e60389cc929f9a9266ea54e176534fd272 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 28 May 2025 09:00:18 +0200 Subject: [PATCH 04/27] yayyy --- crates/pgt_schema_cache/src/policies.rs | 39 +++++++---------- crates/pgt_schema_cache/src/roles.rs | 51 ++++++++++------------ crates/pgt_test_utils/src/test_database.rs | 8 ++-- 3 files changed, 41 insertions(+), 57 deletions(-) diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 770d1a7b..808a2b7e 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -80,26 +80,29 @@ impl SchemaCacheItem for Policy { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + use pgt_test_utils::test_database::{RoleWithArgs, get_new_test_db}; use crate::{SchemaCache, policies::PolicyCommand}; #[tokio::test] async fn loads_policies() { - let test_db = get_new_test_db().await; - - let setup = r#" - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'admin' - ) then - create role admin; - end if; - end $$; + let mut test_db = get_new_test_db().await; + test_db + .setup_roles(vec![ + RoleWithArgs { + role: "admin".into(), + args: vec![], + }, + RoleWithArgs { + role: "owner".into(), + args: vec![], + }, + ]) + .await + .expect("Unable to setup admin roles"); + let setup = r#" create table public.users ( id serial primary key, name varchar(255) not null @@ -130,16 +133,6 @@ mod tests { to admin with check (true); - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'owner' - ) then - create role owner; - end if; - end $$; - create schema real_estate; create table real_estate.properties ( diff --git a/crates/pgt_schema_cache/src/roles.rs b/crates/pgt_schema_cache/src/roles.rs index c212b791..eef02c96 100644 --- a/crates/pgt_schema_cache/src/roles.rs +++ b/crates/pgt_schema_cache/src/roles.rs @@ -22,41 +22,34 @@ impl SchemaCacheItem for Role { #[cfg(test)] mod tests { use crate::SchemaCache; - use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; + use pgt_test_utils::test_database::{RoleWithArgs, get_new_test_db}; #[tokio::test] async fn loads_roles() { - let test_db = get_new_test_db().await; - - let setup = r#" - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_super' - ) then - create role test_super superuser createdb login bypassrls; - end if; - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_nologin' - ) then - create role test_nologin; - end if; - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test_login' - ) then - create role test_login login; - end if; - end $$; - "#; + let mut test_db = get_new_test_db().await; test_db - .execute(setup) + .setup_roles(vec![ + RoleWithArgs { + role: "test_super".into(), + args: vec![ + "superuser".into(), + "createdb".into(), + "login".into(), + "bypassrls".into(), + ], + }, + RoleWithArgs { + role: "test_nologin".into(), + args: vec![], + }, + RoleWithArgs { + role: "test_login".into(), + args: vec!["login".into()], + }, + ]) .await - .expect("Failed to setup test database"); + .expect("Unable to set up roles."); let cache = SchemaCache::load(&test_db) .await diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs index a8ed9a4d..b52b4167 100644 --- a/crates/pgt_test_utils/src/test_database.rs +++ b/crates/pgt_test_utils/src/test_database.rs @@ -14,8 +14,8 @@ pub struct TestDb { #[derive(Debug)] pub struct RoleWithArgs { - role: String, - args: Vec, + pub role: String, + pub args: Vec, } impl TestDb { @@ -60,9 +60,7 @@ impl TestDb { role_statements.join("\n") ); - println!("{}", query); - - self.execute(&query).await + self.pool.execute(query.as_str()).await } pub fn get_roles(&self) -> &[String] { From 363c2ea2831ae8a16d63fe61881cf22ecee0139a Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 28 May 2025 09:11:33 +0200 Subject: [PATCH 05/27] setup roles --- crates/pgt_test_utils/src/test_database.rs | 35 ++++++++++++++++------ 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs index b52b4167..3101d6ae 100644 --- a/crates/pgt_test_utils/src/test_database.rs +++ b/crates/pgt_test_utils/src/test_database.rs @@ -1,4 +1,8 @@ -use std::ops::Deref; +use std::{ + collections::HashSet, + ops::Deref, + sync::{LazyLock, Mutex}, +}; use sqlx::{ Executor, PgPool, @@ -6,10 +10,11 @@ use sqlx::{ }; use uuid::Uuid; +static DB_ROLES: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); + #[derive(Debug)] pub struct TestDb { pool: PgPool, - roles: Vec, } #[derive(Debug)] @@ -30,7 +35,13 @@ impl TestDb { &mut self, roles: Vec, ) -> Result { - self.roles = roles.iter().map(|r| &r.role).cloned().collect(); + { + let roles: Vec = roles.iter().map(|r| &r.role).cloned().collect(); + let mut set = DB_ROLES.lock().unwrap(); + for role in roles { + set.insert(role); + } + } let role_statements: Vec = roles .into_iter() @@ -63,8 +74,17 @@ impl TestDb { self.pool.execute(query.as_str()).await } - pub fn get_roles(&self) -> &[String] { - &self.roles + pub fn get_roles(&self) -> Vec { + let mut roles = vec![]; + + { + let set = DB_ROLES.lock().unwrap(); + for role in set.iter() { + roles.push(role.clone()); + } + } + + roles } } @@ -114,8 +134,5 @@ pub async fn get_new_test_db() -> TestDb { .await .expect("Could not connect to test database"); - TestDb { - pool, - roles: vec![], - } + TestDb { pool } } From 0e1f0a806816ba512f16a9d751430292cc92c800 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 28 May 2025 09:56:29 +0200 Subject: [PATCH 06/27] use distinct method --- crates/pgt_test_utils/src/test_database.rs | 39 +++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs index 3101d6ae..8600af83 100644 --- a/crates/pgt_test_utils/src/test_database.rs +++ b/crates/pgt_test_utils/src/test_database.rs @@ -12,6 +12,13 @@ use uuid::Uuid; static DB_ROLES: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); +fn add_roles(roles: Vec) { + let mut set = DB_ROLES.lock().unwrap(); + for role in roles { + set.insert(role); + } +} + #[derive(Debug)] pub struct TestDb { pool: PgPool, @@ -35,13 +42,8 @@ impl TestDb { &mut self, roles: Vec, ) -> Result { - { - let roles: Vec = roles.iter().map(|r| &r.role).cloned().collect(); - let mut set = DB_ROLES.lock().unwrap(); - for role in roles { - set.insert(role); - } - } + let role_names: Vec = roles.iter().map(|r| &r.role).cloned().collect(); + add_roles(role_names); let role_statements: Vec = roles .into_iter() @@ -84,8 +86,25 @@ impl TestDb { } } + roles.sort(); + roles } + + async fn init_roles(&self) { + let results = sqlx::query!("select rolname from pg_catalog.pg_roles;") + .fetch_all(&self.pool) + .await + .unwrap(); + + let roles: Vec = results + .iter() + .filter_map(|r| r.rolname.as_ref()) + .cloned() + .collect(); + + add_roles(roles); + } } impl Deref for TestDb { @@ -134,5 +153,9 @@ pub async fn get_new_test_db() -> TestDb { .await .expect("Could not connect to test database"); - TestDb { pool } + let db = TestDb { pool }; + + db.init_roles().await; + + db } From 8e4d17cef3be76bc91ca60a15d0950a10d9b134b Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 28 May 2025 09:59:53 +0200 Subject: [PATCH 07/27] =?UTF-8?q?better=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_test_utils/src/test_database.rs | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs index 8600af83..7d7791d2 100644 --- a/crates/pgt_test_utils/src/test_database.rs +++ b/crates/pgt_test_utils/src/test_database.rs @@ -1,7 +1,7 @@ use std::{ collections::HashSet, ops::Deref, - sync::{LazyLock, Mutex}, + sync::{LazyLock, Mutex, MutexGuard}, }; use sqlx::{ @@ -76,19 +76,8 @@ impl TestDb { self.pool.execute(query.as_str()).await } - pub fn get_roles(&self) -> Vec { - let mut roles = vec![]; - - { - let set = DB_ROLES.lock().unwrap(); - for role in set.iter() { - roles.push(role.clone()); - } - } - - roles.sort(); - - roles + pub fn get_roles(&self) -> MutexGuard<'_, HashSet> { + DB_ROLES.lock().unwrap() } async fn init_roles(&self) { From d1d8453432559acb0a20e05a8bd649f8fe33b5c9 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 08:04:35 +0200 Subject: [PATCH 08/27] sqlx prepare --- ...053db65ea6a7529e2cb97b2d3432a18aff6ba.json | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json diff --git a/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json b/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json new file mode 100644 index 00000000..dfc842b7 --- /dev/null +++ b/.sqlx/query-b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "select rolname from pg_catalog.pg_roles;", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "rolname", + "type_info": "Name" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + true + ] + }, + "hash": "b0504a4340264403ad43d05c60d053db65ea6a7529e2cb97b2d3432a18aff6ba" +} From 9021bc02b457758d09511bafd4a70710433ba59a Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 08:10:48 +0200 Subject: [PATCH 09/27] ok --- crates/pgt_completions/src/test_helper.rs | 1 - crates/pgt_lsp/tests/server.rs | 1 - crates/pgt_schema_cache/src/tables.rs | 1 - crates/pgt_schema_cache/src/triggers.rs | 1 - crates/pgt_typecheck/src/typed_identifier.rs | 1 - crates/pgt_typecheck/tests/diagnostics.rs | 1 - 6 files changed, 6 deletions(-) diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 937c11af..19c8e966 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -2,7 +2,6 @@ use std::fmt::Display; use pgt_schema_cache::SchemaCache; use pgt_test_utils::test_database::get_new_test_db; -use sqlx::Executor; use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; diff --git a/crates/pgt_lsp/tests/server.rs b/crates/pgt_lsp/tests/server.rs index 581ea1fe..746e829f 100644 --- a/crates/pgt_lsp/tests/server.rs +++ b/crates/pgt_lsp/tests/server.rs @@ -19,7 +19,6 @@ use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; use serde_json::{from_value, to_value}; -use sqlx::Executor; use std::any::type_name; use std::fmt::Display; use std::time::Duration; diff --git a/crates/pgt_schema_cache/src/tables.rs b/crates/pgt_schema_cache/src/tables.rs index a0a40d6a..98b0be3e 100644 --- a/crates/pgt_schema_cache/src/tables.rs +++ b/crates/pgt_schema_cache/src/tables.rs @@ -81,7 +81,6 @@ impl SchemaCacheItem for Table { mod tests { use crate::{SchemaCache, tables::TableKind}; use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; #[tokio::test] async fn includes_views_in_query() { diff --git a/crates/pgt_schema_cache/src/triggers.rs b/crates/pgt_schema_cache/src/triggers.rs index 0a5241d6..80660008 100644 --- a/crates/pgt_schema_cache/src/triggers.rs +++ b/crates/pgt_schema_cache/src/triggers.rs @@ -127,7 +127,6 @@ impl SchemaCacheItem for Trigger { #[cfg(test)] mod tests { use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; use crate::{ SchemaCache, diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 5efe0421..49152c2e 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -232,7 +232,6 @@ fn resolve_type<'a>( #[cfg(test)] mod tests { use pgt_test_utils::test_database::get_new_test_db; - use sqlx::Executor; #[tokio::test] async fn test_apply_identifiers() { diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 9628962d..740669e7 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -5,7 +5,6 @@ use pgt_console::{ use pgt_diagnostics::PrintDiagnostic; use pgt_test_utils::test_database::get_new_test_db; use pgt_typecheck::{TypecheckParams, check_sql}; -use sqlx::Executor; async fn test(name: &str, query: &str, setup: Option<&str>) { let test_db = get_new_test_db().await; From 0753b50f0a7318a4b502447467f71aefbb802606 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 08:35:35 +0200 Subject: [PATCH 10/27] better --- crates/pgt_completions/src/test_helper.rs | 12 +++---- crates/pgt_lsp/tests/server.rs | 33 ++++++++----------- crates/pgt_test_utils/src/lib.rs | 2 ++ .../testdb_migrations/0001-setup-roles.sql | 23 +++++++++++++ 4 files changed, 45 insertions(+), 25 deletions(-) create mode 100644 crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 19c8e966..c8516108 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -1,7 +1,7 @@ use std::fmt::Display; use pgt_schema_cache::SchemaCache; -use pgt_test_utils::test_database::get_new_test_db; +use sqlx::{Executor, PgPool}; use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; @@ -35,9 +35,8 @@ impl Display for InputQuery { pub(crate) async fn get_test_deps( setup: &str, input: InputQuery, + test_db: &PgPool, ) -> (tree_sitter::Tree, pgt_schema_cache::SchemaCache) { - let test_db = get_new_test_db().await; - test_db .execute(setup) .await @@ -206,8 +205,9 @@ pub(crate) async fn assert_complete_results( query: &str, assertions: Vec, setup: &str, + pool: &PgPool, ) { - let (tree, cache) = get_test_deps(setup, query.into()).await; + let (tree, cache) = get_test_deps(setup, query.into(), pool).await; let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); @@ -240,8 +240,8 @@ pub(crate) async fn assert_complete_results( }); } -pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { - let (tree, cache) = get_test_deps(setup, query.into()).await; +pub(crate) async fn assert_no_complete_results(query: &str, setup: &str, pool: &PgPool) { + let (tree, cache) = get_test_deps(setup, query.into(), pool).await; let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); diff --git a/crates/pgt_lsp/tests/server.rs b/crates/pgt_lsp/tests/server.rs index 746e829f..19b65b06 100644 --- a/crates/pgt_lsp/tests/server.rs +++ b/crates/pgt_lsp/tests/server.rs @@ -13,12 +13,13 @@ use pgt_configuration::database::PartialDatabaseConfiguration; use pgt_fs::MemoryFileSystem; use pgt_lsp::LSPServer; use pgt_lsp::ServerFactory; -use pgt_test_utils::test_database::get_new_test_db; use pgt_workspace::DynRef; use serde::Serialize; use serde::de::DeserializeOwned; use serde_json::Value; use serde_json::{from_value, to_value}; +use sqlx::Executor; +use sqlx::PgPool; use std::any::type_name; use std::fmt::Display; use std::time::Duration; @@ -344,11 +345,10 @@ async fn basic_lifecycle() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_database_connection() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_database_connection(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -456,11 +456,10 @@ async fn server_shutdown() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_completions() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_completions(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -557,11 +556,10 @@ async fn test_completions() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_271() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_271(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -759,11 +757,10 @@ async fn test_issue_271() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_execute_statement() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_execute_statement(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let database = test_db .connect_options() @@ -898,11 +895,10 @@ async fn test_execute_statement() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_281() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_281(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( @@ -982,11 +978,10 @@ async fn test_issue_281() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_issue_303() -> Result<()> { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_issue_303(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); let mut fs = MemoryFileSystem::default(); - let test_db = get_new_test_db().await; let setup = r#" create table public.users ( diff --git a/crates/pgt_test_utils/src/lib.rs b/crates/pgt_test_utils/src/lib.rs index 4d6d3070..935975ca 100644 --- a/crates/pgt_test_utils/src/lib.rs +++ b/crates/pgt_test_utils/src/lib.rs @@ -1 +1,3 @@ pub mod test_database; + +pub static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("./testdb_migrations"); diff --git a/crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql b/crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql new file mode 100644 index 00000000..67e6dfec --- /dev/null +++ b/crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql @@ -0,0 +1,23 @@ +do $$ +begin +if not exists ( + select from pg_catalog.pg_roles + where rolname = 'admin' +) then + create role admin superuser createdb login bypassrls; +end if; + +if not exists ( + select from pg_catalog.pg_roles + where rolname = 'test_login' +) then + create role test_login login; +end if; + +if not exists ( + select from pg_catalog.pg_roles + where rolname = 'test_nologin' +) then + create role test_nologin; +end if; +end; \ No newline at end of file From 1be61b4efdbbfbbb58578cb5ebbee096a4ac3b20 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 10:07:03 +0200 Subject: [PATCH 11/27] ok --- .../pgt_completions/src/providers/columns.rs | 137 ++++++++++------ .../src/providers/functions.rs | 26 +-- .../pgt_completions/src/providers/policies.rs | 14 +- .../pgt_completions/src/providers/schemas.rs | 16 +- .../pgt_completions/src/providers/tables.rs | 113 ++++++++----- .../src/relevance/filtering.rs | 26 +-- .../pgt_completions/src/relevance/scoring.rs | 27 +++- crates/pgt_completions/src/test_helper.rs | 16 +- crates/pgt_schema_cache/src/columns.rs | 8 +- crates/pgt_test_utils/src/lib.rs | 2 - crates/pgt_test_utils/src/test_database.rs | 150 ------------------ 11 files changed, 246 insertions(+), 289 deletions(-) delete mode 100644 crates/pgt_test_utils/src/test_database.rs diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index da6d23bc..b1dcbdf7 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -44,6 +44,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio mod tests { use std::vec; + use sqlx::{Executor, PgPool}; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -66,8 +68,8 @@ mod tests { } } - #[tokio::test] - async fn completes_columns() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_columns(pool: PgPool) { let setup = r#" create schema private; @@ -87,6 +89,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let queries: Vec = vec![ TestCase { message: "correctly prefers the columns of present tables", @@ -121,7 +125,7 @@ mod tests { ]; for q in queries { - let (tree, cache) = get_test_deps(setup, q.get_input_query()).await; + let (tree, cache) = get_test_deps(None, q.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, q.get_input_query()); let results = complete(params); @@ -137,8 +141,8 @@ mod tests { } } - #[tokio::test] - async fn shows_multiple_columns_if_no_relation_specified() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn shows_multiple_columns_if_no_relation_specified(pool: PgPool) { let setup = r#" create schema private; @@ -158,6 +162,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let case = TestCase { query: format!(r#"select n{};"#, CURSOR_POS), description: "", @@ -165,11 +171,11 @@ mod tests { message: "", }; - let (tree, cache) = get_test_deps(setup, case.get_input_query()).await; + let (tree, cache) = get_test_deps(None, case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, case.get_input_query()); let mut items = complete(params); - let _ = items.split_off(6); + let _ = items.split_off(4); #[derive(Eq, PartialEq, Debug)] struct LabelAndDesc { @@ -190,8 +196,6 @@ mod tests { ("narrator", "public.audio_books"), ("narrator_id", "private.audio_books"), ("id", "public.audio_books"), - ("name", "Schema: pg_catalog"), - ("nameconcatoid", "Schema: pg_catalog"), ] .into_iter() .map(|(label, schema)| LabelAndDesc { @@ -203,8 +207,8 @@ mod tests { assert_eq!(labels, expected); } - #[tokio::test] - async fn suggests_relevant_columns_without_letters() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_relevant_columns_without_letters(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -221,7 +225,7 @@ mod tests { description: "", }; - let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let (tree, cache) = get_test_deps(Some(setup), test_case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, test_case.get_input_query()); let results = complete(params); @@ -251,8 +255,8 @@ mod tests { ); } - #[tokio::test] - async fn ignores_cols_in_from_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn ignores_cols_in_from_clause(pool: PgPool) { let setup = r#" create schema private; @@ -271,7 +275,7 @@ mod tests { description: "", }; - let (tree, cache) = get_test_deps(setup, test_case.get_input_query()).await; + let (tree, cache) = get_test_deps(Some(setup), test_case.get_input_query(), &pool).await; let params = get_test_params(&tree, &cache, test_case.get_input_query()); let results = complete(params); @@ -282,8 +286,8 @@ mod tests { ); } - #[tokio::test] - async fn prefers_columns_of_mentioned_tables() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_columns_of_mentioned_tables(pool: PgPool) { let setup = r#" create schema private; @@ -304,6 +308,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!(r#"select {} from users"#, CURSOR_POS).as_str(), vec![ @@ -312,7 +318,8 @@ mod tests { CompletionAssertion::Label("id2".into()), CompletionAssertion::Label("name2".into()), ], - setup, + None, + &pool, ) .await; @@ -324,7 +331,8 @@ mod tests { CompletionAssertion::Label("id1".into()), CompletionAssertion::Label("name1".into()), ], - setup, + None, + &pool, ) .await; @@ -332,13 +340,14 @@ mod tests { assert_complete_results( format!(r#"select sett{} from private.users"#, CURSOR_POS).as_str(), vec![CompletionAssertion::Label("user_settings".into())], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn filters_out_by_aliases() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn filters_out_by_aliases(pool: PgPool) { let setup = r#" create schema auth; @@ -357,6 +366,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + // test in SELECT clause assert_complete_results( format!( @@ -374,7 +385,8 @@ mod tests { CompletionAssertion::Label("title".to_string()), CompletionAssertion::Label("user_id".to_string()), ], - setup, + None, + &pool, ) .await; @@ -396,13 +408,14 @@ mod tests { CompletionAssertion::Label("title".to_string()), CompletionAssertion::Label("user_id".to_string()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn does_not_complete_cols_in_join_clauses() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn does_not_complete_cols_in_join_clauses(pool: PgPool) { let setup = r#" create schema auth; @@ -435,13 +448,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn completes_in_join_on_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_in_join_on_clause(pool: PgPool) { let setup = r#" create schema auth; @@ -460,6 +474,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!( "select u.id, auth.posts.content from auth.users u join auth.posts on u.{}", @@ -472,7 +488,8 @@ mod tests { CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), ], - setup, + None, + &pool, ) .await; @@ -488,13 +505,14 @@ mod tests { CompletionAssertion::LabelAndKind("email".to_string(), CompletionItemKind::Column), CompletionAssertion::LabelAndKind("name".to_string(), CompletionItemKind::Column), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn prefers_not_mentioned_columns() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_not_mentioned_columns(pool: PgPool) { let setup = r#" create schema auth; @@ -513,6 +531,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!( "select {} from public.one o join public.two on o.id = t.id;", @@ -526,7 +546,8 @@ mod tests { CompletionAssertion::Label("d".to_string()), CompletionAssertion::Label("e".to_string()), ], - setup, + None, + &pool, ) .await; @@ -546,7 +567,8 @@ mod tests { CompletionAssertion::Label("z".to_string()), CompletionAssertion::Label("a".to_string()), ], - setup, + None, + &pool, ) .await; @@ -562,7 +584,8 @@ mod tests { CompletionAssertion::LabelAndDesc("id".to_string(), "public.two".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; @@ -574,13 +597,14 @@ mod tests { ) .as_str(), vec![CompletionAssertion::Label("z".to_string())], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_columns_in_insert_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_in_insert_clause(pool: PgPool) { let setup = r#" create table instruments ( id bigint primary key generated always as identity, @@ -595,6 +619,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + // We should prefer the instrument columns, even though they // are lower in the alphabet @@ -605,7 +631,8 @@ mod tests { CompletionAssertion::Label("name".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; @@ -615,14 +642,16 @@ mod tests { CompletionAssertion::Label("name".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; assert_complete_results( format!("insert into instruments (id, {}, name)", CURSOR_POS).as_str(), vec![CompletionAssertion::Label("z".to_string())], - setup, + None, + &pool, ) .await; @@ -637,20 +666,22 @@ mod tests { CompletionAssertion::Label("id".to_string()), CompletionAssertion::Label("z".to_string()), ], - setup, + None, + &pool, ) .await; // no completions in the values list! assert_no_complete_results( format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_columns_in_where_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_in_where_clause(pool: PgPool) { let setup = r#" create table instruments ( id bigint primary key generated always as identity, @@ -666,6 +697,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("select name from instruments where {} ", CURSOR_POS).as_str(), vec![ @@ -674,7 +707,8 @@ mod tests { CompletionAssertion::Label("name".into()), CompletionAssertion::Label("z".into()), ], - setup, + None, + &pool, ) .await; @@ -689,7 +723,8 @@ mod tests { CompletionAssertion::KindNotExists(CompletionItemKind::Column), CompletionAssertion::KindNotExists(CompletionItemKind::Schema), ], - setup, + None, + &pool, ) .await; @@ -705,7 +740,8 @@ mod tests { CompletionAssertion::Label("name".into()), CompletionAssertion::Label("z".into()), ], - setup, + None, + &pool, ) .await; @@ -721,7 +757,8 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/functions.rs b/crates/pgt_completions/src/providers/functions.rs index f1b57e8c..2bc4f331 100644 --- a/crates/pgt_completions/src/providers/functions.rs +++ b/crates/pgt_completions/src/providers/functions.rs @@ -65,13 +65,15 @@ fn get_completion_text(ctx: &CompletionContext, func: &Function) -> CompletionTe #[cfg(test)] mod tests { + use sqlx::PgPool; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{CURSOR_POS, get_test_deps, get_test_params}, }; - #[tokio::test] - async fn completes_fn() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_fn(pool: PgPool) { let setup = r#" create or replace function cool() returns trigger @@ -86,7 +88,7 @@ mod tests { let query = format!("select coo{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -98,8 +100,8 @@ mod tests { assert_eq!(label, "cool"); } - #[tokio::test] - async fn prefers_fn_if_invocation() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_fn_if_invocation(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -119,7 +121,7 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -132,8 +134,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Function); } - #[tokio::test] - async fn prefers_fn_in_select_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_fn_in_select_clause(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -153,7 +155,7 @@ mod tests { let query = format!(r#"select coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); @@ -166,8 +168,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Function); } - #[tokio::test] - async fn prefers_function_in_from_clause_if_invocation() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_function_in_from_clause_if_invocation(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -187,7 +189,7 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index a4d3a9bb..216fcefa 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -59,10 +59,12 @@ pub fn complete_policies<'a>(ctx: &CompletionContext<'a>, builder: &mut Completi #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; - #[tokio::test] - async fn completes_within_quotation_marks() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_within_quotation_marks(pool: PgPool) { let setup = r#" create schema private; @@ -84,13 +86,16 @@ mod tests { with check (true); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("alter policy \"{}\" on private.users;", CURSOR_POS).as_str(), vec![ CompletionAssertion::Label("read for public users disallowed".into()), CompletionAssertion::Label("write for public users allowed".into()), ], - setup, + None, + &pool, ) .await; @@ -99,7 +104,8 @@ mod tests { vec![CompletionAssertion::Label( "write for public users allowed".into(), )], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index 02d2fd0c..561da0f8 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -27,13 +27,15 @@ pub fn complete_schemas<'a>(ctx: &'a CompletionContext, builder: &mut Completion #[cfg(test)] mod tests { + use sqlx::PgPool; + use crate::{ CompletionItemKind, test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}, }; - #[tokio::test] - async fn autocompletes_schemas() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_schemas(pool: PgPool) { let setup = r#" create schema private; create schema auth; @@ -75,13 +77,14 @@ mod tests { CompletionItemKind::Schema, ), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_and_schemas_with_matching_keys() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_and_schemas_with_matching_keys(pool: PgPool) { let setup = r#" create schema ultimate; @@ -99,7 +102,8 @@ mod tests { CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("ultimate".into(), CompletionItemKind::Schema), ], - setup, + Some(setup), + &pool, ) .await; } diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 2102d41c..3fbee8f1 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -42,6 +42,8 @@ pub fn complete_tables<'a>(ctx: &'a CompletionContext, builder: &mut CompletionB #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -50,8 +52,8 @@ mod tests { }, }; - #[tokio::test] - async fn autocompletes_simple_table() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_simple_table(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -62,7 +64,7 @@ mod tests { let query = format!("select * from u{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -77,8 +79,8 @@ mod tests { ) } - #[tokio::test] - async fn autocompletes_table_alphanumerically() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_table_alphanumerically(pool: PgPool) { let setup = r#" create table addresses ( id serial primary key @@ -93,6 +95,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let test_cases = vec![ (format!("select * from u{}", CURSOR_POS), "users"), (format!("select * from e{}", CURSOR_POS), "emails"), @@ -100,7 +104,7 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(None, query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -116,8 +120,8 @@ mod tests { } } - #[tokio::test] - async fn autocompletes_table_with_schema() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn autocompletes_table_with_schema(pool: PgPool) { let setup = r#" create schema customer_support; create schema private; @@ -135,6 +139,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + let test_cases = vec![ (format!("select * from u{}", CURSOR_POS), "user_y"), // user_y is preferred alphanumerically (format!("select * from private.u{}", CURSOR_POS), "user_z"), @@ -145,7 +151,7 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(None, query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -161,8 +167,8 @@ mod tests { } } - #[tokio::test] - async fn prefers_table_in_from_clause() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn prefers_table_in_from_clause(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -182,7 +188,7 @@ mod tests { let query = format!(r#"select * from coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let (tree, cache) = get_test_deps(Some(setup), query.as_str().into(), &pool).await; let params = get_test_params(&tree, &cache, query.as_str().into()); let items = complete(params); @@ -195,8 +201,8 @@ mod tests { assert_eq!(kind, CompletionItemKind::Table); } - #[tokio::test] - async fn suggests_tables_in_update() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_update(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -204,13 +210,16 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("update {}", CURSOR_POS).as_str(), vec![CompletionAssertion::LabelAndKind( "public".into(), CompletionItemKind::Schema, )], - setup, + None, + &pool, ) .await; @@ -220,12 +229,17 @@ mod tests { "coos".into(), CompletionItemKind::Table, )], - setup, + None, + &pool, ) .await; - assert_no_complete_results(format!("update public.coos {}", CURSOR_POS).as_str(), setup) - .await; + assert_no_complete_results( + format!("update public.coos {}", CURSOR_POS).as_str(), + None, + &pool, + ) + .await; assert_complete_results( format!("update coos set {}", CURSOR_POS).as_str(), @@ -233,7 +247,8 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; @@ -243,13 +258,14 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_delete() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_delete(pool: PgPool) { let setup = r#" create table coos ( id serial primary key, @@ -257,7 +273,9 @@ mod tests { ); "#; - assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), setup).await; + pool.execute(setup).await.unwrap(); + + assert_no_complete_results(format!("delete {}", CURSOR_POS).as_str(), None, &pool).await; assert_complete_results( format!("delete from {}", CURSOR_POS).as_str(), @@ -265,14 +283,16 @@ mod tests { CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("coos".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; assert_complete_results( format!("delete from public.{}", CURSOR_POS).as_str(), vec![CompletionAssertion::Label("coos".into())], - setup, + None, + &pool, ) .await; @@ -282,13 +302,14 @@ mod tests { CompletionAssertion::Label("id".into()), CompletionAssertion::Label("name".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_join() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_join(pool: PgPool) { let setup = r#" create schema auth; @@ -315,13 +336,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + Some(setup), + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_alter_and_drop_statements() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_alter_and_drop_statements(pool: PgPool) { let setup = r#" create schema auth; @@ -340,6 +362,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("alter table {}", CURSOR_POS).as_str(), vec![ @@ -348,7 +372,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -360,7 +385,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -372,7 +398,8 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -384,13 +411,14 @@ mod tests { CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn suggests_tables_in_insert_into() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_tables_in_insert_into(pool: PgPool) { let setup = r#" create schema auth; @@ -401,6 +429,8 @@ mod tests { ); "#; + pool.execute(setup).await.unwrap(); + assert_complete_results( format!("insert into {}", CURSOR_POS).as_str(), vec![ @@ -408,7 +438,8 @@ mod tests { CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; @@ -418,7 +449,8 @@ mod tests { "users".into(), CompletionItemKind::Table, )], - setup, + None, + &pool, ) .await; @@ -434,7 +466,8 @@ mod tests { CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), ], - setup, + None, + &pool, ) .await; } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 5323e2bc..0be9e48a 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -218,12 +218,14 @@ impl CompletionFilter<'_> { #[cfg(test)] mod tests { + use sqlx::{Executor, PgPool}; + use crate::test_helper::{ CURSOR_POS, CompletionAssertion, assert_complete_results, assert_no_complete_results, }; - #[tokio::test] - async fn completion_after_asterisk() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_after_asterisk(pool: PgPool) { let setup = r#" create table users ( id serial primary key, @@ -232,7 +234,9 @@ mod tests { ); "#; - assert_no_complete_results(format!("select * {}", CURSOR_POS).as_str(), setup).await; + pool.execute(setup).await.unwrap(); + + assert_no_complete_results(format!("select * {}", CURSOR_POS).as_str(), None, &pool).await; // if there s a COMMA after the asterisk, we're good assert_complete_results( @@ -242,19 +246,21 @@ mod tests { CompletionAssertion::Label("email".into()), CompletionAssertion::Label("id".into()), ], - setup, + None, + &pool, ) .await; } - #[tokio::test] - async fn completion_after_create_table() { - assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), "").await; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_after_create_table(pool: PgPool) { + assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), None, &pool) + .await; } - #[tokio::test] - async fn completion_in_column_definitions() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completion_in_column_definitions(pool: PgPool) { let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); - assert_no_complete_results(query.as_str(), "").await; + assert_no_complete_results(query.as_str(), None, &pool).await; } } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 2fe12511..a8c89f50 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -187,6 +187,16 @@ impl CompletionScore<'_> { } } + fn get_item_name(&self) -> &str { + match self.data { + CompletionRelevanceData::Table(t) => t.name.as_str(), + CompletionRelevanceData::Function(f) => f.name.as_str(), + CompletionRelevanceData::Column(c) => c.name.as_str(), + CompletionRelevanceData::Schema(s) => s.name.as_str(), + CompletionRelevanceData::Policy(p) => p.name.as_str(), + } + } + fn get_schema_name(&self) -> &str { match self.data { CompletionRelevanceData::Function(f) => f.schema.as_str(), @@ -234,19 +244,30 @@ impl CompletionScore<'_> { } fn check_is_user_defined(&mut self) { - let schema = self.get_schema_name().to_string(); + let schema_name = self.get_schema_name().to_string(); let system_schemas = ["pg_catalog", "information_schema", "pg_toast"]; - if system_schemas.contains(&schema.as_str()) { + if system_schemas.contains(&schema_name.as_str()) { self.score -= 20; } // "public" is the default postgres schema where users // create objects. Prefer it by a slight bit. - if schema.as_str() == "public" { + if schema_name.as_str() == "public" { self.score += 2; } + + let item_name = self.get_item_name().to_string(); + let table_name = self.get_table_name(); + + // migrations shouldn't pop up on top + if item_name.contains("migrations") + || table_name.is_some_and(|t| t.contains("migrations")) + || schema_name.contains("migrations") + { + self.score -= 15; + } } fn check_columns_in_stmt(&mut self, ctx: &CompletionContext) { diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index c8516108..61decce0 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -33,14 +33,16 @@ impl Display for InputQuery { } pub(crate) async fn get_test_deps( - setup: &str, + setup: Option<&str>, input: InputQuery, test_db: &PgPool, ) -> (tree_sitter::Tree, pgt_schema_cache::SchemaCache) { - test_db - .execute(setup) - .await - .expect("Failed to execute setup query"); + if let Some(setup) = setup { + test_db + .execute(setup) + .await + .expect("Failed to execute setup query"); + } let schema_cache = SchemaCache::load(&test_db) .await @@ -204,7 +206,7 @@ impl CompletionAssertion { pub(crate) async fn assert_complete_results( query: &str, assertions: Vec, - setup: &str, + setup: Option<&str>, pool: &PgPool, ) { let (tree, cache) = get_test_deps(setup, query.into(), pool).await; @@ -240,7 +242,7 @@ pub(crate) async fn assert_complete_results( }); } -pub(crate) async fn assert_no_complete_results(query: &str, setup: &str, pool: &PgPool) { +pub(crate) async fn assert_no_complete_results(query: &str, setup: Option<&str>, pool: &PgPool) { let (tree, cache) = get_test_deps(setup, query.into(), pool).await; let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); diff --git a/crates/pgt_schema_cache/src/columns.rs b/crates/pgt_schema_cache/src/columns.rs index 943cf9ca..786ca5d4 100644 --- a/crates/pgt_schema_cache/src/columns.rs +++ b/crates/pgt_schema_cache/src/columns.rs @@ -82,14 +82,12 @@ impl SchemaCacheItem for Column { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + use sqlx::{Executor, PgPool}; use crate::{SchemaCache, columns::ColumnClassKind}; - #[tokio::test] - async fn loads_columns() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_columns(test_db: PgPool) { let setup = r#" create table public.users ( id serial primary key, diff --git a/crates/pgt_test_utils/src/lib.rs b/crates/pgt_test_utils/src/lib.rs index 935975ca..e21c6ce4 100644 --- a/crates/pgt_test_utils/src/lib.rs +++ b/crates/pgt_test_utils/src/lib.rs @@ -1,3 +1 @@ -pub mod test_database; - pub static MIGRATIONS: sqlx::migrate::Migrator = sqlx::migrate!("./testdb_migrations"); diff --git a/crates/pgt_test_utils/src/test_database.rs b/crates/pgt_test_utils/src/test_database.rs deleted file mode 100644 index 7d7791d2..00000000 --- a/crates/pgt_test_utils/src/test_database.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::{ - collections::HashSet, - ops::Deref, - sync::{LazyLock, Mutex, MutexGuard}, -}; - -use sqlx::{ - Executor, PgPool, - postgres::{PgConnectOptions, PgQueryResult}, -}; -use uuid::Uuid; - -static DB_ROLES: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); - -fn add_roles(roles: Vec) { - let mut set = DB_ROLES.lock().unwrap(); - for role in roles { - set.insert(role); - } -} - -#[derive(Debug)] -pub struct TestDb { - pool: PgPool, -} - -#[derive(Debug)] -pub struct RoleWithArgs { - pub role: String, - pub args: Vec, -} - -impl TestDb { - pub async fn execute(&self, sql: &str) -> Result { - if sql.to_ascii_lowercase().contains("create role") { - panic!("Please setup roles via the `setup_roles` method.") - } - self.pool.execute(sql).await - } - - pub async fn setup_roles( - &mut self, - roles: Vec, - ) -> Result { - let role_names: Vec = roles.iter().map(|r| &r.role).cloned().collect(); - add_roles(role_names); - - let role_statements: Vec = roles - .into_iter() - .map(|r| { - format!( - r#" - if not exists ( - select from pg_catalog.pg_roles - where rolname = '{0}' - ) then - create role {0} {1}; - end if; - "#, - r.role, - r.args.join(" ") - ) - }) - .collect(); - - let query = format!( - r#" - do $$ - begin - {} - end $$; - "#, - role_statements.join("\n") - ); - - self.pool.execute(query.as_str()).await - } - - pub fn get_roles(&self) -> MutexGuard<'_, HashSet> { - DB_ROLES.lock().unwrap() - } - - async fn init_roles(&self) { - let results = sqlx::query!("select rolname from pg_catalog.pg_roles;") - .fetch_all(&self.pool) - .await - .unwrap(); - - let roles: Vec = results - .iter() - .filter_map(|r| r.rolname.as_ref()) - .cloned() - .collect(); - - add_roles(roles); - } -} - -impl Deref for TestDb { - type Target = PgPool; - fn deref(&self) -> &Self::Target { - &self.pool - } -} - -// TODO: Work with proper config objects instead of a connection_string. -// With the current implementation, we can't parse the password from the connection string. -pub async fn get_new_test_db() -> TestDb { - dotenv::dotenv().expect("Unable to load .env file for tests"); - - let connection_string = std::env::var("DATABASE_URL").expect("DATABASE_URL not set"); - let password = std::env::var("DB_PASSWORD").unwrap_or("postgres".into()); - - let options_from_conn_str: PgConnectOptions = connection_string - .parse() - .expect("Invalid Connection String"); - - let host = options_from_conn_str.get_host(); - assert!( - host == "localhost" || host == "127.0.0.1", - "Running tests against non-local database!" - ); - - let options_without_db_name = PgConnectOptions::new() - .host(host) - .port(options_from_conn_str.get_port()) - .username(options_from_conn_str.get_username()) - .password(&password); - - let postgres = sqlx::PgPool::connect_with(options_without_db_name.clone()) - .await - .expect("Unable to connect to test postgres instance"); - - let database_name = Uuid::new_v4().to_string(); - - postgres - .execute(format!(r#"create database "{}";"#, database_name).as_str()) - .await - .expect("Failed to create test database."); - - let pool = sqlx::PgPool::connect_with(options_without_db_name.database(&database_name)) - .await - .expect("Could not connect to test database"); - - let db = TestDb { pool }; - - db.init_roles().await; - - db -} From e55271f804a91bb1f88f06e4ca4ca718fa90474d Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 10:17:09 +0200 Subject: [PATCH 12/27] ok --- crates/pgt_schema_cache/src/columns.rs | 2 +- crates/pgt_schema_cache/src/policies.rs | 23 +++----------- crates/pgt_schema_cache/src/roles.rs | 32 +++----------------- crates/pgt_schema_cache/src/schema_cache.rs | 8 ++--- crates/pgt_schema_cache/src/tables.rs | 15 ++++----- crates/pgt_schema_cache/src/triggers.rs | 15 ++++----- crates/pgt_typecheck/src/typed_identifier.rs | 8 ++--- crates/pgt_typecheck/tests/diagnostics.rs | 11 +++---- 8 files changed, 32 insertions(+), 82 deletions(-) diff --git a/crates/pgt_schema_cache/src/columns.rs b/crates/pgt_schema_cache/src/columns.rs index 786ca5d4..01f9b41c 100644 --- a/crates/pgt_schema_cache/src/columns.rs +++ b/crates/pgt_schema_cache/src/columns.rs @@ -126,7 +126,7 @@ mod tests { let public_schema_columns = cache .columns .iter() - .filter(|c| c.schema_name.as_str() == "public") + .filter(|c| c.schema_name.as_str() == "public" && !c.table_name.contains("migrations")) .count(); assert_eq!(public_schema_columns, 4); diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 808a2b7e..b754725b 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -80,28 +80,13 @@ impl SchemaCacheItem for Policy { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::{RoleWithArgs, get_new_test_db}; - use crate::{SchemaCache, policies::PolicyCommand}; - - #[tokio::test] - async fn loads_policies() { - let mut test_db = get_new_test_db().await; + use sqlx::{Executor, PgPool}; - test_db - .setup_roles(vec![ - RoleWithArgs { - role: "admin".into(), - args: vec![], - }, - RoleWithArgs { - role: "owner".into(), - args: vec![], - }, - ]) - .await - .expect("Unable to setup admin roles"); + use crate::{SchemaCache, policies::PolicyCommand}; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_policies(test_db: PgPool) { let setup = r#" create table public.users ( id serial primary key, diff --git a/crates/pgt_schema_cache/src/roles.rs b/crates/pgt_schema_cache/src/roles.rs index eef02c96..636af34a 100644 --- a/crates/pgt_schema_cache/src/roles.rs +++ b/crates/pgt_schema_cache/src/roles.rs @@ -21,36 +21,12 @@ impl SchemaCacheItem for Role { #[cfg(test)] mod tests { - use crate::SchemaCache; - use pgt_test_utils::test_database::{RoleWithArgs, get_new_test_db}; - - #[tokio::test] - async fn loads_roles() { - let mut test_db = get_new_test_db().await; + use sqlx::PgPool; - test_db - .setup_roles(vec![ - RoleWithArgs { - role: "test_super".into(), - args: vec![ - "superuser".into(), - "createdb".into(), - "login".into(), - "bypassrls".into(), - ], - }, - RoleWithArgs { - role: "test_nologin".into(), - args: vec![], - }, - RoleWithArgs { - role: "test_login".into(), - args: vec!["login".into()], - }, - ]) - .await - .expect("Unable to set up roles."); + use crate::SchemaCache; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_roles(test_db: PgPool) { let cache = SchemaCache::load(&test_db) .await .expect("Failed to load Schema Cache"); diff --git a/crates/pgt_schema_cache/src/schema_cache.rs b/crates/pgt_schema_cache/src/schema_cache.rs index 516b37e6..8fb9683b 100644 --- a/crates/pgt_schema_cache/src/schema_cache.rs +++ b/crates/pgt_schema_cache/src/schema_cache.rs @@ -93,14 +93,12 @@ pub trait SchemaCacheItem { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + use sqlx::PgPool; use crate::SchemaCache; - #[tokio::test] - async fn it_loads() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn it_loads(test_db: PgPool) { SchemaCache::load(&test_db) .await .expect("Couldnt' load Schema Cache"); diff --git a/crates/pgt_schema_cache/src/tables.rs b/crates/pgt_schema_cache/src/tables.rs index 98b0be3e..16b86c54 100644 --- a/crates/pgt_schema_cache/src/tables.rs +++ b/crates/pgt_schema_cache/src/tables.rs @@ -79,13 +79,12 @@ impl SchemaCacheItem for Table { #[cfg(test)] mod tests { - use crate::{SchemaCache, tables::TableKind}; - use pgt_test_utils::test_database::get_new_test_db; + use sqlx::{Executor, PgPool}; - #[tokio::test] - async fn includes_views_in_query() { - let test_db = get_new_test_db().await; + use crate::{SchemaCache, tables::TableKind}; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn includes_views_in_query(test_db: PgPool) { let setup = r#" create table public.base_table ( id serial primary key, @@ -115,10 +114,8 @@ mod tests { assert_eq!(view.schema, "public"); } - #[tokio::test] - async fn includes_materialized_views_in_query() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn includes_materialized_views_in_query(test_db: PgPool) { let setup = r#" create table public.base_table ( id serial primary key, diff --git a/crates/pgt_schema_cache/src/triggers.rs b/crates/pgt_schema_cache/src/triggers.rs index 80660008..2b2a3aff 100644 --- a/crates/pgt_schema_cache/src/triggers.rs +++ b/crates/pgt_schema_cache/src/triggers.rs @@ -126,17 +126,16 @@ impl SchemaCacheItem for Trigger { #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + + use sqlx::{Executor, PgPool}; use crate::{ SchemaCache, triggers::{TriggerAffected, TriggerEvent, TriggerTiming}, }; - #[tokio::test] - async fn loads_triggers() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_triggers(test_db: PgPool) { let setup = r#" create table public.users ( id serial primary key, @@ -218,10 +217,8 @@ mod tests { assert_eq!(delete_trigger.proc_name, "log_user_insert"); } - #[tokio::test] - async fn loads_instead_and_truncate_triggers() { - let test_db = get_new_test_db().await; - + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn loads_instead_and_truncate_triggers(test_db: PgPool) { let setup = r#" create table public.docs ( id serial primary key, diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 49152c2e..710b2fe9 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -231,10 +231,10 @@ fn resolve_type<'a>( #[cfg(test)] mod tests { - use pgt_test_utils::test_database::get_new_test_db; + use sqlx::{Executor, PgPool}; - #[tokio::test] - async fn test_apply_identifiers() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_apply_identifiers(test_db: PgPool) { let input = "select v_test + fn_name.custom_type.v_test2 + $3 + custom_type.v_test3 + fn_name.v_test2 + enum_type"; let identifiers = vec![ @@ -294,8 +294,6 @@ mod tests { }, ]; - let test_db = get_new_test_db().await; - let setup = r#" CREATE TYPE "public"."custom_type" AS ( v_test2 integer, diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 740669e7..9bf5d786 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -3,12 +3,10 @@ use pgt_console::{ markup, }; use pgt_diagnostics::PrintDiagnostic; -use pgt_test_utils::test_database::get_new_test_db; use pgt_typecheck::{TypecheckParams, check_sql}; +use sqlx::{Executor, PgPool}; -async fn test(name: &str, query: &str, setup: Option<&str>) { - let test_db = get_new_test_db().await; - +async fn test(name: &str, query: &str, setup: Option<&str>, test_db: &PgPool) { if let Some(setup) = setup { test_db .execute(setup) @@ -57,8 +55,8 @@ async fn test(name: &str, query: &str, setup: Option<&str>) { }); } -#[tokio::test] -async fn invalid_column() { +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn invalid_column(pool: PgPool) { test( "invalid_column", "select id, unknown from contacts;", @@ -72,6 +70,7 @@ async fn invalid_column() { ); "#, ), + &pool, ) .await; } From 9cd04ddb11571770ce7b4e0f925e7035b7bd1c56 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 10:19:02 +0200 Subject: [PATCH 13/27] ok --- crates/pgt_completions/src/test_helper.rs | 2 +- crates/pgt_typecheck/tests/diagnostics.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 61decce0..1bd5229c 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -44,7 +44,7 @@ pub(crate) async fn get_test_deps( .expect("Failed to execute setup query"); } - let schema_cache = SchemaCache::load(&test_db) + let schema_cache = SchemaCache::load(test_db) .await .expect("Failed to load Schema Cache"); diff --git a/crates/pgt_typecheck/tests/diagnostics.rs b/crates/pgt_typecheck/tests/diagnostics.rs index 9bf5d786..a7448503 100644 --- a/crates/pgt_typecheck/tests/diagnostics.rs +++ b/crates/pgt_typecheck/tests/diagnostics.rs @@ -19,7 +19,7 @@ async fn test(name: &str, query: &str, setup: Option<&str>, test_db: &PgPool) { .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let schema_cache = pgt_schema_cache::SchemaCache::load(&test_db) + let schema_cache = pgt_schema_cache::SchemaCache::load(test_db) .await .expect("Failed to load Schema Cache"); From b37adda9cd10d442be15b485d2f79f2a1e5039a6 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 10:56:05 +0200 Subject: [PATCH 14/27] ok --- crates/pgt_completions/src/providers/roles.rs | 56 +++++++++++-------- .../pgt_completions/src/relevance/scoring.rs | 4 +- crates/pgt_schema_cache/src/roles.rs | 2 +- ...1-setup-roles.sql => 0001_setup-roles.sql} | 5 +- 4 files changed, 39 insertions(+), 28 deletions(-) rename crates/pgt_test_utils/testdb_migrations/{0001-setup-roles.sql => 0001_setup-roles.sql} (97%) diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 66d8eefb..5ddf8dc3 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -27,19 +27,11 @@ pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionB #[cfg(test)] mod tests { + use sqlx::PgPool; + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; const SETUP: &'static str = r#" - do $$ - begin - if not exists ( - select from pg_catalog.pg_roles - where rolname = 'test' - ) then - create role test; - end if; - end $$; - create table users ( id serial primary key, email varchar, @@ -47,28 +39,44 @@ mod tests { ); "#; - #[tokio::test] - async fn works_in_drop_role() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_drop_role(pool: PgPool) { assert_complete_results( format!("drop role {}", CURSOR_POS).as_str(), - vec![CompletionAssertion::LabelAndKind( - "test".into(), - crate::CompletionItemKind::Role, - )], - SETUP, + vec![ + CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + Some(SETUP), + &pool, ) .await; } - #[tokio::test] - async fn works_in_alter_role() { + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_alter_role(pool: PgPool) { assert_complete_results( format!("alter role {}", CURSOR_POS).as_str(), - vec![CompletionAssertion::LabelAndKind( - "test".into(), - crate::CompletionItemKind::Role, - )], - SETUP, + vec![ + CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + Some(SETUP), + &pool, ) .await; } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 93fb0762..a0b5efa5 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -208,7 +208,7 @@ impl CompletionScore<'_> { } } - fn get_schema_name(&self) -> &str { + fn get_schema_name(&self) -> Option<&str> { match self.data { CompletionRelevanceData::Function(f) => Some(f.schema.as_str()), CompletionRelevanceData::Table(t) => Some(t.schema.as_str()), @@ -283,7 +283,7 @@ impl CompletionScore<'_> { return; } - let schema = match self.get_schema_name() { + let schema_name = match self.get_schema_name() { Some(s) => s.to_string(), None => return, }; diff --git a/crates/pgt_schema_cache/src/roles.rs b/crates/pgt_schema_cache/src/roles.rs index 636af34a..33187078 100644 --- a/crates/pgt_schema_cache/src/roles.rs +++ b/crates/pgt_schema_cache/src/roles.rs @@ -33,7 +33,7 @@ mod tests { let roles = &cache.roles; - let super_role = roles.iter().find(|r| r.name == "test_super").unwrap(); + let super_role = roles.iter().find(|r| r.name == "admin").unwrap(); assert!(super_role.is_super_user); assert!(super_role.can_create_db); assert!(super_role.can_login); diff --git a/crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql b/crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql similarity index 97% rename from crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql rename to crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql index 67e6dfec..f247a599 100644 --- a/crates/pgt_test_utils/testdb_migrations/0001-setup-roles.sql +++ b/crates/pgt_test_utils/testdb_migrations/0001_setup-roles.sql @@ -1,5 +1,6 @@ do $$ begin + if not exists ( select from pg_catalog.pg_roles where rolname = 'admin' @@ -20,4 +21,6 @@ if not exists ( ) then create role test_nologin; end if; -end; \ No newline at end of file + +end +$$; \ No newline at end of file From cc387572641d475fedaf9eb693cba6e60b5aca21 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 11:00:48 +0200 Subject: [PATCH 15/27] adjust test --- crates/pgt_schema_cache/src/policies.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index b754725b..af5e72d4 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -125,10 +125,10 @@ mod tests { owner_id int not null ); - create policy owner_policy + create policy test_nologin_policy on real_estate.properties for update - to owner + to test_nologin using (owner_id = current_user::int); "#; @@ -186,13 +186,13 @@ mod tests { let owner_policy = cache .policies .iter() - .find(|p| p.name == "owner_policy") + .find(|p| p.name == "test_nologin_policy") .unwrap(); assert_eq!(owner_policy.table_name, "properties"); assert_eq!(owner_policy.schema_name, "real_estate"); assert!(owner_policy.is_permissive); assert_eq!(owner_policy.command, PolicyCommand::Update); - assert_eq!(owner_policy.role_names, vec!["owner"]); + assert_eq!(owner_policy.role_names, vec!["test_nologin"]); assert_eq!( owner_policy.security_qualification, Some("(owner_id = (CURRENT_USER)::integer)".into()) From a1e1a9cf4b6eafcec101dccc372daf2362312ff4 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 29 May 2025 11:13:17 +0200 Subject: [PATCH 16/27] ok --- crates/pgt_completions/src/providers/roles.rs | 44 +++++++++++++++++-- .../src/relevance/filtering.rs | 13 ++++-- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 5ddf8dc3..01a905a1 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -27,7 +27,7 @@ pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionB #[cfg(test)] mod tests { - use sqlx::PgPool; + use sqlx::{Executor, PgPool}; use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; @@ -81,9 +81,45 @@ mod tests { .await; } - async fn works_in_set_statement() { - // set role ROLE; - // set session authorization ROLE; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_set_statement(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + assert_complete_results( + format!("set role {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!("set session authorization {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; } async fn works_in_policies() {} diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index e19db507..ba9b5256 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -76,7 +76,8 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Table(_) => match clause { WrappingClause::Select | WrappingClause::Where - | WrappingClause::ColumnDefinitions => false, + | WrappingClause::ColumnDefinitions + | WrappingClause::SetStatement => false, WrappingClause::Insert => { ctx.wrapping_node_kind @@ -101,6 +102,7 @@ impl CompletionFilter<'_> { match clause { WrappingClause::From | WrappingClause::ColumnDefinitions + | WrappingClause::SetStatement | WrappingClause::AlterTable | WrappingClause::DropTable => false, @@ -170,9 +172,12 @@ impl CompletionFilter<'_> { matches!(clause, WrappingClause::PolicyName) } - CompletionRelevanceData::Role(_) => { - matches!(clause, WrappingClause::DropRole | WrappingClause::AlterRole) - } + CompletionRelevanceData::Role(_) => match clause { + WrappingClause::DropRole | WrappingClause::AlterRole => true, + WrappingClause::SetStatement => ctx + .before_cursor_matches_kind(&["keyword_role", "keyword_authorization"]), + _ => false, + }, } }) .and_then(|is_ok| if is_ok { Some(()) } else { None }) From 9c3184eb49eec62c1a82af532cf7496cb4f24de6 Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 2 Jun 2025 08:55:52 +0200 Subject: [PATCH 17/27] ok --- crates/pgt_completions/src/relevance/scoring.rs | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 178d98a6..a0b5efa5 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -208,17 +208,6 @@ impl CompletionScore<'_> { } } - fn get_schema_name(&self) -> &str { - match self.data { - CompletionRelevanceData::Table(t) => t.name.as_str(), - CompletionRelevanceData::Function(f) => f.name.as_str(), - CompletionRelevanceData::Column(c) => c.name.as_str(), - CompletionRelevanceData::Schema(s) => s.name.as_str(), - CompletionRelevanceData::Policy(p) => p.name.as_str(), - CompletionRelevanceData::Role(r) => r.name.as_str(), - } - } - fn get_schema_name(&self) -> Option<&str> { match self.data { CompletionRelevanceData::Function(f) => Some(f.schema.as_str()), From 299e46960279abb4b77de919e62dd8ccddaa01eb Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 3 Jun 2025 08:42:05 +0200 Subject: [PATCH 18/27] ok --- .../src/context/grant_parser.rs | 430 ++++++++++++++++++ crates/pgt_completions/src/context/mod.rs | 42 +- .../src/context/parser_helper.rs | 132 ++++++ .../src/context/policy_parser.rs | 160 ++----- crates/pgt_completions/src/providers/roles.rs | 136 +++++- .../src/relevance/filtering.rs | 13 +- crates/pgt_lsp/src/session.rs | 3 + crates/pgt_workspace/src/settings.rs | 39 -- 8 files changed, 778 insertions(+), 177 deletions(-) create mode 100644 crates/pgt_completions/src/context/grant_parser.rs create mode 100644 crates/pgt_completions/src/context/parser_helper.rs diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs new file mode 100644 index 00000000..f92c71d3 --- /dev/null +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -0,0 +1,430 @@ +use std::iter::Peekable; + +use pgt_text_size::{TextRange, TextSize}; + +use crate::context::parser_helper::{WordWithIndex, sql_to_words}; + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) struct GrantContext { + pub table_name: Option, + pub schema_name: Option, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, +} + +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. +pub(crate) struct GrantParser { + tokens: Peekable>, + previous_token: Option, + current_token: Option, + context: GrantContext, + cursor_position: usize, + in_roles_list: bool, +} + +impl GrantParser { + pub(crate) fn looks_like_grant_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("grant") + } + + pub(crate) fn get_context(sql: &str, cursor_position: usize) -> GrantContext { + assert!( + Self::looks_like_grant_stmt(sql), + "GrantParser should only be used for GRANT statements. Developer error!" + ); + + match sql_to_words(sql) { + Ok(tokens) => { + let parser = GrantParser { + tokens: tokens.into_iter().peekable(), + context: GrantContext::default(), + previous_token: None, + current_token: None, + cursor_position, + in_roles_list: false, + }; + + parser.parse() + } + Err(_) => GrantContext::default(), + } + } + + fn parse(mut self) -> GrantContext { + while let Some(token) = self.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } + } + + self.context + } + + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.previous_token.is_none() { + return; + } + + let previous = self.previous_token.take().unwrap(); + let current = self + .current_token + .as_ref() + .map(|w| w.get_word_without_quotes()); + + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { + "grant" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "keyword_grant".into(); + self.context.node_text = token.get_word(); + } + "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { + self.handle_table(&token) + } + + "table" => { + self.handle_table(&token); + } + "to" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "grant_role".into(); + self.context.node_text = token.get_word(); + } + p => { + if self.in_roles_list && p.ends_with(',') { + self.context.node_kind = "grant_role".into(); + } + + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + } + } + } + + fn handle_table(&mut self, token: &WordWithIndex) { + if token.get_word_without_quotes().contains('.') { + let (schema_name, table_name) = self.schema_and_table_name(token); + + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_kind = "grant_table".into(); + + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + self.context.node_kind = "grant_table".into(); + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.get_word_without_quotes().as_str() { + "on" if !self.next_matches("table") => self.table_with_schema(), + "table" => self.table_with_schema(), + + "to" => { + self.in_roles_list = true; + } + + t => { + if self.in_roles_list && !t.ends_with(',') { + self.in_roles_list = false; + } + } + } + } + + fn next_matches(&mut self, it: &str) -> bool { + self.tokens + .peek() + .is_some_and(|c| c.get_word_without_quotes().as_str() == it) + } + + fn advance(&mut self) -> Option { + // we can't peek back n an iterator, so we'll have to keep track manually. + self.previous_token = self.current_token.take(); + self.current_token = self.tokens.next(); + self.current_token.clone() + } + + fn table_with_schema(&mut self) { + if let Some(token) = self.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.get_word_without_quotes().contains('.') { + let (schema, maybe_table) = self.schema_and_table_name(&token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.get_word()); + } + }; + } + + fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { + let word = token.get_word_without_quotes(); + let mut parts = word.split('.'); + + ( + parts.next().unwrap().into(), + parts.next().map(|tb| tb.into()), + ) + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::grant_parser::{GrantContext, GrantParser}, + test_helper::CURSOR_POS, + }; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_grant_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(19), TextSize::new(33)), + node_kind: "keyword_grant".into(), + } + ); + } + + #[test] + fn infers_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(29), TextSize::new(43)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(35), TextSize::new(49)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_role_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn determines_table_name_after_schema() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} to test_role + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_quoted_schema_and_table() { + let (pos, query) = with_pos(format!( + r#" + grant select on "MySchema"."MyTable" to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("MyTable".into()), + schema_name: Some("MySchema".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(53), TextSize::new(67)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn infers_multiple_roles() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to alice, {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(52), TextSize::new(66)), + node_kind: "grant_role".into(), + } + ); + } +} diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 6e0a952c..940f90f2 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -2,6 +2,8 @@ use std::{ cmp, collections::{HashMap, HashSet}, }; +mod grant_parser; +mod parser_helper; mod policy_parser; use pgt_schema_cache::SchemaCache; @@ -13,7 +15,10 @@ use pgt_treesitter_queries::{ use crate::{ NodeText, - context::policy_parser::{PolicyParser, PolicyStmtKind}, + context::{ + grant_parser::GrantParser, + policy_parser::{PolicyParser, PolicyStmtKind}, + }, sanitization::SanitizedCompletionParams, }; @@ -36,6 +41,7 @@ pub enum WrappingClause<'a> { SetStatement, AlterRole, DropRole, + Grant, } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -192,14 +198,48 @@ impl<'a> CompletionContext<'a> { // We infer the context manually. if PolicyParser::looks_like_policy_stmt(¶ms.text) { ctx.gather_policy_context(); + } else if GrantParser::looks_like_grant_stmt(¶ms.text) { + ctx.gather_grant_context(); } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); } + println!("{:#?}", ctx.text); + println!("{:#?}", ctx.wrapping_clause_type); + println!("{:#?}", ctx.node_under_cursor); + ctx } + fn gather_grant_context(&mut self) { + let grant_context = GrantParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: grant_context.node_text.into(), + range: grant_context.node_range, + kind: grant_context.node_kind.clone(), + }); + + if grant_context.node_kind == "grant_table" { + self.schema_or_alias_name = grant_context.schema_name.clone(); + } + + if grant_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(grant_context.table_name.unwrap()); + self.mentioned_relations + .insert(grant_context.schema_name, new); + } + + self.wrapping_clause_type = match grant_context.node_kind.as_str() { + "keyword_grant" => Some(WrappingClause::Grant), + "grant_role" => Some(WrappingClause::ToRoleAssignment), + "grant_table" => Some(WrappingClause::From), + _ => None, + }; + } + fn gather_policy_context(&mut self) { let policy_context = PolicyParser::get_context(self.text, self.position); diff --git a/crates/pgt_completions/src/context/parser_helper.rs b/crates/pgt_completions/src/context/parser_helper.rs new file mode 100644 index 00000000..11a5dbec --- /dev/null +++ b/crates/pgt_completions/src/context/parser_helper.rs @@ -0,0 +1,132 @@ +use pgt_text_size::{TextRange, TextSize}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct WordWithIndex { + word: String, + start: usize, + end: usize, +} + +impl WordWithIndex { + pub(crate) fn is_under_cursor(&self, cursor_pos: usize) -> bool { + self.start <= cursor_pos && self.end > cursor_pos + } + + pub(crate) fn get_range(&self) -> TextRange { + let start: u32 = self.start.try_into().expect("Text too long"); + let end: u32 = self.end.try_into().expect("Text too long"); + TextRange::new(TextSize::from(start), TextSize::from(end)) + } + + pub(crate) fn get_word_without_quotes(&self) -> String { + self.word.replace('"', "") + } + + pub(crate) fn get_word(&self) -> String { + self.word.clone() + } +} + +/// Note: A policy name within quotation marks will be considered a single word. +pub(crate) fn sql_to_words(sql: &str) -> Result, String> { + let mut words = vec![]; + + let mut start_of_word: Option = None; + let mut current_word = String::new(); + let mut in_quotation_marks = false; + + for (current_position, current_char) in sql.char_indices() { + if (current_char.is_ascii_whitespace() || current_char == ';') + && !current_word.is_empty() + && start_of_word.is_some() + && !in_quotation_marks + { + words.push(WordWithIndex { + word: current_word, + start: start_of_word.unwrap(), + end: current_position, + }); + + current_word = String::new(); + start_of_word = None; + } else if (current_char.is_ascii_whitespace() || current_char == ';') + && current_word.is_empty() + { + // do nothing + } else if current_char == '"' && start_of_word.is_none() { + in_quotation_marks = true; + current_word.push(current_char); + start_of_word = Some(current_position); + } else if current_char == '"' && start_of_word.is_some() { + current_word.push(current_char); + in_quotation_marks = false; + } else if start_of_word.is_some() { + current_word.push(current_char) + } else { + start_of_word = Some(current_position); + current_word.push(current_char); + } + } + + if let Some(start_of_word) = start_of_word { + if !current_word.is_empty() { + words.push(WordWithIndex { + word: current_word, + start: start_of_word, + end: sql.len(), + }); + } + } + + if in_quotation_marks { + Err("String was not closed properly.".into()) + } else { + Ok(words) + } +} + +#[cfg(test)] +mod tests { + use crate::context::parser_helper::{WordWithIndex, sql_to_words}; + + #[test] + fn determines_positions_correctly() { + let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("create", 1, 7)); + assert_eq!(words[1], to_word("policy", 8, 14)); + assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); + assert_eq!(words[3], to_word("on", 30, 32)); + assert_eq!(words[4], to_word("auth.users", 33, 43)); + assert_eq!(words[5], to_word("as", 45, 47)); + assert_eq!(words[6], to_word("permissive", 48, 58)); + assert_eq!(words[7], to_word("for", 60, 63)); + assert_eq!(words[8], to_word("select", 64, 70)); + assert_eq!(words[9], to_word("to", 73, 75)); + assert_eq!(words[10], to_word("public", 78, 84)); + assert_eq!(words[11], to_word("using", 87, 92)); + assert_eq!(words[12], to_word("(true)", 93, 99)); + } + + #[test] + fn handles_schemas_in_quotation_marks() { + let query = r#"grant select on "public"."users""#.to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("grant", 0, 5)); + assert_eq!(words[1], to_word("select", 6, 12)); + assert_eq!(words[2], to_word("on", 13, 15)); + assert_eq!(words[3], to_word(r#""public"."users""#, 16, 32)); + } + + fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { + WordWithIndex { + word: word.into(), + start, + end, + } + } +} diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index db37a13f..3bfb343c 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -2,6 +2,8 @@ use std::iter::Peekable; use pgt_text_size::{TextRange, TextSize}; +use crate::context::parser_helper::{WordWithIndex, sql_to_words}; + #[derive(Default, Debug, PartialEq, Eq)] pub(crate) enum PolicyStmtKind { #[default] @@ -11,90 +13,6 @@ pub(crate) enum PolicyStmtKind { Drop, } -#[derive(Clone, Debug, PartialEq, Eq)] -struct WordWithIndex { - word: String, - start: usize, - end: usize, -} - -impl WordWithIndex { - fn is_under_cursor(&self, cursor_pos: usize) -> bool { - self.start <= cursor_pos && self.end > cursor_pos - } - - fn get_range(&self) -> TextRange { - let start: u32 = self.start.try_into().expect("Text too long"); - let end: u32 = self.end.try_into().expect("Text too long"); - TextRange::new(TextSize::from(start), TextSize::from(end)) - } -} - -/// Note: A policy name within quotation marks will be considered a single word. -fn sql_to_words(sql: &str) -> Result, String> { - let mut words = vec![]; - - let mut start_of_word: Option = None; - let mut current_word = String::new(); - let mut in_quotation_marks = false; - - for (current_position, current_char) in sql.char_indices() { - if (current_char.is_ascii_whitespace() || current_char == ';') - && !current_word.is_empty() - && start_of_word.is_some() - && !in_quotation_marks - { - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: current_position, - }); - - current_word = String::new(); - start_of_word = None; - } else if (current_char.is_ascii_whitespace() || current_char == ';') - && current_word.is_empty() - { - // do nothing - } else if current_char == '"' && start_of_word.is_none() { - in_quotation_marks = true; - current_word.push(current_char); - start_of_word = Some(current_position); - } else if current_char == '"' && start_of_word.is_some() { - current_word.push(current_char); - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: current_position + 1, - }); - in_quotation_marks = false; - start_of_word = None; - current_word = String::new() - } else if start_of_word.is_some() { - current_word.push(current_char) - } else { - start_of_word = Some(current_position); - current_word.push(current_char); - } - } - - if let Some(start_of_word) = start_of_word { - if !current_word.is_empty() { - words.push(WordWithIndex { - word: current_word, - start: start_of_word, - end: sql.len(), - }); - } - } - - if in_quotation_marks { - Err("String was not closed properly.".into()) - } else { - Ok(words) - } -} - #[derive(Default, Debug, PartialEq, Eq)] pub(crate) struct PolicyContext { pub policy_name: Option, @@ -168,14 +86,18 @@ impl PolicyParser { let previous = self.previous_token.take().unwrap(); - match previous.word.to_ascii_lowercase().as_str() { + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { "policy" => { self.context.node_range = token.get_range(); self.context.node_kind = "policy_name".into(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } "on" => { - if token.word.contains('.') { + if token.get_word_without_quotes().contains('.') { let (schema_name, table_name) = self.schema_and_table_name(&token); let schema_name_len = schema_name.len(); @@ -198,24 +120,28 @@ impl PolicyParser { self.context.node_text = table_name.unwrap_or_default(); } else { self.context.node_range = token.get_range(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); self.context.node_kind = "policy_table".into(); } } "to" => { self.context.node_range = token.get_range(); self.context.node_kind = "policy_role".into(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } _ => { self.context.node_range = token.get_range(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } } } fn handle_token(&mut self, token: WordWithIndex) { - match token.word.to_ascii_lowercase().as_str() { + match token + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { "create" if self.next_matches("policy") => { self.context.statement_kind = PolicyStmtKind::Create; } @@ -234,18 +160,22 @@ impl PolicyParser { _ => { if self.prev_matches("policy") { - self.context.policy_name = Some(token.word); + self.context.policy_name = Some(token.get_word()); } } } } fn next_matches(&mut self, it: &str) -> bool { - self.tokens.peek().is_some_and(|c| c.word.as_str() == it) + self.tokens + .peek() + .is_some_and(|c| c.get_word_without_quotes().as_str() == it) } fn prev_matches(&self, it: &str) -> bool { - self.previous_token.as_ref().is_some_and(|t| t.word == it) + self.previous_token + .as_ref() + .is_some_and(|t| t.get_word_without_quotes() == it) } fn advance(&mut self) -> Option { @@ -259,18 +189,19 @@ impl PolicyParser { if let Some(token) = self.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); - } else if token.word.contains('.') { + } else if token.get_word_without_quotes().contains('.') { let (schema, maybe_table) = self.schema_and_table_name(&token); self.context.schema_name = Some(schema); self.context.table_name = maybe_table; } else { - self.context.table_name = Some(token.word); + self.context.table_name = Some(token.get_word()); } }; } fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { - let mut parts = token.word.split('.'); + let word = token.get_word_without_quotes(); + let mut parts = word.split('.'); ( parts.next().unwrap().into(), @@ -284,11 +215,11 @@ mod tests { use pgt_text_size::{TextRange, TextSize}; use crate::{ - context::policy_parser::{PolicyContext, PolicyStmtKind, WordWithIndex}, + context::policy_parser::{PolicyContext, PolicyStmtKind}, test_helper::CURSOR_POS, }; - use super::{PolicyParser, sql_to_words}; + use super::PolicyParser; fn with_pos(query: String) -> (usize, String) { let mut pos: Option = None; @@ -508,6 +439,8 @@ mod tests { CURSOR_POS )); + println!("{}", query); + let context = PolicyParser::get_context(query.as_str(), pos); assert_eq!( @@ -585,33 +518,4 @@ mod tests { assert_eq!(context, PolicyContext::default()); } - - fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { - WordWithIndex { - word: word.into(), - start, - end, - } - } - - #[test] - fn determines_positions_correctly() { - let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); - - let words = sql_to_words(query.as_str()).unwrap(); - - assert_eq!(words[0], to_word("create", 1, 7)); - assert_eq!(words[1], to_word("policy", 8, 14)); - assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); - assert_eq!(words[3], to_word("on", 30, 32)); - assert_eq!(words[4], to_word("auth.users", 33, 43)); - assert_eq!(words[5], to_word("as", 45, 47)); - assert_eq!(words[6], to_word("permissive", 48, 58)); - assert_eq!(words[7], to_word("for", 60, 63)); - assert_eq!(words[8], to_word("select", 64, 70)); - assert_eq!(words[9], to_word("to", 73, 75)); - assert_eq!(words[10], to_word("public", 78, 84)); - assert_eq!(words[11], to_word("using", 87, 92)); - assert_eq!(words[12], to_word("(true)", 93, 99)); - } } diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 01a905a1..0eae1fc5 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -44,7 +44,7 @@ mod tests { assert_complete_results( format!("drop role {}", CURSOR_POS).as_str(), vec![ - CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( "test_login".into(), crate::CompletionItemKind::Role, @@ -65,7 +65,7 @@ mod tests { assert_complete_results( format!("alter role {}", CURSOR_POS).as_str(), vec![ - CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( "test_login".into(), crate::CompletionItemKind::Role, @@ -88,7 +88,7 @@ mod tests { assert_complete_results( format!("set role {}", CURSOR_POS).as_str(), vec![ - CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( "test_login".into(), crate::CompletionItemKind::Role, @@ -106,7 +106,7 @@ mod tests { assert_complete_results( format!("set session authorization {}", CURSOR_POS).as_str(), vec![ - CompletionAssertion::LabelAndKind("admin".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), CompletionAssertion::LabelAndKind( "test_login".into(), crate::CompletionItemKind::Role, @@ -122,11 +122,131 @@ mod tests { .await; } - async fn works_in_policies() {} + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_policies(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); - async fn works_in_grant_statements() { - // grant select on my_table to ROLE; - // grant ROLE to OTHER_ROLE with admin option; + assert_complete_results( + format!( + r#"create policy "my cool policy" on public.users + as restrictive + for all + to {} + using (true);"#, + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!( + r#"create policy "my cool policy" on public.users + for select + to {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_grant_statements(pool: PgPool) { + // assert_complete_results( + // format!( + // r#"grant select + // on table public.users + // to {}"#, + // CURSOR_POS + // ) + // .as_str(), + // vec![ + // // recognizing already mentioned roles is not supported for now + // CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + // CompletionAssertion::LabelAndKind( + // "test_login".into(), + // crate::CompletionItemKind::Role, + // ), + // CompletionAssertion::LabelAndKind( + // "test_nologin".into(), + // crate::CompletionItemKind::Role, + // ), + // ], + // None, + // &pool, + // ) + // .await; + + // assert_complete_results( + // format!( + // r#"grant select + // on table public.users + // to owner, {}"#, + // CURSOR_POS + // ) + // .as_str(), + // vec![ + // // recognizing already mentioned roles is not supported for now + // CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + // CompletionAssertion::LabelAndKind( + // "test_login".into(), + // crate::CompletionItemKind::Role, + // ), + // CompletionAssertion::LabelAndKind( + // "test_nologin".into(), + // crate::CompletionItemKind::Role, + // ), + // ], + // None, + // &pool, + // ) + // .await; + + assert_complete_results( + format!(r#"grant {} to owner"#, CURSOR_POS).as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; } async fn works_in_revoke_statements() { diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ba9b5256..6ecdd76b 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -77,6 +77,7 @@ impl CompletionFilter<'_> { WrappingClause::Select | WrappingClause::Where | WrappingClause::ColumnDefinitions + | WrappingClause::ToRoleAssignment | WrappingClause::SetStatement => false, WrappingClause::Insert => { @@ -104,6 +105,7 @@ impl CompletionFilter<'_> { | WrappingClause::ColumnDefinitions | WrappingClause::SetStatement | WrappingClause::AlterTable + | WrappingClause::ToRoleAssignment | WrappingClause::DropTable => false, // We can complete columns in JOIN cluases, but only if we are after the @@ -173,9 +175,18 @@ impl CompletionFilter<'_> { } CompletionRelevanceData::Role(_) => match clause { - WrappingClause::DropRole | WrappingClause::AlterRole => true, + WrappingClause::DropRole + | WrappingClause::AlterRole + | WrappingClause::ToRoleAssignment => true, + + WrappingClause::Grant => ctx + .node_under_cursor + .as_ref() + .is_some_and(|n| n.kind() == "keyword_grant"), + WrappingClause::SetStatement => ctx .before_cursor_matches_kind(&["keyword_role", "keyword_authorization"]), + _ => false, }, } diff --git a/crates/pgt_lsp/src/session.rs b/crates/pgt_lsp/src/session.rs index fd5af2da..ede0469f 100644 --- a/crates/pgt_lsp/src/session.rs +++ b/crates/pgt_lsp/src/session.rs @@ -32,9 +32,11 @@ use tower_lsp::lsp_types::{Unregistration, WorkspaceFolder}; use tracing::{error, info}; pub(crate) struct ClientInformation { + #[allow(dead_code)] /// The name of the client pub(crate) name: String, + #[allow(dead_code)] /// The version of the client pub(crate) version: Option, } @@ -76,6 +78,7 @@ pub(crate) struct Session { struct InitializeParams { /// The capabilities provided by the client as part of [`lsp_types::InitializeParams`] client_capabilities: lsp_types::ClientCapabilities, + #[allow(dead_code)] client_information: Option, root_uri: Option, #[allow(unused)] diff --git a/crates/pgt_workspace/src/settings.rs b/crates/pgt_workspace/src/settings.rs index 08854493..ac55d8a1 100644 --- a/crates/pgt_workspace/src/settings.rs +++ b/crates/pgt_workspace/src/settings.rs @@ -214,45 +214,6 @@ pub struct Settings { pub migrations: Option, } -#[derive(Debug)] -pub struct SettingsHandleMut<'a> { - inner: RwLockWriteGuard<'a, Settings>, -} - -/// Handle object holding a temporary lock on the settings -#[derive(Debug)] -pub struct SettingsHandle<'a> { - inner: RwLockReadGuard<'a, Settings>, -} - -impl<'a> SettingsHandle<'a> { - pub(crate) fn new(settings: &'a RwLock) -> Self { - Self { - inner: settings.read().unwrap(), - } - } -} - -impl AsRef for SettingsHandle<'_> { - fn as_ref(&self) -> &Settings { - &self.inner - } -} - -impl<'a> SettingsHandleMut<'a> { - pub(crate) fn new(settings: &'a RwLock) -> Self { - Self { - inner: settings.write().unwrap(), - } - } -} - -impl AsMut for SettingsHandleMut<'_> { - fn as_mut(&mut self) -> &mut Settings { - &mut self.inner - } -} - impl Settings { /// The [PartialConfiguration] is merged into the workspace #[tracing::instrument(level = "trace", skip(self), err)] From 2d37803e94ff4d15de582594b8d26c320c60b87d Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 3 Jun 2025 08:56:49 +0200 Subject: [PATCH 19/27] quicksave --- .../src/context/grant_parser.rs | 4 +- crates/pgt_completions/src/context/mod.rs | 2 - crates/pgt_completions/src/providers/roles.rs | 96 +++++++++---------- .../src/relevance/filtering.rs | 5 - 4 files changed, 50 insertions(+), 57 deletions(-) diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs index f92c71d3..d79d81f2 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -86,7 +86,7 @@ impl GrantParser { { "grant" => { self.context.node_range = token.get_range(); - self.context.node_kind = "keyword_grant".into(); + self.context.node_kind = "grant_role".into(); self.context.node_text = token.get_word(); } "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { @@ -239,7 +239,7 @@ mod tests { schema_name: None, node_text: "REPLACED_TOKEN".into(), node_range: TextRange::new(TextSize::new(19), TextSize::new(33)), - node_kind: "keyword_grant".into(), + node_kind: "grant_role".into(), } ); } diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 940f90f2..48c0efaa 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -41,7 +41,6 @@ pub enum WrappingClause<'a> { SetStatement, AlterRole, DropRole, - Grant, } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -233,7 +232,6 @@ impl<'a> CompletionContext<'a> { } self.wrapping_clause_type = match grant_context.node_kind.as_str() { - "keyword_grant" => Some(WrappingClause::Grant), "grant_role" => Some(WrappingClause::ToRoleAssignment), "grant_table" => Some(WrappingClause::From), _ => None, diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 0eae1fc5..214be493 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -179,55 +179,55 @@ mod tests { #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn works_in_grant_statements(pool: PgPool) { - // assert_complete_results( - // format!( - // r#"grant select - // on table public.users - // to {}"#, - // CURSOR_POS - // ) - // .as_str(), - // vec![ - // // recognizing already mentioned roles is not supported for now - // CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), - // CompletionAssertion::LabelAndKind( - // "test_login".into(), - // crate::CompletionItemKind::Role, - // ), - // CompletionAssertion::LabelAndKind( - // "test_nologin".into(), - // crate::CompletionItemKind::Role, - // ), - // ], - // None, - // &pool, - // ) - // .await; + assert_complete_results( + format!( + r#"grant select + on table public.users + to {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; - // assert_complete_results( - // format!( - // r#"grant select - // on table public.users - // to owner, {}"#, - // CURSOR_POS - // ) - // .as_str(), - // vec![ - // // recognizing already mentioned roles is not supported for now - // CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), - // CompletionAssertion::LabelAndKind( - // "test_login".into(), - // crate::CompletionItemKind::Role, - // ), - // CompletionAssertion::LabelAndKind( - // "test_nologin".into(), - // crate::CompletionItemKind::Role, - // ), - // ], - // None, - // &pool, - // ) - // .await; + assert_complete_results( + format!( + r#"grant select + on table public.users + to owner, {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; assert_complete_results( format!(r#"grant {} to owner"#, CURSOR_POS).as_str(), diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 6ecdd76b..ff50f3b8 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -179,11 +179,6 @@ impl CompletionFilter<'_> { | WrappingClause::AlterRole | WrappingClause::ToRoleAssignment => true, - WrappingClause::Grant => ctx - .node_under_cursor - .as_ref() - .is_some_and(|n| n.kind() == "keyword_grant"), - WrappingClause::SetStatement => ctx .before_cursor_matches_kind(&["keyword_role", "keyword_authorization"]), From 027324f5f684d239582bed676b3db2a812e32367 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 3 Jun 2025 09:04:04 +0200 Subject: [PATCH 20/27] =?UTF-8?q?reading=20the=20card=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_schema_cache/src/policies.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pgt_schema_cache/src/policies.rs b/crates/pgt_schema_cache/src/policies.rs index 798027cf..8e2ee4d7 100644 --- a/crates/pgt_schema_cache/src/policies.rs +++ b/crates/pgt_schema_cache/src/policies.rs @@ -173,7 +173,7 @@ mod tests { let owner_policy = cache .policies .iter() - .find(|p| p.name == "test_nologin_policy") + .find(|p| p.name == "owner_policy") .unwrap(); assert_eq!(owner_policy.table_name, "users"); assert_eq!(owner_policy.schema_name, "public"); From 0dd285f200251abf87e1abdb1c21a39eba5a6f25 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 3 Jun 2025 09:50:33 +0200 Subject: [PATCH 21/27] wowa wiwa --- .../{parser_helper.rs => base_parser.rs} | 79 +++- .../src/context/grant_parser.rs | 111 +++-- crates/pgt_completions/src/context/mod.rs | 12 +- .../src/context/policy_parser.rs | 105 ++--- .../src/context/revoke_parser.rs | 385 ++++++++++++++++++ 5 files changed, 549 insertions(+), 143 deletions(-) rename crates/pgt_completions/src/context/{parser_helper.rs => base_parser.rs} (64%) create mode 100644 crates/pgt_completions/src/context/revoke_parser.rs diff --git a/crates/pgt_completions/src/context/parser_helper.rs b/crates/pgt_completions/src/context/base_parser.rs similarity index 64% rename from crates/pgt_completions/src/context/parser_helper.rs rename to crates/pgt_completions/src/context/base_parser.rs index 11a5dbec..a89d9610 100644 --- a/crates/pgt_completions/src/context/parser_helper.rs +++ b/crates/pgt_completions/src/context/base_parser.rs @@ -1,5 +1,79 @@ +use std::iter::Peekable; + use pgt_text_size::{TextRange, TextSize}; +pub(crate) struct TokenNavigator { + tokens: Peekable>, + pub previous_token: Option, + pub current_token: Option, +} + +impl TokenNavigator { + pub(crate) fn next_matches(&mut self, options: &[&str]) -> bool { + self.tokens + .peek() + .is_some_and(|c| options.contains(&c.get_word_without_quotes().as_str())) + } + + pub(crate) fn prev_matches(&self, it: &str) -> bool { + self.previous_token + .as_ref() + .is_some_and(|t| t.get_word_without_quotes() == it) + } + + pub(crate) fn advance(&mut self) -> Option { + // we can't peek back n an iterator, so we'll have to keep track manually. + self.previous_token = self.current_token.take(); + self.current_token = self.tokens.next(); + self.current_token.clone() + } +} + +impl From> for TokenNavigator { + fn from(tokens: Vec) -> Self { + TokenNavigator { + tokens: tokens.into_iter().peekable(), + previous_token: None, + current_token: None, + } + } +} + +pub(crate) trait CompletionStatementParser: Sized { + type Context: Default; + const NAME: &'static str; + + fn looks_like_matching_stmt(sql: &str) -> bool; + fn parse(self) -> Self::Context; + fn make_parser(tokens: Vec, cursor_position: usize) -> Self; + + fn get_context(sql: &str, cursor_position: usize) -> Self::Context { + assert!( + Self::looks_like_matching_stmt(sql), + "Using {} for a wrong statement! Developer Error!", + Self::NAME + ); + + match sql_to_words(sql) { + Ok(tokens) => { + let parser = Self::make_parser(tokens, cursor_position); + parser.parse() + } + Err(_) => Self::Context::default(), + } + } +} + +pub(crate) fn schema_and_table_name(token: &WordWithIndex) -> (String, Option) { + let word = token.get_word_without_quotes(); + let mut parts = word.split('.'); + + ( + parts.next().unwrap().into(), + parts.next().map(|tb| tb.into()), + ) +} + #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct WordWithIndex { word: String, @@ -29,13 +103,14 @@ impl WordWithIndex { /// Note: A policy name within quotation marks will be considered a single word. pub(crate) fn sql_to_words(sql: &str) -> Result, String> { + let lowercased = sql.to_ascii_lowercase(); let mut words = vec![]; let mut start_of_word: Option = None; let mut current_word = String::new(); let mut in_quotation_marks = false; - for (current_position, current_char) in sql.char_indices() { + for (current_position, current_char) in lowercased.char_indices() { if (current_char.is_ascii_whitespace() || current_char == ';') && !current_word.is_empty() && start_of_word.is_some() @@ -87,7 +162,7 @@ pub(crate) fn sql_to_words(sql: &str) -> Result, String> { #[cfg(test)] mod tests { - use crate::context::parser_helper::{WordWithIndex, sql_to_words}; + use crate::context::base_parser::{WordWithIndex, sql_to_words}; #[test] fn determines_positions_correctly() { diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs index d79d81f2..82905775 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -1,8 +1,8 @@ -use std::iter::Peekable; - use pgt_text_size::{TextRange, TextSize}; -use crate::context::parser_helper::{WordWithIndex, sql_to_words}; +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, +}; #[derive(Default, Debug, PartialEq, Eq)] pub(crate) struct GrantContext { @@ -18,46 +18,24 @@ pub(crate) struct GrantContext { /// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. /// It can only parse policy statements. pub(crate) struct GrantParser { - tokens: Peekable>, - previous_token: Option, - current_token: Option, + navigator: TokenNavigator, context: GrantContext, cursor_position: usize, in_roles_list: bool, } -impl GrantParser { - pub(crate) fn looks_like_grant_stmt(sql: &str) -> bool { +impl CompletionStatementParser for GrantParser { + type Context = GrantContext; + const NAME: &'static str = "GrantParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { let lowercased = sql.to_ascii_lowercase(); let trimmed = lowercased.trim(); trimmed.starts_with("grant") } - pub(crate) fn get_context(sql: &str, cursor_position: usize) -> GrantContext { - assert!( - Self::looks_like_grant_stmt(sql), - "GrantParser should only be used for GRANT statements. Developer error!" - ); - - match sql_to_words(sql) { - Ok(tokens) => { - let parser = GrantParser { - tokens: tokens.into_iter().peekable(), - context: GrantContext::default(), - previous_token: None, - current_token: None, - cursor_position, - in_roles_list: false, - }; - - parser.parse() - } - Err(_) => GrantContext::default(), - } - } - - fn parse(mut self) -> GrantContext { - while let Some(token) = self.advance() { + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else { @@ -68,13 +46,25 @@ impl GrantParser { self.context } + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: GrantContext::default(), + cursor_position, + in_roles_list: false, + } + } +} + +impl GrantParser { fn handle_token_under_cursor(&mut self, token: WordWithIndex) { - if self.previous_token.is_none() { + if self.navigator.previous_token.is_none() { return; } - let previous = self.previous_token.take().unwrap(); + let previous = self.navigator.previous_token.take().unwrap(); let current = self + .navigator .current_token .as_ref() .map(|w| w.get_word_without_quotes()); @@ -114,7 +104,7 @@ impl GrantParser { fn handle_table(&mut self, token: &WordWithIndex) { if token.get_word_without_quotes().contains('.') { - let (schema_name, table_name) = self.schema_and_table_name(token); + let (schema_name, table_name) = schema_and_table_name(token); let schema_name_len = schema_name.len(); self.context.schema_name = Some(schema_name); @@ -143,7 +133,26 @@ impl GrantParser { fn handle_token(&mut self, token: WordWithIndex) { match token.get_word_without_quotes().as_str() { - "on" if !self.next_matches("table") => self.table_with_schema(), + "on" if !self.navigator.next_matches(&[ + "table", + "schema", + "foreign", + "domain", + "sequence", + "database", + "function", + "procedure", + "routine", + "language", + "large", + "parameter", + "schema", + "tablespace", + "type", + ]) => + { + self.table_with_schema() + } "table" => self.table_with_schema(), "to" => { @@ -158,25 +167,12 @@ impl GrantParser { } } - fn next_matches(&mut self, it: &str) -> bool { - self.tokens - .peek() - .is_some_and(|c| c.get_word_without_quotes().as_str() == it) - } - - fn advance(&mut self) -> Option { - // we can't peek back n an iterator, so we'll have to keep track manually. - self.previous_token = self.current_token.take(); - self.current_token = self.tokens.next(); - self.current_token.clone() - } - fn table_with_schema(&mut self) { - if let Some(token) = self.advance() { + if let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else if token.get_word_without_quotes().contains('.') { - let (schema, maybe_table) = self.schema_and_table_name(&token); + let (schema, maybe_table) = schema_and_table_name(&token); self.context.schema_name = Some(schema); self.context.table_name = maybe_table; } else { @@ -184,16 +180,6 @@ impl GrantParser { } }; } - - fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { - let word = token.get_word_without_quotes(); - let mut parts = word.split('.'); - - ( - parts.next().unwrap().into(), - parts.next().map(|tb| tb.into()), - ) - } } #[cfg(test)] @@ -201,6 +187,7 @@ mod tests { use pgt_text_size::{TextRange, TextSize}; use crate::{ + context::base_parser::CompletionStatementParser, context::grant_parser::{GrantContext, GrantParser}, test_helper::CURSOR_POS, }; diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 48c0efaa..bcdbc6f8 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -2,9 +2,10 @@ use std::{ cmp, collections::{HashMap, HashSet}, }; +mod base_parser; mod grant_parser; -mod parser_helper; mod policy_parser; +mod revoke_parser; use pgt_schema_cache::SchemaCache; use pgt_text_size::TextRange; @@ -16,6 +17,7 @@ use pgt_treesitter_queries::{ use crate::{ NodeText, context::{ + base_parser::CompletionStatementParser, grant_parser::GrantParser, policy_parser::{PolicyParser, PolicyStmtKind}, }, @@ -195,19 +197,15 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - if PolicyParser::looks_like_policy_stmt(¶ms.text) { + if PolicyParser::looks_like_matching_stmt(¶ms.text) { ctx.gather_policy_context(); - } else if GrantParser::looks_like_grant_stmt(¶ms.text) { + } else if GrantParser::looks_like_matching_stmt(¶ms.text) { ctx.gather_grant_context(); } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); } - println!("{:#?}", ctx.text); - println!("{:#?}", ctx.wrapping_clause_type); - println!("{:#?}", ctx.node_under_cursor); - ctx } diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 3bfb343c..1c1b0533 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -1,8 +1,8 @@ -use std::iter::Peekable; - use pgt_text_size::{TextRange, TextSize}; -use crate::context::parser_helper::{WordWithIndex, sql_to_words}; +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, sql_to_words, +}; #[derive(Default, Debug, PartialEq, Eq)] pub(crate) enum PolicyStmtKind { @@ -29,15 +29,16 @@ pub(crate) struct PolicyContext { /// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. /// It can only parse policy statements. pub(crate) struct PolicyParser { - tokens: Peekable>, - previous_token: Option, - current_token: Option, + navigator: TokenNavigator, context: PolicyContext, cursor_position: usize, } -impl PolicyParser { - pub(crate) fn looks_like_policy_stmt(sql: &str) -> bool { +impl CompletionStatementParser for PolicyParser { + type Context = PolicyContext; + const NAME: &'static str = "PolicyParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { let lowercased = sql.to_ascii_lowercase(); let trimmed = lowercased.trim(); trimmed.starts_with("create policy") @@ -45,30 +46,8 @@ impl PolicyParser { || trimmed.starts_with("alter policy") } - pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { - assert!( - Self::looks_like_policy_stmt(sql), - "PolicyParser should only be used for policy statements. Developer error!" - ); - - match sql_to_words(sql) { - Ok(tokens) => { - let parser = PolicyParser { - tokens: tokens.into_iter().peekable(), - context: PolicyContext::default(), - previous_token: None, - current_token: None, - cursor_position, - }; - - parser.parse() - } - Err(_) => PolicyContext::default(), - } - } - - fn parse(mut self) -> PolicyContext { - while let Some(token) = self.advance() { + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else { @@ -79,12 +58,22 @@ impl PolicyParser { self.context } + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: PolicyContext::default(), + cursor_position, + } + } +} + +impl PolicyParser { fn handle_token_under_cursor(&mut self, token: WordWithIndex) { - if self.previous_token.is_none() { + if self.navigator.previous_token.is_none() { return; } - let previous = self.previous_token.take().unwrap(); + let previous = self.navigator.previous_token.take().unwrap(); match previous .get_word_without_quotes() @@ -98,7 +87,7 @@ impl PolicyParser { } "on" => { if token.get_word_without_quotes().contains('.') { - let (schema_name, table_name) = self.schema_and_table_name(&token); + let (schema_name, table_name) = schema_and_table_name(&token); let schema_name_len = schema_name.len(); self.context.schema_name = Some(schema_name); @@ -142,55 +131,36 @@ impl PolicyParser { .to_ascii_lowercase() .as_str() { - "create" if self.next_matches("policy") => { + "create" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Create; } - "alter" if self.next_matches("policy") => { + "alter" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Alter; } - "drop" if self.next_matches("policy") => { + "drop" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Drop; } "on" => self.table_with_schema(), // skip the "to" so we don't parse it as the TO rolename when it's under the cursor - "rename" if self.next_matches("to") => { - self.advance(); + "rename" if self.navigator.next_matches(&["to"]) => { + self.navigator.advance(); } _ => { - if self.prev_matches("policy") { + if self.navigator.prev_matches("policy") { self.context.policy_name = Some(token.get_word()); } } } } - fn next_matches(&mut self, it: &str) -> bool { - self.tokens - .peek() - .is_some_and(|c| c.get_word_without_quotes().as_str() == it) - } - - fn prev_matches(&self, it: &str) -> bool { - self.previous_token - .as_ref() - .is_some_and(|t| t.get_word_without_quotes() == it) - } - - fn advance(&mut self) -> Option { - // we can't peek back n an iterator, so we'll have to keep track manually. - self.previous_token = self.current_token.take(); - self.current_token = self.tokens.next(); - self.current_token.clone() - } - fn table_with_schema(&mut self) { - if let Some(token) = self.advance() { + if let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else if token.get_word_without_quotes().contains('.') { - let (schema, maybe_table) = self.schema_and_table_name(&token); + let (schema, maybe_table) = schema_and_table_name(&token); self.context.schema_name = Some(schema); self.context.table_name = maybe_table; } else { @@ -198,16 +168,6 @@ impl PolicyParser { } }; } - - fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { - let word = token.get_word_without_quotes(); - let mut parts = word.split('.'); - - ( - parts.next().unwrap().into(), - parts.next().map(|tb| tb.into()), - ) - } } #[cfg(test)] @@ -215,6 +175,7 @@ mod tests { use pgt_text_size::{TextRange, TextSize}; use crate::{ + context::base_parser::CompletionStatementParser, context::policy_parser::{PolicyContext, PolicyStmtKind}, test_helper::CURSOR_POS, }; diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs new file mode 100644 index 00000000..9df22a11 --- /dev/null +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -0,0 +1,385 @@ +use pgt_text_size::{TextRange, TextSize}; + +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, +}; + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) struct RevokeContext { + pub table_name: Option, + pub schema_name: Option, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, +} + +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. +pub(crate) struct RevokeParser { + navigator: TokenNavigator, + context: RevokeContext, + cursor_position: usize, +} + +impl CompletionStatementParser for RevokeParser { + type Context = RevokeContext; + const NAME: &'static str = "GrantParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("revoke") + } + + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } + } + + self.context + } + + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: RevokeContext::default(), + cursor_position, + } + } +} + +impl RevokeParser { + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.navigator.previous_token.is_none() { + return; + } + + let previous = self.navigator.previous_token.take().unwrap(); + let current = self + .navigator + .current_token + .as_ref() + .map(|w| w.get_word_without_quotes()); + + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { + "grant" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "grant_role".into(); + self.context.node_text = token.get_word(); + } + "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { + self.handle_table(&token) + } + + "table" => { + self.handle_table(&token); + } + + "to" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "grant_role".into(); + self.context.node_text = token.get_word(); + } + + _ => { + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + } + } + } + + fn handle_table(&mut self, token: &WordWithIndex) { + if token.get_word_without_quotes().contains('.') { + let (schema_name, table_name) = schema_and_table_name(token); + + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_kind = "grant_table".into(); + + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + self.context.node_kind = "grant_table".into(); + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.get_word_without_quotes().as_str() { + "on" if !self.navigator.next_matches(&["table"]) => self.table_with_schema(), + "table" => self.table_with_schema(), + _ => {} + } + } + + fn table_with_schema(&mut self) { + if let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.get_word_without_quotes().contains('.') { + let (schema, maybe_table) = schema_and_table_name(&token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.get_word()); + } + }; + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::base_parser::CompletionStatementParser, + context::grant_parser::{GrantContext, GrantParser}, + test_helper::CURSOR_POS, + }; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_grant_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(19), TextSize::new(33)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn infers_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(29), TextSize::new(43)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(35), TextSize::new(49)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_role_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn determines_table_name_after_schema() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} to test_role + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_quoted_schema_and_table() { + let (pos, query) = with_pos(format!( + r#" + grant select on "MySchema"."MyTable" to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("MyTable".into()), + schema_name: Some("MySchema".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(53), TextSize::new(67)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn infers_multiple_roles() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to alice, {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(52), TextSize::new(66)), + node_kind: "grant_role".into(), + } + ); + } +} From f72297a5e8247eb81b309eb3302d19e6790d43ed Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 08:37:05 +0200 Subject: [PATCH 22/27] ok --- .../src/context/base_parser.rs | 7 +- .../src/context/grant_parser.rs | 4 +- crates/pgt_completions/src/context/mod.rs | 3 + .../src/context/policy_parser.rs | 4 +- .../src/context/revoke_parser.rs | 281 +++--------------- 5 files changed, 46 insertions(+), 253 deletions(-) diff --git a/crates/pgt_completions/src/context/base_parser.rs b/crates/pgt_completions/src/context/base_parser.rs index a89d9610..93333679 100644 --- a/crates/pgt_completions/src/context/base_parser.rs +++ b/crates/pgt_completions/src/context/base_parser.rs @@ -15,10 +15,10 @@ impl TokenNavigator { .is_some_and(|c| options.contains(&c.get_word_without_quotes().as_str())) } - pub(crate) fn prev_matches(&self, it: &str) -> bool { + pub(crate) fn prev_matches(&self, options: &[&str]) -> bool { self.previous_token .as_ref() - .is_some_and(|t| t.get_word_without_quotes() == it) + .is_some_and(|t| options.contains(&t.get_word_without_quotes().as_str())) } pub(crate) fn advance(&mut self) -> Option { @@ -103,14 +103,13 @@ impl WordWithIndex { /// Note: A policy name within quotation marks will be considered a single word. pub(crate) fn sql_to_words(sql: &str) -> Result, String> { - let lowercased = sql.to_ascii_lowercase(); let mut words = vec![]; let mut start_of_word: Option = None; let mut current_word = String::new(); let mut in_quotation_marks = false; - for (current_position, current_char) in lowercased.char_indices() { + for (current_position, current_char) in sql.char_indices() { if (current_char.is_ascii_whitespace() || current_char == ';') && !current_word.is_empty() && start_of_word.is_some() diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs index 82905775..8a0bcd1e 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -91,8 +91,8 @@ impl GrantParser { self.context.node_kind = "grant_role".into(); self.context.node_text = token.get_word(); } - p => { - if self.in_roles_list && p.ends_with(',') { + t => { + if self.in_roles_list && t.ends_with(',') { self.context.node_kind = "grant_role".into(); } diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index bcdbc6f8..0fbded30 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -20,6 +20,7 @@ use crate::{ base_parser::CompletionStatementParser, grant_parser::GrantParser, policy_parser::{PolicyParser, PolicyStmtKind}, + revoke_parser::RevokeParser, }, sanitization::SanitizedCompletionParams, }; @@ -201,6 +202,8 @@ impl<'a> CompletionContext<'a> { ctx.gather_policy_context(); } else if GrantParser::looks_like_matching_stmt(¶ms.text) { ctx.gather_grant_context(); + } else if RevokeParser::looks_like_matching_stmt(¶ms.text) { + todo!() } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 1c1b0533..465b8ff3 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -1,7 +1,7 @@ use pgt_text_size::{TextRange, TextSize}; use crate::context::base_parser::{ - CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, sql_to_words, + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, }; #[derive(Default, Debug, PartialEq, Eq)] @@ -148,7 +148,7 @@ impl PolicyParser { } _ => { - if self.navigator.prev_matches("policy") { + if self.navigator.prev_matches(&["policy"]) { self.context.policy_name = Some(token.get_word()); } } diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs index 9df22a11..8afe73ff 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -21,6 +21,8 @@ pub(crate) struct RevokeParser { navigator: TokenNavigator, context: RevokeContext, cursor_position: usize, + in_roles_list: bool, + is_revoking_role: bool, } impl CompletionStatementParser for RevokeParser { @@ -50,6 +52,8 @@ impl CompletionStatementParser for RevokeParser { navigator: tokens.into(), context: RevokeContext::default(), cursor_position, + in_roles_list: false, + is_revoking_role: false, } } } @@ -72,11 +76,6 @@ impl RevokeParser { .to_ascii_lowercase() .as_str() { - "grant" => { - self.context.node_range = token.get_range(); - self.context.node_kind = "grant_role".into(); - self.context.node_text = token.get_word(); - } "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { self.handle_table(&token) } @@ -85,13 +84,23 @@ impl RevokeParser { self.handle_table(&token); } - "to" => { + "from" | "revoke" => { self.context.node_range = token.get_range(); - self.context.node_kind = "grant_role".into(); + self.context.node_kind = "revoke_role".into(); self.context.node_text = token.get_word(); } - _ => { + "for" if self.is_revoking_role => { + self.context.node_range = token.get_range(); + self.context.node_kind = "revoke_role".into(); + self.context.node_text = token.get_word(); + } + + t => { + if self.in_roles_list && t.ends_with(',') { + self.context.node_kind = "grant_role".into(); + } + self.context.node_range = token.get_range(); self.context.node_text = token.get_word(); } @@ -114,7 +123,7 @@ impl RevokeParser { .expect("Text too long"); self.context.node_range = range_without_schema; - self.context.node_kind = "grant_table".into(); + self.context.node_kind = "revoke_table".into(); // In practice, we should always have a table name. // The completion sanitization will add a word after a `.` if nothing follows it; @@ -123,15 +132,31 @@ impl RevokeParser { } else { self.context.node_range = token.get_range(); self.context.node_text = token.get_word(); - self.context.node_kind = "grant_table".into(); + self.context.node_kind = "revoke_table".into(); } } fn handle_token(&mut self, token: WordWithIndex) { match token.get_word_without_quotes().as_str() { "on" if !self.navigator.next_matches(&["table"]) => self.table_with_schema(), + + // This is the only case where there is no "GRANT" before the option: + // REVOKE [ { ADMIN | INHERIT | SET } OPTION FOR ] role_name + "option" if !self.navigator.prev_matches(&["grant"]) => { + self.is_revoking_role = true; + } + "table" => self.table_with_schema(), - _ => {} + + "from" => { + self.in_roles_list = true; + } + + t => { + if self.in_roles_list && !t.ends_with(',') { + self.in_roles_list = false; + } + } } } @@ -149,237 +174,3 @@ impl RevokeParser { }; } } - -#[cfg(test)] -mod tests { - use pgt_text_size::{TextRange, TextSize}; - - use crate::{ - context::base_parser::CompletionStatementParser, - context::grant_parser::{GrantContext, GrantParser}, - test_helper::CURSOR_POS, - }; - - fn with_pos(query: String) -> (usize, String) { - let mut pos: Option = None; - - for (p, c) in query.char_indices() { - if c == CURSOR_POS { - pos = Some(p); - break; - } - } - - ( - pos.expect("Please add cursor position!"), - query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), - ) - } - - #[test] - fn infers_grant_keyword() { - let (pos, query) = with_pos(format!( - r#" - grant {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: None, - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(19), TextSize::new(33)), - node_kind: "grant_role".into(), - } - ); - } - - #[test] - fn infers_table_name() { - let (pos, query) = with_pos(format!( - r#" - grant select on {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: None, - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(29), TextSize::new(43)), - node_kind: "grant_table".into(), - } - ); - } - - #[test] - fn infers_table_name_with_keyword() { - let (pos, query) = with_pos(format!( - r#" - grant select on table {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: None, - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(35), TextSize::new(49)), - node_kind: "grant_table".into(), - } - ); - } - - #[test] - fn infers_schema_and_table_name() { - let (pos, query) = with_pos(format!( - r#" - grant select on public.{} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: Some("public".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), - node_kind: "grant_table".into(), - } - ); - } - - #[test] - fn infers_schema_and_table_name_with_keyword() { - let (pos, query) = with_pos(format!( - r#" - grant select on table public.{} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: Some("public".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), - node_kind: "grant_table".into(), - } - ); - } - - #[test] - fn infers_role_name() { - let (pos, query) = with_pos(format!( - r#" - grant select on public.users to {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: Some("users".into()), - schema_name: Some("public".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), - node_kind: "grant_role".into(), - } - ); - } - - #[test] - fn determines_table_name_after_schema() { - let (pos, query) = with_pos(format!( - r#" - grant select on public.{} to test_role - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: None, - schema_name: Some("public".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), - node_kind: "grant_table".into(), - } - ); - } - - #[test] - fn infers_quoted_schema_and_table() { - let (pos, query) = with_pos(format!( - r#" - grant select on "MySchema"."MyTable" to {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: Some("MyTable".into()), - schema_name: Some("MySchema".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(53), TextSize::new(67)), - node_kind: "grant_role".into(), - } - ); - } - - #[test] - fn infers_multiple_roles() { - let (pos, query) = with_pos(format!( - r#" - grant select on public.users to alice, {} - "#, - CURSOR_POS - )); - - let context = GrantParser::get_context(query.as_str(), pos); - - assert_eq!( - context, - GrantContext { - table_name: Some("users".into()), - schema_name: Some("public".into()), - node_text: "REPLACED_TOKEN".into(), - node_range: TextRange::new(TextSize::new(52), TextSize::new(66)), - node_kind: "grant_role".into(), - } - ); - } -} From 578741eea19465dede1c7371d9f95c602ab83712 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 08:43:38 +0200 Subject: [PATCH 23/27] =?UTF-8?q?lowercase=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/pgt_completions/src/context/policy_parser.rs | 2 -- crates/pgt_completions/src/sanitization.rs | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 465b8ff3..58619502 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -400,8 +400,6 @@ mod tests { CURSOR_POS )); - println!("{}", query); - let context = PolicyParser::get_context(query.as_str(), pos); assert_eq!( diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 40dea7e6..69dadcbb 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -6,6 +6,7 @@ use crate::CompletionParams; static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; +#[derive(Debug)] pub(crate) struct SanitizedCompletionParams<'a> { pub position: TextSize, pub text: String, @@ -48,7 +49,8 @@ impl<'larger, 'smaller> From> for SanitizedCompletionP where 'larger: 'smaller, { - fn from(params: CompletionParams<'larger>) -> Self { + fn from(mut params: CompletionParams<'larger>) -> Self { + params.text = params.text.to_ascii_lowercase(); if cursor_inbetween_nodes(¶ms.text, params.position) || cursor_prepared_to_write_token_after_last_node(¶ms.text, params.position) || cursor_before_semicolon(params.tree, params.position) From ef5cb98bbc679c2183358204c88674548a5d7f46 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 09:07:25 +0200 Subject: [PATCH 24/27] wowa wiwa --- crates/pgt_completions/src/context/mod.rs | 29 ++++++++++++- .../src/context/revoke_parser.rs | 4 +- crates/pgt_completions/src/providers/roles.rs | 42 +++++++++++++++++-- 3 files changed, 68 insertions(+), 7 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 0fbded30..2db9ea1c 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -203,7 +203,7 @@ impl<'a> CompletionContext<'a> { } else if GrantParser::looks_like_matching_stmt(¶ms.text) { ctx.gather_grant_context(); } else if RevokeParser::looks_like_matching_stmt(¶ms.text) { - todo!() + ctx.gather_revoke_context(); } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); @@ -212,6 +212,33 @@ impl<'a> CompletionContext<'a> { ctx } + fn gather_revoke_context(&mut self) { + let revoke_context = RevokeParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: revoke_context.node_text.into(), + range: revoke_context.node_range, + kind: revoke_context.node_kind.clone(), + }); + + if revoke_context.node_kind == "revoke_table" { + self.schema_or_alias_name = revoke_context.schema_name.clone(); + } + + if revoke_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(revoke_context.table_name.unwrap()); + self.mentioned_relations + .insert(revoke_context.schema_name, new); + } + + self.wrapping_clause_type = match revoke_context.node_kind.as_str() { + "revoke_role" => Some(WrappingClause::ToRoleAssignment), + "revoke_table" => Some(WrappingClause::From), + _ => None, + }; + } + fn gather_grant_context(&mut self) { let grant_context = GrantParser::get_context(self.text, self.position); diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs index 8afe73ff..206321b0 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -27,7 +27,7 @@ pub(crate) struct RevokeParser { impl CompletionStatementParser for RevokeParser { type Context = RevokeContext; - const NAME: &'static str = "GrantParser"; + const NAME: &'static str = "RevokeParser"; fn looks_like_matching_stmt(sql: &str) -> bool { let lowercased = sql.to_ascii_lowercase(); @@ -98,7 +98,7 @@ impl RevokeParser { t => { if self.in_roles_list && t.ends_with(',') { - self.context.node_kind = "grant_role".into(); + self.context.node_kind = "revoke_role".into(); } self.context.node_range = token.get_range(); diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index 214be493..d2014ba1 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -179,6 +179,8 @@ mod tests { #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] async fn works_in_grant_statements(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + assert_complete_results( format!( r#"grant select @@ -249,9 +251,41 @@ mod tests { .await; } - async fn works_in_revoke_statements() { - // revoke select on my_table from ROLE; - // revoke ROLE from OTHER_ROLE; - // revoke admin option for ROLE from OTHER_ROLE; + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_revoke_statements(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + let queries = vec![ + format!("revoke {} from owner", CURSOR_POS), + format!("revoke admin option for {} from owner", CURSOR_POS), + format!("revoke owner from {}", CURSOR_POS), + format!("revoke all on schema public from {} granted by", CURSOR_POS), + format!("revoke all on schema public from owner, {}", CURSOR_POS), + format!("revoke all on table userse from owner, {}", CURSOR_POS), + ]; + + for query in queries { + assert_complete_results( + query.as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind( + "owner".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } } } From 9214c917627acd50b8caf77fb443f7b0cf843089 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 09:23:55 +0200 Subject: [PATCH 25/27] add tests --- .../src/context/revoke_parser.rs | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs index 206321b0..76451277 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -174,3 +174,168 @@ impl RevokeParser { }; } } + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::base_parser::CompletionStatementParser, + context::revoke_parser::{RevokeContext, RevokeParser}, + test_helper::CURSOR_POS, + }; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_revoke_keyword() { + let (pos, query) = with_pos(format!( + r#" + revoke {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(20), TextSize::new(34)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_table_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(30), TextSize::new(44)), + node_kind: "revoke_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.{} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(37), TextSize::new(51)), + node_kind: "revoke_table".into(), + } + ); + } + + #[test] + fn infers_role_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.users from {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(48), TextSize::new(62)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_multiple_roles() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.users from alice, {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(55), TextSize::new(69)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_quoted_schema_and_table() { + let (pos, query) = with_pos(format!( + r#" + revoke select on "MySchema"."MyTable" from {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("MyTable".into()), + schema_name: Some("MySchema".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(56), TextSize::new(70)), + node_kind: "revoke_role".into(), + } + ); + } +} From f2b4b4405a648c46dc896ef3744c142ddd045b11 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 09:36:09 +0200 Subject: [PATCH 26/27] linty --- crates/pgt_completions/src/context/grant_parser.rs | 2 +- crates/pgt_completions/src/context/revoke_parser.rs | 2 +- crates/pgt_completions/src/providers/roles.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs index 8a0bcd1e..29b0d9b3 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -79,7 +79,7 @@ impl GrantParser { self.context.node_kind = "grant_role".into(); self.context.node_text = token.get_word(); } - "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { + "on" if !matches!(current.as_deref(), Some("table")) => { self.handle_table(&token) } diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs index 76451277..42f3b41f 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -76,7 +76,7 @@ impl RevokeParser { .to_ascii_lowercase() .as_str() { - "on" if !matches!(current.as_ref().map(|c| c.as_str()), Some("table")) => { + "on" if !matches!(current.as_deref(), Some("table")) => { self.handle_table(&token) } diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs index d2014ba1..01641543 100644 --- a/crates/pgt_completions/src/providers/roles.rs +++ b/crates/pgt_completions/src/providers/roles.rs @@ -31,7 +31,7 @@ mod tests { use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; - const SETUP: &'static str = r#" + const SETUP: &str = r#" create table users ( id serial primary key, email varchar, From b5e82edfb7e73150cf9dab49f07d12b6d30614d0 Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 5 Jun 2025 09:36:55 +0200 Subject: [PATCH 27/27] format --- crates/pgt_completions/src/context/grant_parser.rs | 4 +--- crates/pgt_completions/src/context/revoke_parser.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs index 29b0d9b3..14ba882a 100644 --- a/crates/pgt_completions/src/context/grant_parser.rs +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -79,9 +79,7 @@ impl GrantParser { self.context.node_kind = "grant_role".into(); self.context.node_text = token.get_word(); } - "on" if !matches!(current.as_deref(), Some("table")) => { - self.handle_table(&token) - } + "on" if !matches!(current.as_deref(), Some("table")) => self.handle_table(&token), "table" => { self.handle_table(&token); diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs index 42f3b41f..e0c43934 100644 --- a/crates/pgt_completions/src/context/revoke_parser.rs +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -76,9 +76,7 @@ impl RevokeParser { .to_ascii_lowercase() .as_str() { - "on" if !matches!(current.as_deref(), Some("table")) => { - self.handle_table(&token) - } + "on" if !matches!(current.as_deref(), Some("table")) => self.handle_table(&token), "table" => { self.handle_table(&token);