diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 0bb190a9..7006c5bf 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -31,6 +31,9 @@ pub enum WrappingClause<'a> { Insert, AlterTable, DropTable, + DropColumn, + AlterColumn, + RenameColumn, PolicyName, ToRoleAssignment, } @@ -424,7 +427,7 @@ impl<'a> CompletionContext<'a> { } "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" - | "drop_table" | "alter_table" => { + | "drop_table" | "alter_table" | "drop_column" | "alter_column" | "rename_column" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -515,6 +518,8 @@ impl<'a> CompletionContext<'a> { (WrappingClause::From, &["from"]), (WrappingClause::Join { on_node: None }, &["join"]), (WrappingClause::AlterTable, &["alter", "table"]), + (WrappingClause::AlterColumn, &["alter", "table", "alter"]), + (WrappingClause::RenameColumn, &["alter", "table", "rename"]), ( WrappingClause::AlterTable, &["alter", "table", "if", "exists"], @@ -575,10 +580,54 @@ impl<'a> CompletionContext<'a> { let mut first_sibling = self.get_first_sibling(node); if let Some(clause) = self.wrapping_clause_type.as_ref() { - if clause == &WrappingClause::Insert { - while let Some(sib) = first_sibling.next_sibling() { - match sib.kind() { - "object_reference" => { + match *clause { + WrappingClause::Insert => { + while let Some(sib) = first_sibling.next_sibling() { + match sib.kind() { + "object_reference" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let mut iter = txt.split('.').rev(); + let table = iter.next().unwrap().to_string(); + let schema = iter.next().map(|s| s.to_string()); + self.mentioned_relations + .entry(schema) + .and_modify(|s| { + s.insert(table.clone()); + }) + .or_insert(HashSet::from([table])); + } + } + + "column" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let entry = MentionedColumn { + column: txt, + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(entry.clone()); + }) + .or_insert(HashSet::from([entry])); + } + } + + _ => {} + } + + first_sibling = sib; + } + } + + WrappingClause::AlterColumn => { + while let Some(sib) = first_sibling.next_sibling() { + if sib.kind() == "object_reference" { if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { let mut iter = txt.split('.').rev(); let table = iter.next().unwrap().to_string(); @@ -591,27 +640,12 @@ impl<'a> CompletionContext<'a> { .or_insert(HashSet::from([table])); } } - "column" => { - if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { - let entry = MentionedColumn { - column: txt, - alias: None, - }; - self.mentioned_columns - .entry(Some(WrappingClause::Insert)) - .and_modify(|s| { - s.insert(entry.clone()); - }) - .or_insert(HashSet::from([entry])); - } - } - - _ => {} + first_sibling = sib; } - - first_sibling = sib; } + + _ => {} } } } @@ -628,6 +662,9 @@ impl<'a> CompletionContext<'a> { "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), "drop_table" => Some(WrappingClause::DropTable), + "drop_column" => Some(WrappingClause::DropColumn), + "alter_column" => Some(WrappingClause::AlterColumn), + "rename_column" => Some(WrappingClause::RenameColumn), "alter_table" => Some(WrappingClause::AlterTable), "column_definitions" => Some(WrappingClause::ColumnDefinitions), "insert" => Some(WrappingClause::Insert), diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index b1dcbdf7..4299973b 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -762,4 +762,59 @@ mod tests { ) .await; } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_in_alter_table_and_drop_table(pool: PgPool) { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text, + created_at timestamp with time zone default now() + ); + + create table others ( + a text, + b text, + c text + ); + "#; + + pool.execute(setup).await.unwrap(); + + let queries = vec![ + format!("alter table instruments drop column {}", CURSOR_POS), + format!( + "alter table instruments drop column if exists {}", + CURSOR_POS + ), + format!( + "alter table instruments alter column {} set default", + CURSOR_POS + ), + format!("alter table instruments alter {} set default", CURSOR_POS), + format!("alter table public.instruments alter column {}", CURSOR_POS), + format!("alter table instruments alter {}", CURSOR_POS), + format!("alter table instruments rename {} to new_col", CURSOR_POS), + format!( + "alter table public.instruments rename column {} to new_col", + CURSOR_POS + ), + ]; + + for query in queries { + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + None, + &pool, + ) + .await; + } + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 0be9e48a..ea681bd7 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -74,9 +74,13 @@ impl CompletionFilter<'_> { .map(|clause| { match self.data { CompletionRelevanceData::Table(_) => match clause { - WrappingClause::Select - | WrappingClause::Where - | WrappingClause::ColumnDefinitions => false, + WrappingClause::From | WrappingClause::Update => true, + + WrappingClause::Join { on_node: None } => true, + WrappingClause::Join { on_node: Some(on) } => ctx + .node_under_cursor + .as_ref() + .is_some_and(|cn| cn.start_byte() < on.end_byte()), WrappingClause::Insert => { ctx.wrapping_node_kind @@ -94,15 +98,22 @@ impl CompletionFilter<'_> { "keyword_table", ]), - _ => true, + _ => false, }, CompletionRelevanceData::Column(_) => { match clause { - WrappingClause::From - | WrappingClause::ColumnDefinitions - | WrappingClause::AlterTable - | WrappingClause::DropTable => false, + WrappingClause::Select + | WrappingClause::Update + | WrappingClause::Delete + | WrappingClause::DropColumn => true, + + WrappingClause::RenameColumn => ctx + .before_cursor_matches_kind(&["keyword_rename", "keyword_column"]), + + WrappingClause::AlterColumn => { + ctx.before_cursor_matches_kind(&["keyword_alter", "keyword_column"]) + } // We can complete columns in JOIN cluases, but only if we are after the // ON node in the "ON u.id = posts.user_id" part. @@ -126,7 +137,7 @@ impl CompletionFilter<'_> { && ctx.parent_matches_one_of_kind(&["field"])) } - _ => true, + _ => false, } } diff --git a/crates/pgt_treesitter_queries/src/queries/relations.rs b/crates/pgt_treesitter_queries/src/queries/relations.rs index 38fd0513..2d7e4431 100644 --- a/crates/pgt_treesitter_queries/src/queries/relations.rs +++ b/crates/pgt_treesitter_queries/src/queries/relations.rs @@ -22,6 +22,16 @@ static TS_QUERY: LazyLock = LazyLock::new(|| { (identifier)? @table )+ ) + (alter_table + (keyword_alter) + (keyword_table) + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + )+ + ) "#; tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") }); @@ -196,4 +206,50 @@ mod tests { assert_eq!(results[0].get_schema(sql), None); assert_eq!(results[0].get_table(sql), "users"); } + + #[test] + fn finds_alter_table_with_schema() { + let sql = r#"alter table public.users alter some_col set default 15;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("public".into())); + assert_eq!(results[0].get_table(sql), "users"); + } + + #[test] + fn finds_alter_table_without_schema() { + let sql = r#"alter table users alter some_col set default 15;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), None); + assert_eq!(results[0].get_table(sql), "users"); + } }