diff --git a/crates/pgt_completions/src/complete.rs b/crates/pgt_completions/src/complete.rs index 5bc5d41c..bd5efd19 100644 --- a/crates/pgt_completions/src/complete.rs +++ b/crates/pgt_completions/src/complete.rs @@ -5,7 +5,8 @@ use crate::{ context::CompletionContext, item::CompletionItem, providers::{ - complete_columns, complete_functions, complete_policies, complete_schemas, complete_tables, + complete_columns, complete_functions, complete_policies, complete_roles, complete_schemas, + complete_tables, }, sanitization::SanitizedCompletionParams, }; @@ -36,6 +37,7 @@ pub fn complete(params: CompletionParams) -> Vec { complete_columns(&ctx, &mut builder); complete_schemas(&ctx, &mut builder); complete_policies(&ctx, &mut builder); + complete_roles(&ctx, &mut builder); builder.finish() } diff --git a/crates/pgt_completions/src/context/base_parser.rs b/crates/pgt_completions/src/context/base_parser.rs new file mode 100644 index 00000000..93333679 --- /dev/null +++ b/crates/pgt_completions/src/context/base_parser.rs @@ -0,0 +1,206 @@ +use std::iter::Peekable; + +use pgt_text_size::{TextRange, TextSize}; + +pub(crate) struct TokenNavigator { + tokens: Peekable>, + pub previous_token: Option, + pub current_token: Option, +} + +impl TokenNavigator { + pub(crate) fn next_matches(&mut self, options: &[&str]) -> bool { + self.tokens + .peek() + .is_some_and(|c| options.contains(&c.get_word_without_quotes().as_str())) + } + + pub(crate) fn prev_matches(&self, options: &[&str]) -> bool { + self.previous_token + .as_ref() + .is_some_and(|t| options.contains(&t.get_word_without_quotes().as_str())) + } + + pub(crate) fn advance(&mut self) -> Option { + // we can't peek back n an iterator, so we'll have to keep track manually. + self.previous_token = self.current_token.take(); + self.current_token = self.tokens.next(); + self.current_token.clone() + } +} + +impl From> for TokenNavigator { + fn from(tokens: Vec) -> Self { + TokenNavigator { + tokens: tokens.into_iter().peekable(), + previous_token: None, + current_token: None, + } + } +} + +pub(crate) trait CompletionStatementParser: Sized { + type Context: Default; + const NAME: &'static str; + + fn looks_like_matching_stmt(sql: &str) -> bool; + fn parse(self) -> Self::Context; + fn make_parser(tokens: Vec, cursor_position: usize) -> Self; + + fn get_context(sql: &str, cursor_position: usize) -> Self::Context { + assert!( + Self::looks_like_matching_stmt(sql), + "Using {} for a wrong statement! Developer Error!", + Self::NAME + ); + + match sql_to_words(sql) { + Ok(tokens) => { + let parser = Self::make_parser(tokens, cursor_position); + parser.parse() + } + Err(_) => Self::Context::default(), + } + } +} + +pub(crate) fn schema_and_table_name(token: &WordWithIndex) -> (String, Option) { + let word = token.get_word_without_quotes(); + let mut parts = word.split('.'); + + ( + parts.next().unwrap().into(), + parts.next().map(|tb| tb.into()), + ) +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct WordWithIndex { + word: String, + start: usize, + end: usize, +} + +impl WordWithIndex { + pub(crate) fn is_under_cursor(&self, cursor_pos: usize) -> bool { + self.start <= cursor_pos && self.end > cursor_pos + } + + pub(crate) fn get_range(&self) -> TextRange { + let start: u32 = self.start.try_into().expect("Text too long"); + let end: u32 = self.end.try_into().expect("Text too long"); + TextRange::new(TextSize::from(start), TextSize::from(end)) + } + + pub(crate) fn get_word_without_quotes(&self) -> String { + self.word.replace('"', "") + } + + pub(crate) fn get_word(&self) -> String { + self.word.clone() + } +} + +/// Note: A policy name within quotation marks will be considered a single word. +pub(crate) fn sql_to_words(sql: &str) -> Result, String> { + let mut words = vec![]; + + let mut start_of_word: Option = None; + let mut current_word = String::new(); + let mut in_quotation_marks = false; + + for (current_position, current_char) in sql.char_indices() { + if (current_char.is_ascii_whitespace() || current_char == ';') + && !current_word.is_empty() + && start_of_word.is_some() + && !in_quotation_marks + { + words.push(WordWithIndex { + word: current_word, + start: start_of_word.unwrap(), + end: current_position, + }); + + current_word = String::new(); + start_of_word = None; + } else if (current_char.is_ascii_whitespace() || current_char == ';') + && current_word.is_empty() + { + // do nothing + } else if current_char == '"' && start_of_word.is_none() { + in_quotation_marks = true; + current_word.push(current_char); + start_of_word = Some(current_position); + } else if current_char == '"' && start_of_word.is_some() { + current_word.push(current_char); + in_quotation_marks = false; + } else if start_of_word.is_some() { + current_word.push(current_char) + } else { + start_of_word = Some(current_position); + current_word.push(current_char); + } + } + + if let Some(start_of_word) = start_of_word { + if !current_word.is_empty() { + words.push(WordWithIndex { + word: current_word, + start: start_of_word, + end: sql.len(), + }); + } + } + + if in_quotation_marks { + Err("String was not closed properly.".into()) + } else { + Ok(words) + } +} + +#[cfg(test)] +mod tests { + use crate::context::base_parser::{WordWithIndex, sql_to_words}; + + #[test] + fn determines_positions_correctly() { + let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("create", 1, 7)); + assert_eq!(words[1], to_word("policy", 8, 14)); + assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); + assert_eq!(words[3], to_word("on", 30, 32)); + assert_eq!(words[4], to_word("auth.users", 33, 43)); + assert_eq!(words[5], to_word("as", 45, 47)); + assert_eq!(words[6], to_word("permissive", 48, 58)); + assert_eq!(words[7], to_word("for", 60, 63)); + assert_eq!(words[8], to_word("select", 64, 70)); + assert_eq!(words[9], to_word("to", 73, 75)); + assert_eq!(words[10], to_word("public", 78, 84)); + assert_eq!(words[11], to_word("using", 87, 92)); + assert_eq!(words[12], to_word("(true)", 93, 99)); + } + + #[test] + fn handles_schemas_in_quotation_marks() { + let query = r#"grant select on "public"."users""#.to_string(); + + let words = sql_to_words(query.as_str()).unwrap(); + + assert_eq!(words[0], to_word("grant", 0, 5)); + assert_eq!(words[1], to_word("select", 6, 12)); + assert_eq!(words[2], to_word("on", 13, 15)); + assert_eq!(words[3], to_word(r#""public"."users""#, 16, 32)); + } + + fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { + WordWithIndex { + word: word.into(), + start, + end, + } + } +} diff --git a/crates/pgt_completions/src/context/grant_parser.rs b/crates/pgt_completions/src/context/grant_parser.rs new file mode 100644 index 00000000..14ba882a --- /dev/null +++ b/crates/pgt_completions/src/context/grant_parser.rs @@ -0,0 +1,415 @@ +use pgt_text_size::{TextRange, TextSize}; + +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, +}; + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) struct GrantContext { + pub table_name: Option, + pub schema_name: Option, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, +} + +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. +pub(crate) struct GrantParser { + navigator: TokenNavigator, + context: GrantContext, + cursor_position: usize, + in_roles_list: bool, +} + +impl CompletionStatementParser for GrantParser { + type Context = GrantContext; + const NAME: &'static str = "GrantParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("grant") + } + + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } + } + + self.context + } + + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: GrantContext::default(), + cursor_position, + in_roles_list: false, + } + } +} + +impl GrantParser { + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.navigator.previous_token.is_none() { + return; + } + + let previous = self.navigator.previous_token.take().unwrap(); + let current = self + .navigator + .current_token + .as_ref() + .map(|w| w.get_word_without_quotes()); + + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { + "grant" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "grant_role".into(); + self.context.node_text = token.get_word(); + } + "on" if !matches!(current.as_deref(), Some("table")) => self.handle_table(&token), + + "table" => { + self.handle_table(&token); + } + "to" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "grant_role".into(); + self.context.node_text = token.get_word(); + } + t => { + if self.in_roles_list && t.ends_with(',') { + self.context.node_kind = "grant_role".into(); + } + + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + } + } + } + + fn handle_table(&mut self, token: &WordWithIndex) { + if token.get_word_without_quotes().contains('.') { + let (schema_name, table_name) = schema_and_table_name(token); + + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_kind = "grant_table".into(); + + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + self.context.node_kind = "grant_table".into(); + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.get_word_without_quotes().as_str() { + "on" if !self.navigator.next_matches(&[ + "table", + "schema", + "foreign", + "domain", + "sequence", + "database", + "function", + "procedure", + "routine", + "language", + "large", + "parameter", + "schema", + "tablespace", + "type", + ]) => + { + self.table_with_schema() + } + "table" => self.table_with_schema(), + + "to" => { + self.in_roles_list = true; + } + + t => { + if self.in_roles_list && !t.ends_with(',') { + self.in_roles_list = false; + } + } + } + } + + fn table_with_schema(&mut self) { + if let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.get_word_without_quotes().contains('.') { + let (schema, maybe_table) = schema_and_table_name(&token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.get_word()); + } + }; + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::base_parser::CompletionStatementParser, + context::grant_parser::{GrantContext, GrantParser}, + test_helper::CURSOR_POS, + }; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_grant_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(19), TextSize::new(33)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn infers_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(29), TextSize::new(43)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(35), TextSize::new(49)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name_with_keyword() { + let (pos, query) = with_pos(format!( + r#" + grant select on table public.{} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_role_name() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn determines_table_name_after_schema() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.{} to test_role + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(36), TextSize::new(50)), + node_kind: "grant_table".into(), + } + ); + } + + #[test] + fn infers_quoted_schema_and_table() { + let (pos, query) = with_pos(format!( + r#" + grant select on "MySchema"."MyTable" to {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("MyTable".into()), + schema_name: Some("MySchema".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(53), TextSize::new(67)), + node_kind: "grant_role".into(), + } + ); + } + + #[test] + fn infers_multiple_roles() { + let (pos, query) = with_pos(format!( + r#" + grant select on public.users to alice, {} + "#, + CURSOR_POS + )); + + let context = GrantParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + GrantContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(52), TextSize::new(66)), + node_kind: "grant_role".into(), + } + ); + } +} diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 7006c5bf..996ec6be 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -2,7 +2,10 @@ use std::{ cmp, collections::{HashMap, HashSet}, }; +mod base_parser; +mod grant_parser; mod policy_parser; +mod revoke_parser; use pgt_schema_cache::SchemaCache; use pgt_text_size::TextRange; @@ -13,7 +16,12 @@ use pgt_treesitter_queries::{ use crate::{ NodeText, - context::policy_parser::{PolicyParser, PolicyStmtKind}, + context::{ + base_parser::CompletionStatementParser, + grant_parser::GrantParser, + policy_parser::{PolicyParser, PolicyStmtKind}, + revoke_parser::RevokeParser, + }, sanitization::SanitizedCompletionParams, }; @@ -36,6 +44,9 @@ pub enum WrappingClause<'a> { RenameColumn, PolicyName, ToRoleAssignment, + SetStatement, + AlterRole, + DropRole, } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -190,8 +201,12 @@ impl<'a> CompletionContext<'a> { // policy handling is important to Supabase, but they are a PostgreSQL specific extension, // so the tree_sitter_sql language does not support it. // We infer the context manually. - if PolicyParser::looks_like_policy_stmt(¶ms.text) { + if PolicyParser::looks_like_matching_stmt(¶ms.text) { ctx.gather_policy_context(); + } else if GrantParser::looks_like_matching_stmt(¶ms.text) { + ctx.gather_grant_context(); + } else if RevokeParser::looks_like_matching_stmt(¶ms.text) { + ctx.gather_revoke_context(); } else { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); @@ -200,6 +215,60 @@ impl<'a> CompletionContext<'a> { ctx } + fn gather_revoke_context(&mut self) { + let revoke_context = RevokeParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: revoke_context.node_text.into(), + range: revoke_context.node_range, + kind: revoke_context.node_kind.clone(), + }); + + if revoke_context.node_kind == "revoke_table" { + self.schema_or_alias_name = revoke_context.schema_name.clone(); + } + + if revoke_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(revoke_context.table_name.unwrap()); + self.mentioned_relations + .insert(revoke_context.schema_name, new); + } + + self.wrapping_clause_type = match revoke_context.node_kind.as_str() { + "revoke_role" => Some(WrappingClause::ToRoleAssignment), + "revoke_table" => Some(WrappingClause::From), + _ => None, + }; + } + + fn gather_grant_context(&mut self) { + let grant_context = GrantParser::get_context(self.text, self.position); + + self.node_under_cursor = Some(NodeUnderCursor::CustomNode { + text: grant_context.node_text.into(), + range: grant_context.node_range, + kind: grant_context.node_kind.clone(), + }); + + if grant_context.node_kind == "grant_table" { + self.schema_or_alias_name = grant_context.schema_name.clone(); + } + + if grant_context.table_name.is_some() { + let mut new = HashSet::new(); + new.insert(grant_context.table_name.unwrap()); + self.mentioned_relations + .insert(grant_context.schema_name, new); + } + + self.wrapping_clause_type = match grant_context.node_kind.as_str() { + "grant_role" => Some(WrappingClause::ToRoleAssignment), + "grant_table" => Some(WrappingClause::From), + _ => None, + }; + } + fn gather_policy_context(&mut self) { let policy_context = PolicyParser::get_context(self.text, self.position); @@ -427,7 +496,8 @@ impl<'a> CompletionContext<'a> { } "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" - | "drop_table" | "alter_table" | "drop_column" | "alter_column" | "rename_column" => { + | "alter_role" | "drop_role" | "set_statement" | "drop_table" | "alter_table" + | "drop_column" | "alter_column" | "rename_column" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -662,10 +732,13 @@ impl<'a> CompletionContext<'a> { "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), "drop_table" => Some(WrappingClause::DropTable), + "alter_role" => Some(WrappingClause::AlterRole), + "drop_role" => Some(WrappingClause::DropRole), "drop_column" => Some(WrappingClause::DropColumn), "alter_column" => Some(WrappingClause::AlterColumn), "rename_column" => Some(WrappingClause::RenameColumn), "alter_table" => Some(WrappingClause::AlterTable), + "set_statement" => Some(WrappingClause::SetStatement), "column_definitions" => Some(WrappingClause::ColumnDefinitions), "insert" => Some(WrappingClause::Insert), "join" => { diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index db37a13f..58619502 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -1,7 +1,9 @@ -use std::iter::Peekable; - use pgt_text_size::{TextRange, TextSize}; +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, +}; + #[derive(Default, Debug, PartialEq, Eq)] pub(crate) enum PolicyStmtKind { #[default] @@ -11,90 +13,6 @@ pub(crate) enum PolicyStmtKind { Drop, } -#[derive(Clone, Debug, PartialEq, Eq)] -struct WordWithIndex { - word: String, - start: usize, - end: usize, -} - -impl WordWithIndex { - fn is_under_cursor(&self, cursor_pos: usize) -> bool { - self.start <= cursor_pos && self.end > cursor_pos - } - - fn get_range(&self) -> TextRange { - let start: u32 = self.start.try_into().expect("Text too long"); - let end: u32 = self.end.try_into().expect("Text too long"); - TextRange::new(TextSize::from(start), TextSize::from(end)) - } -} - -/// Note: A policy name within quotation marks will be considered a single word. -fn sql_to_words(sql: &str) -> Result, String> { - let mut words = vec![]; - - let mut start_of_word: Option = None; - let mut current_word = String::new(); - let mut in_quotation_marks = false; - - for (current_position, current_char) in sql.char_indices() { - if (current_char.is_ascii_whitespace() || current_char == ';') - && !current_word.is_empty() - && start_of_word.is_some() - && !in_quotation_marks - { - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: current_position, - }); - - current_word = String::new(); - start_of_word = None; - } else if (current_char.is_ascii_whitespace() || current_char == ';') - && current_word.is_empty() - { - // do nothing - } else if current_char == '"' && start_of_word.is_none() { - in_quotation_marks = true; - current_word.push(current_char); - start_of_word = Some(current_position); - } else if current_char == '"' && start_of_word.is_some() { - current_word.push(current_char); - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: current_position + 1, - }); - in_quotation_marks = false; - start_of_word = None; - current_word = String::new() - } else if start_of_word.is_some() { - current_word.push(current_char) - } else { - start_of_word = Some(current_position); - current_word.push(current_char); - } - } - - if let Some(start_of_word) = start_of_word { - if !current_word.is_empty() { - words.push(WordWithIndex { - word: current_word, - start: start_of_word, - end: sql.len(), - }); - } - } - - if in_quotation_marks { - Err("String was not closed properly.".into()) - } else { - Ok(words) - } -} - #[derive(Default, Debug, PartialEq, Eq)] pub(crate) struct PolicyContext { pub policy_name: Option, @@ -111,15 +29,16 @@ pub(crate) struct PolicyContext { /// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. /// It can only parse policy statements. pub(crate) struct PolicyParser { - tokens: Peekable>, - previous_token: Option, - current_token: Option, + navigator: TokenNavigator, context: PolicyContext, cursor_position: usize, } -impl PolicyParser { - pub(crate) fn looks_like_policy_stmt(sql: &str) -> bool { +impl CompletionStatementParser for PolicyParser { + type Context = PolicyContext; + const NAME: &'static str = "PolicyParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { let lowercased = sql.to_ascii_lowercase(); let trimmed = lowercased.trim(); trimmed.starts_with("create policy") @@ -127,30 +46,8 @@ impl PolicyParser { || trimmed.starts_with("alter policy") } - pub(crate) fn get_context(sql: &str, cursor_position: usize) -> PolicyContext { - assert!( - Self::looks_like_policy_stmt(sql), - "PolicyParser should only be used for policy statements. Developer error!" - ); - - match sql_to_words(sql) { - Ok(tokens) => { - let parser = PolicyParser { - tokens: tokens.into_iter().peekable(), - context: PolicyContext::default(), - previous_token: None, - current_token: None, - cursor_position, - }; - - parser.parse() - } - Err(_) => PolicyContext::default(), - } - } - - fn parse(mut self) -> PolicyContext { - while let Some(token) = self.advance() { + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); } else { @@ -161,22 +58,36 @@ impl PolicyParser { self.context } + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: PolicyContext::default(), + cursor_position, + } + } +} + +impl PolicyParser { fn handle_token_under_cursor(&mut self, token: WordWithIndex) { - if self.previous_token.is_none() { + if self.navigator.previous_token.is_none() { return; } - let previous = self.previous_token.take().unwrap(); + let previous = self.navigator.previous_token.take().unwrap(); - match previous.word.to_ascii_lowercase().as_str() { + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { "policy" => { self.context.node_range = token.get_range(); self.context.node_kind = "policy_name".into(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } "on" => { - if token.word.contains('.') { - let (schema_name, table_name) = self.schema_and_table_name(&token); + if token.get_word_without_quotes().contains('.') { + let (schema_name, table_name) = schema_and_table_name(&token); let schema_name_len = schema_name.len(); self.context.schema_name = Some(schema_name); @@ -198,85 +109,65 @@ impl PolicyParser { self.context.node_text = table_name.unwrap_or_default(); } else { self.context.node_range = token.get_range(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); self.context.node_kind = "policy_table".into(); } } "to" => { self.context.node_range = token.get_range(); self.context.node_kind = "policy_role".into(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } _ => { self.context.node_range = token.get_range(); - self.context.node_text = token.word; + self.context.node_text = token.get_word(); } } } fn handle_token(&mut self, token: WordWithIndex) { - match token.word.to_ascii_lowercase().as_str() { - "create" if self.next_matches("policy") => { + match token + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { + "create" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Create; } - "alter" if self.next_matches("policy") => { + "alter" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Alter; } - "drop" if self.next_matches("policy") => { + "drop" if self.navigator.next_matches(&["policy"]) => { self.context.statement_kind = PolicyStmtKind::Drop; } "on" => self.table_with_schema(), // skip the "to" so we don't parse it as the TO rolename when it's under the cursor - "rename" if self.next_matches("to") => { - self.advance(); + "rename" if self.navigator.next_matches(&["to"]) => { + self.navigator.advance(); } _ => { - if self.prev_matches("policy") { - self.context.policy_name = Some(token.word); + if self.navigator.prev_matches(&["policy"]) { + self.context.policy_name = Some(token.get_word()); } } } } - fn next_matches(&mut self, it: &str) -> bool { - self.tokens.peek().is_some_and(|c| c.word.as_str() == it) - } - - fn prev_matches(&self, it: &str) -> bool { - self.previous_token.as_ref().is_some_and(|t| t.word == it) - } - - fn advance(&mut self) -> Option { - // we can't peek back n an iterator, so we'll have to keep track manually. - self.previous_token = self.current_token.take(); - self.current_token = self.tokens.next(); - self.current_token.clone() - } - fn table_with_schema(&mut self) { - if let Some(token) = self.advance() { + if let Some(token) = self.navigator.advance() { if token.is_under_cursor(self.cursor_position) { self.handle_token_under_cursor(token); - } else if token.word.contains('.') { - let (schema, maybe_table) = self.schema_and_table_name(&token); + } else if token.get_word_without_quotes().contains('.') { + let (schema, maybe_table) = schema_and_table_name(&token); self.context.schema_name = Some(schema); self.context.table_name = maybe_table; } else { - self.context.table_name = Some(token.word); + self.context.table_name = Some(token.get_word()); } }; } - - fn schema_and_table_name(&self, token: &WordWithIndex) -> (String, Option) { - let mut parts = token.word.split('.'); - - ( - parts.next().unwrap().into(), - parts.next().map(|tb| tb.into()), - ) - } } #[cfg(test)] @@ -284,11 +175,12 @@ mod tests { use pgt_text_size::{TextRange, TextSize}; use crate::{ - context::policy_parser::{PolicyContext, PolicyStmtKind, WordWithIndex}, + context::base_parser::CompletionStatementParser, + context::policy_parser::{PolicyContext, PolicyStmtKind}, test_helper::CURSOR_POS, }; - use super::{PolicyParser, sql_to_words}; + use super::PolicyParser; fn with_pos(query: String) -> (usize, String) { let mut pos: Option = None; @@ -585,33 +477,4 @@ mod tests { assert_eq!(context, PolicyContext::default()); } - - fn to_word(word: &str, start: usize, end: usize) -> WordWithIndex { - WordWithIndex { - word: word.into(), - start, - end, - } - } - - #[test] - fn determines_positions_correctly() { - let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); - - let words = sql_to_words(query.as_str()).unwrap(); - - assert_eq!(words[0], to_word("create", 1, 7)); - assert_eq!(words[1], to_word("policy", 8, 14)); - assert_eq!(words[2], to_word("\"my cool pol\"", 15, 28)); - assert_eq!(words[3], to_word("on", 30, 32)); - assert_eq!(words[4], to_word("auth.users", 33, 43)); - assert_eq!(words[5], to_word("as", 45, 47)); - assert_eq!(words[6], to_word("permissive", 48, 58)); - assert_eq!(words[7], to_word("for", 60, 63)); - assert_eq!(words[8], to_word("select", 64, 70)); - assert_eq!(words[9], to_word("to", 73, 75)); - assert_eq!(words[10], to_word("public", 78, 84)); - assert_eq!(words[11], to_word("using", 87, 92)); - assert_eq!(words[12], to_word("(true)", 93, 99)); - } } diff --git a/crates/pgt_completions/src/context/revoke_parser.rs b/crates/pgt_completions/src/context/revoke_parser.rs new file mode 100644 index 00000000..e0c43934 --- /dev/null +++ b/crates/pgt_completions/src/context/revoke_parser.rs @@ -0,0 +1,339 @@ +use pgt_text_size::{TextRange, TextSize}; + +use crate::context::base_parser::{ + CompletionStatementParser, TokenNavigator, WordWithIndex, schema_and_table_name, +}; + +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) struct RevokeContext { + pub table_name: Option, + pub schema_name: Option, + pub node_text: String, + pub node_range: TextRange, + pub node_kind: String, +} + +/// Simple parser that'll turn a policy-related statement into a context object required for +/// completions. +/// The parser will only work if the (trimmed) sql starts with `create policy`, `drop policy`, or `alter policy`. +/// It can only parse policy statements. +pub(crate) struct RevokeParser { + navigator: TokenNavigator, + context: RevokeContext, + cursor_position: usize, + in_roles_list: bool, + is_revoking_role: bool, +} + +impl CompletionStatementParser for RevokeParser { + type Context = RevokeContext; + const NAME: &'static str = "RevokeParser"; + + fn looks_like_matching_stmt(sql: &str) -> bool { + let lowercased = sql.to_ascii_lowercase(); + let trimmed = lowercased.trim(); + trimmed.starts_with("revoke") + } + + fn parse(mut self) -> Self::Context { + while let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else { + self.handle_token(token); + } + } + + self.context + } + + fn make_parser(tokens: Vec, cursor_position: usize) -> Self { + Self { + navigator: tokens.into(), + context: RevokeContext::default(), + cursor_position, + in_roles_list: false, + is_revoking_role: false, + } + } +} + +impl RevokeParser { + fn handle_token_under_cursor(&mut self, token: WordWithIndex) { + if self.navigator.previous_token.is_none() { + return; + } + + let previous = self.navigator.previous_token.take().unwrap(); + let current = self + .navigator + .current_token + .as_ref() + .map(|w| w.get_word_without_quotes()); + + match previous + .get_word_without_quotes() + .to_ascii_lowercase() + .as_str() + { + "on" if !matches!(current.as_deref(), Some("table")) => self.handle_table(&token), + + "table" => { + self.handle_table(&token); + } + + "from" | "revoke" => { + self.context.node_range = token.get_range(); + self.context.node_kind = "revoke_role".into(); + self.context.node_text = token.get_word(); + } + + "for" if self.is_revoking_role => { + self.context.node_range = token.get_range(); + self.context.node_kind = "revoke_role".into(); + self.context.node_text = token.get_word(); + } + + t => { + if self.in_roles_list && t.ends_with(',') { + self.context.node_kind = "revoke_role".into(); + } + + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + } + } + } + + fn handle_table(&mut self, token: &WordWithIndex) { + if token.get_word_without_quotes().contains('.') { + let (schema_name, table_name) = schema_and_table_name(token); + + let schema_name_len = schema_name.len(); + self.context.schema_name = Some(schema_name); + + let offset: u32 = schema_name_len.try_into().expect("Text too long"); + let range_without_schema = token + .get_range() + .checked_expand_start( + TextSize::new(offset + 1), // kill the dot as well + ) + .expect("Text too long"); + + self.context.node_range = range_without_schema; + self.context.node_kind = "revoke_table".into(); + + // In practice, we should always have a table name. + // The completion sanitization will add a word after a `.` if nothing follows it; + // the token_text will then look like `schema.REPLACED_TOKEN`. + self.context.node_text = table_name.unwrap_or_default(); + } else { + self.context.node_range = token.get_range(); + self.context.node_text = token.get_word(); + self.context.node_kind = "revoke_table".into(); + } + } + + fn handle_token(&mut self, token: WordWithIndex) { + match token.get_word_without_quotes().as_str() { + "on" if !self.navigator.next_matches(&["table"]) => self.table_with_schema(), + + // This is the only case where there is no "GRANT" before the option: + // REVOKE [ { ADMIN | INHERIT | SET } OPTION FOR ] role_name + "option" if !self.navigator.prev_matches(&["grant"]) => { + self.is_revoking_role = true; + } + + "table" => self.table_with_schema(), + + "from" => { + self.in_roles_list = true; + } + + t => { + if self.in_roles_list && !t.ends_with(',') { + self.in_roles_list = false; + } + } + } + } + + fn table_with_schema(&mut self) { + if let Some(token) = self.navigator.advance() { + if token.is_under_cursor(self.cursor_position) { + self.handle_token_under_cursor(token); + } else if token.get_word_without_quotes().contains('.') { + let (schema, maybe_table) = schema_and_table_name(&token); + self.context.schema_name = Some(schema); + self.context.table_name = maybe_table; + } else { + self.context.table_name = Some(token.get_word()); + } + }; + } +} + +#[cfg(test)] +mod tests { + use pgt_text_size::{TextRange, TextSize}; + + use crate::{ + context::base_parser::CompletionStatementParser, + context::revoke_parser::{RevokeContext, RevokeParser}, + test_helper::CURSOR_POS, + }; + + fn with_pos(query: String) -> (usize, String) { + let mut pos: Option = None; + + for (p, c) in query.char_indices() { + if c == CURSOR_POS { + pos = Some(p); + break; + } + } + + ( + pos.expect("Please add cursor position!"), + query.replace(CURSOR_POS, "REPLACED_TOKEN").to_string(), + ) + } + + #[test] + fn infers_revoke_keyword() { + let (pos, query) = with_pos(format!( + r#" + revoke {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(20), TextSize::new(34)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_table_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: None, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(30), TextSize::new(44)), + node_kind: "revoke_table".into(), + } + ); + } + + #[test] + fn infers_schema_and_table_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.{} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: None, + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(37), TextSize::new(51)), + node_kind: "revoke_table".into(), + } + ); + } + + #[test] + fn infers_role_name() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.users from {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(48), TextSize::new(62)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_multiple_roles() { + let (pos, query) = with_pos(format!( + r#" + revoke select on public.users from alice, {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("users".into()), + schema_name: Some("public".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(55), TextSize::new(69)), + node_kind: "revoke_role".into(), + } + ); + } + + #[test] + fn infers_quoted_schema_and_table() { + let (pos, query) = with_pos(format!( + r#" + revoke select on "MySchema"."MyTable" from {} + "#, + CURSOR_POS + )); + + let context = RevokeParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + RevokeContext { + table_name: Some("MyTable".into()), + schema_name: Some("MySchema".into()), + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(56), TextSize::new(70)), + node_kind: "revoke_role".into(), + } + ); + } +} diff --git a/crates/pgt_completions/src/item.rs b/crates/pgt_completions/src/item.rs index 73e08cc0..766e436c 100644 --- a/crates/pgt_completions/src/item.rs +++ b/crates/pgt_completions/src/item.rs @@ -12,6 +12,7 @@ pub enum CompletionItemKind { Column, Schema, Policy, + Role, } impl Display for CompletionItemKind { @@ -22,6 +23,7 @@ impl Display for CompletionItemKind { CompletionItemKind::Column => "Column", CompletionItemKind::Schema => "Schema", CompletionItemKind::Policy => "Policy", + CompletionItemKind::Role => "Role", }; write!(f, "{txt}") diff --git a/crates/pgt_completions/src/providers/mod.rs b/crates/pgt_completions/src/providers/mod.rs index 7b07cee8..ddbdf252 100644 --- a/crates/pgt_completions/src/providers/mod.rs +++ b/crates/pgt_completions/src/providers/mod.rs @@ -2,11 +2,13 @@ mod columns; mod functions; mod helper; mod policies; +mod roles; mod schemas; mod tables; pub use columns::*; pub use functions::*; pub use policies::*; +pub use roles::*; pub use schemas::*; pub use tables::*; diff --git a/crates/pgt_completions/src/providers/roles.rs b/crates/pgt_completions/src/providers/roles.rs new file mode 100644 index 00000000..01641543 --- /dev/null +++ b/crates/pgt_completions/src/providers/roles.rs @@ -0,0 +1,291 @@ +use crate::{ + CompletionItemKind, + builder::{CompletionBuilder, PossibleCompletionItem}, + context::CompletionContext, + relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, +}; + +pub fn complete_roles<'a>(ctx: &CompletionContext<'a>, builder: &mut CompletionBuilder<'a>) { + let available_roles = &ctx.schema_cache.roles; + + for role in available_roles { + let relevance = CompletionRelevanceData::Role(role); + + let item = PossibleCompletionItem { + label: role.name.chars().take(35).collect::(), + score: CompletionScore::from(relevance.clone()), + filter: CompletionFilter::from(relevance), + description: role.name.clone(), + kind: CompletionItemKind::Role, + completion_text: None, + detail: None, + }; + + builder.add_item(item); + } +} + +#[cfg(test)] +mod tests { + use sqlx::{Executor, PgPool}; + + use crate::test_helper::{CURSOR_POS, CompletionAssertion, assert_complete_results}; + + const SETUP: &str = r#" + create table users ( + id serial primary key, + email varchar, + address text + ); + "#; + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_drop_role(pool: PgPool) { + assert_complete_results( + format!("drop role {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + Some(SETUP), + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_alter_role(pool: PgPool) { + assert_complete_results( + format!("alter role {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + Some(SETUP), + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_set_statement(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + assert_complete_results( + format!("set role {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!("set session authorization {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_policies(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + assert_complete_results( + format!( + r#"create policy "my cool policy" on public.users + as restrictive + for all + to {} + using (true);"#, + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!( + r#"create policy "my cool policy" on public.users + for select + to {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_grant_statements(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + assert_complete_results( + format!( + r#"grant select + on table public.users + to {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!( + r#"grant select + on table public.users + to owner, {}"#, + CURSOR_POS + ) + .as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + + assert_complete_results( + format!(r#"grant {} to owner"#, CURSOR_POS).as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind("owner".into(), crate::CompletionItemKind::Role), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn works_in_revoke_statements(pool: PgPool) { + pool.execute(SETUP).await.unwrap(); + + let queries = vec![ + format!("revoke {} from owner", CURSOR_POS), + format!("revoke admin option for {} from owner", CURSOR_POS), + format!("revoke owner from {}", CURSOR_POS), + format!("revoke all on schema public from {} granted by", CURSOR_POS), + format!("revoke all on schema public from owner, {}", CURSOR_POS), + format!("revoke all on table userse from owner, {}", CURSOR_POS), + ]; + + for query in queries { + assert_complete_results( + query.as_str(), + vec![ + // recognizing already mentioned roles is not supported for now + CompletionAssertion::LabelAndKind( + "owner".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_login".into(), + crate::CompletionItemKind::Role, + ), + CompletionAssertion::LabelAndKind( + "test_nologin".into(), + crate::CompletionItemKind::Role, + ), + ], + None, + &pool, + ) + .await; + } + } +} diff --git a/crates/pgt_completions/src/relevance.rs b/crates/pgt_completions/src/relevance.rs index f51c3c52..1d39d9bb 100644 --- a/crates/pgt_completions/src/relevance.rs +++ b/crates/pgt_completions/src/relevance.rs @@ -8,4 +8,5 @@ pub(crate) enum CompletionRelevanceData<'a> { Column(&'a pgt_schema_cache::Column), Schema(&'a pgt_schema_cache::Schema), Policy(&'a pgt_schema_cache::Policy), + Role(&'a pgt_schema_cache::Role), } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ea681bd7..a020d2e8 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -180,6 +180,17 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Policy(_) => { matches!(clause, WrappingClause::PolicyName) } + + CompletionRelevanceData::Role(_) => match clause { + WrappingClause::DropRole + | WrappingClause::AlterRole + | WrappingClause::ToRoleAssignment => true, + + WrappingClause::SetStatement => ctx + .before_cursor_matches_kind(&["keyword_role", "keyword_authorization"]), + + _ => false, + }, } }) .and_then(|is_ok| if is_ok { Some(()) } else { None }) @@ -215,8 +226,8 @@ impl CompletionFilter<'_> { // we should never allow schema suggestions if there already was one. CompletionRelevanceData::Schema(_) => false, - // no policy comletion if user typed a schema node first. - CompletionRelevanceData::Policy(_) => false, + // no policy or row completion if user typed a schema node first. + CompletionRelevanceData::Policy(_) | CompletionRelevanceData::Role(_) => false, }; if !matches { diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index a8c89f50..a0b5efa5 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -47,6 +47,7 @@ impl CompletionScore<'_> { CompletionRelevanceData::Column(c) => c.name.as_str().to_ascii_lowercase(), CompletionRelevanceData::Schema(s) => s.name.as_str().to_ascii_lowercase(), CompletionRelevanceData::Policy(p) => p.name.as_str().to_ascii_lowercase(), + CompletionRelevanceData::Role(r) => r.name.as_str().to_ascii_lowercase(), }; let fz_matcher = SkimMatcherV2::default(); @@ -126,6 +127,11 @@ impl CompletionScore<'_> { WrappingClause::PolicyName => 25, _ => -50, }, + + CompletionRelevanceData::Role(_) => match clause_type { + WrappingClause::DropRole | WrappingClause::AlterRole => 25, + _ => -50, + }, } } @@ -160,6 +166,7 @@ impl CompletionScore<'_> { _ => -50, }, CompletionRelevanceData::Policy(_) => 0, + CompletionRelevanceData::Role(_) => 0, } } @@ -178,7 +185,10 @@ impl CompletionScore<'_> { Some(n) => n, }; - let data_schema = self.get_schema_name(); + let data_schema = match self.get_schema_name() { + Some(s) => s, + None => return, + }; if schema_name == data_schema { self.score += 25; @@ -194,16 +204,18 @@ impl CompletionScore<'_> { CompletionRelevanceData::Column(c) => c.name.as_str(), CompletionRelevanceData::Schema(s) => s.name.as_str(), CompletionRelevanceData::Policy(p) => p.name.as_str(), + CompletionRelevanceData::Role(r) => r.name.as_str(), } } - fn get_schema_name(&self) -> &str { + fn get_schema_name(&self) -> Option<&str> { match self.data { - CompletionRelevanceData::Function(f) => f.schema.as_str(), - CompletionRelevanceData::Table(t) => t.schema.as_str(), - CompletionRelevanceData::Column(c) => c.schema_name.as_str(), - CompletionRelevanceData::Schema(s) => s.name.as_str(), - CompletionRelevanceData::Policy(p) => p.schema_name.as_str(), + CompletionRelevanceData::Function(f) => Some(f.schema.as_str()), + CompletionRelevanceData::Table(t) => Some(t.schema.as_str()), + CompletionRelevanceData::Column(c) => Some(c.schema_name.as_str()), + CompletionRelevanceData::Schema(s) => Some(s.name.as_str()), + CompletionRelevanceData::Policy(p) => Some(p.schema_name.as_str()), + CompletionRelevanceData::Role(_) => None, } } @@ -222,7 +234,10 @@ impl CompletionScore<'_> { _ => {} } - let schema = self.get_schema_name().to_string(); + let schema = match self.get_schema_name() { + Some(s) => s.to_string(), + None => return, + }; let table_name = match self.get_table_name() { Some(t) => t, None => return, @@ -244,7 +259,34 @@ impl CompletionScore<'_> { } fn check_is_user_defined(&mut self) { - let schema_name = self.get_schema_name().to_string(); + if let CompletionRelevanceData::Role(r) = self.data { + match r.name.as_str() { + "pg_read_all_data" + | "pg_write_all_data" + | "pg_read_all_settings" + | "pg_read_all_stats" + | "pg_stat_scan_tables" + | "pg_monitor" + | "pg_database_owner" + | "pg_signal_backend" + | "pg_read_server_files" + | "pg_write_server_files" + | "pg_execute_server_program" + | "pg_checkpoint" + | "pg_maintain" + | "pg_use_reserved_connections" + | "pg_create_subscription" + | "postgres" => self.score -= 20, + _ => {} + }; + + return; + } + + let schema_name = match self.get_schema_name() { + Some(s) => s.to_string(), + None => return, + }; let system_schemas = ["pg_catalog", "information_schema", "pg_toast"]; diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 154998e7..ddc9563e 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -6,6 +6,7 @@ use crate::CompletionParams; static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; +#[derive(Debug)] pub(crate) struct SanitizedCompletionParams<'a> { pub position: TextSize, pub text: String, diff --git a/crates/pgt_lsp/src/handlers/completions.rs b/crates/pgt_lsp/src/handlers/completions.rs index 7e901c79..4a035fcf 100644 --- a/crates/pgt_lsp/src/handlers/completions.rs +++ b/crates/pgt_lsp/src/handlers/completions.rs @@ -76,5 +76,6 @@ fn to_lsp_types_completion_item_kind( pgt_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, pgt_completions::CompletionItemKind::Schema => lsp_types::CompletionItemKind::CLASS, pgt_completions::CompletionItemKind::Policy => lsp_types::CompletionItemKind::CONSTANT, + pgt_completions::CompletionItemKind::Role => lsp_types::CompletionItemKind::CONSTANT, } } diff --git a/crates/pgt_lsp/src/session.rs b/crates/pgt_lsp/src/session.rs index fd5af2da..ede0469f 100644 --- a/crates/pgt_lsp/src/session.rs +++ b/crates/pgt_lsp/src/session.rs @@ -32,9 +32,11 @@ use tower_lsp::lsp_types::{Unregistration, WorkspaceFolder}; use tracing::{error, info}; pub(crate) struct ClientInformation { + #[allow(dead_code)] /// The name of the client pub(crate) name: String, + #[allow(dead_code)] /// The version of the client pub(crate) version: Option, } @@ -76,6 +78,7 @@ pub(crate) struct Session { struct InitializeParams { /// The capabilities provided by the client as part of [`lsp_types::InitializeParams`] client_capabilities: lsp_types::ClientCapabilities, + #[allow(dead_code)] client_information: Option, root_uri: Option, #[allow(unused)] diff --git a/crates/pgt_workspace/src/settings.rs b/crates/pgt_workspace/src/settings.rs index 08854493..ac55d8a1 100644 --- a/crates/pgt_workspace/src/settings.rs +++ b/crates/pgt_workspace/src/settings.rs @@ -214,45 +214,6 @@ pub struct Settings { pub migrations: Option, } -#[derive(Debug)] -pub struct SettingsHandleMut<'a> { - inner: RwLockWriteGuard<'a, Settings>, -} - -/// Handle object holding a temporary lock on the settings -#[derive(Debug)] -pub struct SettingsHandle<'a> { - inner: RwLockReadGuard<'a, Settings>, -} - -impl<'a> SettingsHandle<'a> { - pub(crate) fn new(settings: &'a RwLock) -> Self { - Self { - inner: settings.read().unwrap(), - } - } -} - -impl AsRef for SettingsHandle<'_> { - fn as_ref(&self) -> &Settings { - &self.inner - } -} - -impl<'a> SettingsHandleMut<'a> { - pub(crate) fn new(settings: &'a RwLock) -> Self { - Self { - inner: settings.write().unwrap(), - } - } -} - -impl AsMut for SettingsHandleMut<'_> { - fn as_mut(&mut self) -> &mut Settings { - &mut self.inner - } -} - impl Settings { /// The [PartialConfiguration] is merged into the workspace #[tracing::instrument(level = "trace", skip(self), err)]