diff --git a/crates/pgt_lexer/src/lexer.rs b/crates/pgt_lexer/src/lexer.rs index 3e6912295..03d5693b1 100644 --- a/crates/pgt_lexer/src/lexer.rs +++ b/crates/pgt_lexer/src/lexer.rs @@ -143,7 +143,7 @@ impl<'a> Lexer<'a> { } _ => {} }; - SyntaxKind::POSITIONAL_PARAM + SyntaxKind::NAMED_PARAM } pgt_tokenizer::TokenKind::QuotedIdent { terminated } => { if !terminated { diff --git a/crates/pgt_lexer_codegen/src/syntax_kind.rs b/crates/pgt_lexer_codegen/src/syntax_kind.rs index 3a0054374..6e79d0b95 100644 --- a/crates/pgt_lexer_codegen/src/syntax_kind.rs +++ b/crates/pgt_lexer_codegen/src/syntax_kind.rs @@ -65,6 +65,9 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream { let mut enum_variants: Vec = Vec::new(); let mut from_kw_match_arms: Vec = Vec::new(); + let mut is_kw_match_arms: Vec = Vec::new(); + + let mut is_trivia_match_arms: Vec = Vec::new(); // collect keywords for kw in &all_keywords { @@ -78,18 +81,30 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream { from_kw_match_arms.push(quote! { #kw => Some(SyntaxKind::#kind_ident) }); + is_kw_match_arms.push(quote! { + SyntaxKind::#kind_ident => true + }); } // collect extra keywords EXTRA.iter().for_each(|&name| { let variant_name = format_ident!("{}", name); enum_variants.push(quote! { #variant_name }); + + if name == "COMMENT" { + is_trivia_match_arms.push(quote! { + SyntaxKind::#variant_name => true + }); + } }); // collect whitespace variants WHITESPACE.iter().for_each(|&name| { let variant_name = format_ident!("{}", name); enum_variants.push(quote! { #variant_name }); + is_trivia_match_arms.push(quote! { + SyntaxKind::#variant_name => true + }); }); // collect punctuations @@ -119,6 +134,20 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream { _ => None } } + + pub fn is_keyword(&self) -> bool { + match self { + #(#is_kw_match_arms),*, + _ => false + } + } + + pub fn is_trivia(&self) -> bool { + match self { + #(#is_trivia_match_arms),*, + _ => false + } + } } } } diff --git a/crates/pgt_tokenizer/src/lib.rs b/crates/pgt_tokenizer/src/lib.rs index 16093db8f..14c36f091 100644 --- a/crates/pgt_tokenizer/src/lib.rs +++ b/crates/pgt_tokenizer/src/lib.rs @@ -668,6 +668,13 @@ mod tests { assert_debug_snapshot!(result); } + #[test] + fn graphile_named_param() { + let result = + lex("grant usage on schema public, app_public, app_hidden to :DATABASE_VISITOR;"); + assert_debug_snapshot!(result); + } + #[test] fn named_param_dollar_raw() { let result = lex("select 1 from c where id = $id;"); diff --git a/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__graphile_named_param.snap b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__graphile_named_param.snap new file mode 100644 index 000000000..a2ddd008f --- /dev/null +++ b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__graphile_named_param.snap @@ -0,0 +1,27 @@ +--- +source: crates/pgt_tokenizer/src/lib.rs +expression: result +snapshot_kind: text +--- +[ + "grant" @ Ident, + " " @ Space, + "usage" @ Ident, + " " @ Space, + "on" @ Ident, + " " @ Space, + "schema" @ Ident, + " " @ Space, + "public" @ Ident, + "," @ Comma, + " " @ Space, + "app_public" @ Ident, + "," @ Comma, + " " @ Space, + "app_hidden" @ Ident, + " " @ Space, + "to" @ Ident, + " " @ Space, + ":DATABASE_VISITOR" @ NamedParam { kind: ColonRaw }, + ";" @ Semi, +] diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 0494c5219..a198ae3f7 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -3,9 +3,9 @@ use std::num::NonZeroUsize; use std::sync::{Arc, LazyLock, Mutex}; use lru::LruCache; +use pgt_lexer::lex; use pgt_query_ext::diagnostics::*; use pgt_text_size::TextRange; -use pgt_tokenizer::tokenize; use regex::Regex; use super::statement_identifier::StatementId; @@ -104,6 +104,27 @@ fn is_composite_type_error(err: &str) -> bool { COMPOSITE_TYPE_ERROR_RE.is_match(err) } +// Keywords that, when preceding a named parameter, indicate that the parameter should be treated +// as an identifier rather than a positional parameter. +const IDENTIFIER_CONTEXT: [pgt_lexer::SyntaxKind; 15] = [ + pgt_lexer::SyntaxKind::TO_KW, + pgt_lexer::SyntaxKind::FROM_KW, + pgt_lexer::SyntaxKind::SCHEMA_KW, + pgt_lexer::SyntaxKind::TABLE_KW, + pgt_lexer::SyntaxKind::INDEX_KW, + pgt_lexer::SyntaxKind::CONSTRAINT_KW, + pgt_lexer::SyntaxKind::OWNER_KW, + pgt_lexer::SyntaxKind::ROLE_KW, + pgt_lexer::SyntaxKind::USER_KW, + pgt_lexer::SyntaxKind::DATABASE_KW, + pgt_lexer::SyntaxKind::TYPE_KW, + pgt_lexer::SyntaxKind::CAST_KW, + pgt_lexer::SyntaxKind::ALTER_KW, + pgt_lexer::SyntaxKind::DROP_KW, + // for schema.table style identifiers + pgt_lexer::SyntaxKind::DOT, +]; + /// Converts named parameters in a SQL query string to positional parameters. /// /// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`) @@ -116,13 +137,16 @@ pub fn convert_to_positional_params(text: &str) -> String { let mut result = String::with_capacity(text.len()); let mut param_mapping: HashMap<&str, usize> = HashMap::new(); let mut param_index = 1; - let mut position = 0; - for token in tokenize(text) { - let token_len = token.len as usize; - let token_text = &text[position..position + token_len]; + let lexed = lex(text); + for (token_idx, kind) in lexed.tokens().enumerate() { + if kind == pgt_lexer::SyntaxKind::EOF { + break; + } + + let token_text = lexed.text(token_idx); - if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { + if matches!(kind, pgt_lexer::SyntaxKind::NAMED_PARAM) { let idx = match param_mapping.get(token_text) { Some(&index) => index, None => { @@ -133,7 +157,16 @@ pub fn convert_to_positional_params(text: &str) -> String { } }; - let replacement = format!("${}", idx); + // find previous non-trivia token + let prev_token = (0..token_idx) + .rev() + .map(|i| lexed.kind(i)) + .find(|kind| !kind.is_trivia()); + + let replacement = match prev_token { + Some(k) if IDENTIFIER_CONTEXT.contains(&k) => deterministic_identifier(idx - 1), + _ => format!("${}", idx), + }; let original_len = token_text.len(); let replacement_len = replacement.len(); @@ -146,17 +179,45 @@ pub fn convert_to_positional_params(text: &str) -> String { } else { result.push_str(token_text); } - - position += token_len; } result } +const ALPHABET: [char; 26] = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', + 't', 'u', 'v', 'w', 'x', 'y', 'z', +]; + +/// Generates a deterministic identifier based on the given index. +fn deterministic_identifier(idx: usize) -> String { + let iteration = idx / ALPHABET.len(); + let pos = idx % ALPHABET.len(); + + format!( + "{}{}", + ALPHABET[pos], + if iteration > 0 { + deterministic_identifier(iteration - 1) + } else { + "".to_string() + } + ) +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_deterministic_identifier() { + assert_eq!(deterministic_identifier(0), "a"); + assert_eq!(deterministic_identifier(25), "z"); + assert_eq!(deterministic_identifier(26), "aa"); + assert_eq!(deterministic_identifier(27), "ba"); + assert_eq!(deterministic_identifier(51), "za"); + } + #[test] fn test_convert_to_positional_params() { let input = "select * from users where id = @one and name = :two and email = :'three';"; @@ -177,6 +238,24 @@ mod tests { ); } + #[test] + fn test_positional_params_in_grant() { + let input = "grant usage on schema public, app_public, app_hidden to :DB_ROLE;"; + + let result = convert_to_positional_params(input); + + assert_eq!( + result, + "grant usage on schema public, app_public, app_hidden to a ;" + ); + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_ast(&StatementId::new(input)); + + assert!(res.is_ok()); + } + #[test] fn test_plpgsql_syntax_error() { let input = "