diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index db21e498..6ace55b6 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -13,6 +13,7 @@ pub enum ClauseType { Select, Where, From, + Join, Update, Delete, } @@ -33,6 +34,7 @@ impl TryFrom<&str> for ClauseType { "from" => Ok(Self::From), "update" => Ok(Self::Update), "delete" => Ok(Self::Delete), + "join" => Ok(Self::Join), _ => { let message = format!("Unimplemented ClauseType: {}", value); @@ -106,7 +108,25 @@ pub(crate) struct CompletionContext<'a> { pub schema_cache: &'a SchemaCache, pub position: usize, - pub schema_name: Option, + /// If the cursor is on a node that uses dot notation + /// to specify an alias or schema, this will hold the schema's or + /// alias's name. + /// + /// Here, `auth` is a schema name: + /// ```sql + /// select * from auth.users; + /// ``` + /// + /// Here, `u` is an alias name: + /// ```sql + /// select + /// * + /// from + /// auth.users u + /// left join identities i + /// on u.id = i.user_id; + /// ``` + pub schema_or_alias_name: Option, pub wrapping_clause_type: Option, pub wrapping_node_kind: Option, @@ -114,6 +134,9 @@ pub(crate) struct CompletionContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option, + /// Some incomplete statements can't be correctly parsed by TreeSitter. + pub is_in_error_node: bool, + pub mentioned_relations: HashMap, HashSet>, pub mentioned_table_aliases: HashMap, @@ -127,13 +150,14 @@ impl<'a> CompletionContext<'a> { schema_cache: params.schema, position: usize::from(params.position), node_under_cursor: None, - schema_name: None, + schema_or_alias_name: None, wrapping_clause_type: None, wrapping_node_kind: None, wrapping_statement_range: None, is_invocation: false, mentioned_relations: HashMap::new(), mentioned_table_aliases: HashMap::new(), + is_in_error_node: false, }; ctx.gather_tree_context(); @@ -246,19 +270,58 @@ impl<'a> CompletionContext<'a> { self.wrapping_statement_range = Some(parent_node.range()); } "invocation" => self.is_invocation = true, - _ => {} } + // try to gather context from the siblings if we're within an error node. + if self.is_in_error_node { + let mut next_sibling = current_node.next_named_sibling(); + while let Some(n) = next_sibling { + if n.kind().starts_with("keyword_") { + if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) { + match txt { + "where" | "update" | "select" | "delete" | "from" | "join" => { + self.wrapping_clause_type = txt.try_into().ok(); + break; + } + _ => {} + } + }; + } + next_sibling = n.next_named_sibling(); + } + let mut prev_sibling = current_node.prev_named_sibling(); + while let Some(n) = prev_sibling { + if n.kind().starts_with("keyword_") { + if let Some(txt) = self.get_ts_node_content(n).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) { + match txt { + "where" | "update" | "select" | "delete" | "from" | "join" => { + self.wrapping_clause_type = txt.try_into().ok(); + break; + } + _ => {} + } + }; + } + prev_sibling = n.prev_named_sibling(); + } + } + match current_node_kind { - "object_reference" => { + "object_reference" | "field" => { let content = self.get_ts_node_content(current_node); if let Some(node_txt) = content { match node_txt { NodeText::Original(txt) => { let parts: Vec<&str> = txt.split('.').collect(); if parts.len() == 2 { - self.schema_name = Some(parts[0].to_string()); + self.schema_or_alias_name = Some(parts[0].to_string()); } } NodeText::Replaced => {} @@ -266,7 +329,7 @@ impl<'a> CompletionContext<'a> { } } - "where" | "update" | "select" | "delete" | "from" => { + "where" | "update" | "select" | "delete" | "from" | "join" => { self.wrapping_clause_type = current_node_kind.try_into().ok(); } @@ -274,6 +337,10 @@ impl<'a> CompletionContext<'a> { self.wrapping_node_kind = current_node_kind.try_into().ok(); } + "ERROR" => { + self.is_in_error_node = true; + } + _ => {} } @@ -380,7 +447,10 @@ mod tests { let ctx = CompletionContext::new(¶ms); - assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); + assert_eq!( + ctx.schema_or_alias_name, + expected_schema.map(|f| f.to_string()) + ); } } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 6ac3c989..770a2b61 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -28,7 +28,10 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, - test_helper::{CURSOR_POS, InputQuery, get_test_deps, get_test_params}, + test_helper::{ + CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps, + get_test_params, + }, }; struct TestCase { @@ -168,9 +171,9 @@ mod tests { ("name", "Table: public.users"), ("narrator", "Table: public.audio_books"), ("narrator_id", "Table: private.audio_books"), + ("id", "Table: public.audio_books"), ("name", "Schema: pg_catalog"), ("nameconcatoid", "Schema: pg_catalog"), - ("nameeq", "Schema: pg_catalog"), ] .into_iter() .map(|(label, schema)| LabelAndDesc { @@ -325,4 +328,107 @@ mod tests { ); } } + + #[tokio::test] + async fn filters_out_by_aliases() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + // test in SELECT clause + assert_complete_results( + format!( + "select u.id, p.{} from auth.users u join auth.posts p on u.id = p.user_id;", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelNotExists("uid".to_string()), + CompletionAssertion::LabelNotExists("name".to_string()), + CompletionAssertion::LabelNotExists("email".to_string()), + CompletionAssertion::Label("content".to_string()), + CompletionAssertion::Label("created_at".to_string()), + CompletionAssertion::Label("pid".to_string()), + CompletionAssertion::Label("title".to_string()), + CompletionAssertion::Label("user_id".to_string()), + ], + setup, + ) + .await; + + // test in JOIN clause + assert_complete_results( + format!( + "select u.id, p.content from auth.users u join auth.posts p on u.id = p.{};", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelNotExists("uid".to_string()), + CompletionAssertion::LabelNotExists("name".to_string()), + CompletionAssertion::LabelNotExists("email".to_string()), + // primary keys are preferred + CompletionAssertion::Label("pid".to_string()), + CompletionAssertion::Label("content".to_string()), + CompletionAssertion::Label("created_at".to_string()), + CompletionAssertion::Label("title".to_string()), + CompletionAssertion::Label("user_id".to_string()), + ], + setup, + ) + .await; + } + + #[tokio::test] + async fn does_not_complete_cols_in_join_clauses() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + /* + * We are not in the "ON" part of the JOIN clause, so we should not complete columns. + */ + assert_complete_results( + format!( + "select u.id, p.content from auth.users u join auth.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Column), + CompletionAssertion::LabelAndKind("posts".to_string(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index 2e4ef8a9..c0fe5869 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -7,7 +7,7 @@ pub(crate) fn get_completion_text_with_schema( item_name: &str, item_schema_name: &str, ) -> Option { - if item_schema_name == "public" || ctx.schema_name.is_some() { + if item_schema_name == "public" || ctx.schema_or_alias_name.is_some() { None } else { let node = ctx.node_under_cursor.unwrap(); diff --git a/crates/pgt_completions/src/providers/schemas.rs b/crates/pgt_completions/src/providers/schemas.rs index c28f831e..aaa5ebe6 100644 --- a/crates/pgt_completions/src/providers/schemas.rs +++ b/crates/pgt_completions/src/providers/schemas.rs @@ -59,6 +59,8 @@ mod tests { "private".to_string(), CompletionItemKind::Schema, ), + // users table still preferred over system schemas + CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), CompletionAssertion::LabelAndKind( "information_schema".to_string(), CompletionItemKind::Schema, @@ -71,7 +73,6 @@ mod tests { "pg_toast".to_string(), CompletionItemKind::Schema, ), - CompletionAssertion::LabelAndKind("users".to_string(), CompletionItemKind::Table), ], setup, ) diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index f9f922d1..cbedc55b 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -273,4 +273,37 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_tables_in_join() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + assert_complete_results( + format!("select * from auth.users u join {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 69939e0b..2658216b 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{ClauseType, CompletionContext}; +use crate::context::{ClauseType, CompletionContext, WrappingNode}; use super::CompletionRelevanceData; @@ -18,7 +18,7 @@ impl CompletionFilter<'_> { self.completable_context(ctx)?; self.check_clause(ctx)?; self.check_invocation(ctx)?; - self.check_mentioned_schema(ctx)?; + self.check_mentioned_schema_or_alias(ctx)?; Some(()) } @@ -50,6 +50,7 @@ impl CompletionFilter<'_> { fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { let clause = ctx.wrapping_clause_type.as_ref(); + let wrapping_node = ctx.wrapping_node_kind.as_ref(); match self.data { CompletionRelevanceData::Table(_) => { @@ -62,10 +63,20 @@ impl CompletionFilter<'_> { } CompletionRelevanceData::Column(_) => { let in_from_clause = clause.is_some_and(|c| c == &ClauseType::From); - if in_from_clause { return None; } + + // We can complete columns in JOIN cluases, but only if we are in the + // "ON u.id = posts.user_id" part. + let in_join_clause = clause.is_some_and(|c| c == &ClauseType::Join); + + let in_comparison_clause = + wrapping_node.is_some_and(|n| n == &WrappingNode::BinaryExpression); + + if in_join_clause && !in_comparison_clause { + return None; + } } _ => {} } @@ -86,27 +97,28 @@ impl CompletionFilter<'_> { Some(()) } - fn check_mentioned_schema(&self, ctx: &CompletionContext) -> Option<()> { - if ctx.schema_name.is_none() { + fn check_mentioned_schema_or_alias(&self, ctx: &CompletionContext) -> Option<()> { + if ctx.schema_or_alias_name.is_none() { return Some(()); } - let name = ctx.schema_name.as_ref().unwrap(); + let schema_or_alias = ctx.schema_or_alias_name.as_ref().unwrap(); + + let matches = match self.data { + CompletionRelevanceData::Table(table) => &table.schema == schema_or_alias, + CompletionRelevanceData::Function(f) => &f.schema == schema_or_alias, + CompletionRelevanceData::Column(col) => ctx + .mentioned_table_aliases + .get(schema_or_alias) + .is_some_and(|t| t == &col.table_name), - let does_not_match = match self.data { - CompletionRelevanceData::Table(table) => &table.schema != name, - CompletionRelevanceData::Function(f) => &f.schema != name, - CompletionRelevanceData::Column(_) => { - // columns belong to tables, not schemas - true - } CompletionRelevanceData::Schema(_) => { // we should never allow schema suggestions if there already was one. - true + false } }; - if does_not_match { + if !matches { return None; } diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 2ef8edb6..e67df658 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -62,13 +62,19 @@ impl CompletionScore<'_> { }; let has_mentioned_tables = !ctx.mentioned_relations.is_empty(); - let has_mentioned_schema = ctx.schema_name.is_some(); + let has_mentioned_schema = ctx.schema_or_alias_name.is_some(); + + let is_binary_exp = ctx + .wrapping_node_kind + .as_ref() + .is_some_and(|wn| wn == &WrappingNode::BinaryExpression); self.score += match self.data { CompletionRelevanceData::Table(_) => match clause_type { - ClauseType::From => 5, ClauseType::Update => 10, ClauseType::Delete => 10, + ClauseType::From => 5, + ClauseType::Join if !is_binary_exp => 5, _ => -50, }, CompletionRelevanceData::Function(_) => match clause_type { @@ -77,14 +83,19 @@ impl CompletionScore<'_> { ClauseType::From => 0, _ => -50, }, - CompletionRelevanceData::Column(_) => match clause_type { + CompletionRelevanceData::Column(col) => match clause_type { ClauseType::Select if has_mentioned_tables => 10, ClauseType::Select if !has_mentioned_tables => 0, ClauseType::Where => 10, + ClauseType::Join if is_binary_exp => { + // Users will probably join on primary keys + if col.is_primary_key { 20 } else { 10 } + } _ => -15, }, CompletionRelevanceData::Schema(_) => match clause_type { ClauseType::From if !has_mentioned_schema => 15, + ClauseType::Join if !has_mentioned_schema => 15, ClauseType::Update if !has_mentioned_schema => 15, ClauseType::Delete if !has_mentioned_schema => 15, _ => -50, @@ -98,7 +109,7 @@ impl CompletionScore<'_> { Some(wn) => wn, }; - let has_mentioned_schema = ctx.schema_name.is_some(); + let has_mentioned_schema = ctx.schema_or_alias_name.is_some(); let has_node_text = ctx.get_node_under_cursor_content().is_some(); self.score += match self.data { @@ -135,7 +146,7 @@ impl CompletionScore<'_> { } fn check_matches_schema(&mut self, ctx: &CompletionContext) { - let schema_name = match ctx.schema_name.as_ref() { + let schema_name = match ctx.schema_or_alias_name.as_ref() { None => return, Some(n) => n, }; @@ -199,7 +210,7 @@ impl CompletionScore<'_> { let system_schemas = ["pg_catalog", "information_schema", "pg_toast"]; if system_schemas.contains(&schema.as_str()) { - self.score -= 10; + self.score -= 20; } // "public" is the default postgres schema where users diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 5eb5f53f..a6b57c55 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -146,17 +146,45 @@ mod tests { pub(crate) enum CompletionAssertion { Label(String), LabelAndKind(String, CompletionItemKind), + LabelNotExists(String), + KindNotExists(CompletionItemKind), } impl CompletionAssertion { - fn assert_eq(self, item: CompletionItem) { + fn assert(&self, item: &CompletionItem) { match self { CompletionAssertion::Label(label) => { - assert_eq!(item.label, label); + assert_eq!( + &item.label, label, + "Expected label to be {}, but got {}", + label, &item.label + ); } CompletionAssertion::LabelAndKind(label, kind) => { - assert_eq!(item.label, label); - assert_eq!(item.kind, kind); + assert_eq!( + &item.label, label, + "Expected label to be {}, but got {}", + label, &item.label + ); + assert_eq!( + &item.kind, kind, + "Expected kind to be {:?}, but got {:?}", + kind, &item.kind + ); + } + CompletionAssertion::LabelNotExists(label) => { + assert_ne!( + &item.label, label, + "Expected label {} not to exist, but found it", + label + ); + } + CompletionAssertion::KindNotExists(kind) => { + assert_ne!( + &item.kind, kind, + "Expected kind {:?} not to exist, but found it", + kind + ); } } } @@ -171,11 +199,30 @@ pub(crate) async fn assert_complete_results( let params = get_test_params(&tree, &cache, query.into()); let items = complete(params); - assertions + let (not_existing, existing): (Vec, Vec) = + assertions.into_iter().partition(|a| match a { + CompletionAssertion::LabelNotExists(_) | CompletionAssertion::KindNotExists(_) => true, + CompletionAssertion::Label(_) | CompletionAssertion::LabelAndKind(_, _) => false, + }); + + assert!( + items.len() >= existing.len(), + "Not enough items returned. Expected at least {} items, but got {}", + existing.len(), + items.len() + ); + + for item in &items { + for assertion in ¬_existing { + assertion.assert(item); + } + } + + existing .into_iter() .zip(items.into_iter()) .for_each(|(assertion, result)| { - assertion.assert_eq(result); + assertion.assert(&result); }); }