Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/pgt_lexer/src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<'a> Lexer<'a> {
}
_ => {}
};
SyntaxKind::POSITIONAL_PARAM
SyntaxKind::NAMED_PARAM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

das war n bug nehm ich an?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ja, allerdings kein wichtiger für die Statement Splitter Logik deshalb kam es nie raus

}
pgt_tokenizer::TokenKind::QuotedIdent { terminated } => {
if !terminated {
Expand Down
29 changes: 29 additions & 0 deletions crates/pgt_lexer_codegen/src/syntax_kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {

let mut enum_variants: Vec<TokenStream> = Vec::new();
let mut from_kw_match_arms: Vec<TokenStream> = Vec::new();
let mut is_kw_match_arms: Vec<TokenStream> = Vec::new();

let mut is_trivia_match_arms: Vec<TokenStream> = Vec::new();

// collect keywords
for kw in &all_keywords {
Expand All @@ -78,18 +81,30 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {
from_kw_match_arms.push(quote! {
#kw => Some(SyntaxKind::#kind_ident)
});
is_kw_match_arms.push(quote! {
SyntaxKind::#kind_ident => true
});
}

// collect extra keywords
EXTRA.iter().for_each(|&name| {
let variant_name = format_ident!("{}", name);
enum_variants.push(quote! { #variant_name });

if name == "COMMENT" {
is_trivia_match_arms.push(quote! {
SyntaxKind::#variant_name => true
});
}
});

// collect whitespace variants
WHITESPACE.iter().for_each(|&name| {
let variant_name = format_ident!("{}", name);
enum_variants.push(quote! { #variant_name });
is_trivia_match_arms.push(quote! {
SyntaxKind::#variant_name => true
});
});

// collect punctuations
Expand Down Expand Up @@ -119,6 +134,20 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {
_ => None
}
}

pub fn is_keyword(&self) -> bool {
match self {
#(#is_kw_match_arms),*,
_ => false
}
}

pub fn is_trivia(&self) -> bool {
match self {
#(#is_trivia_match_arms),*,
_ => false
}
}
}
}
}
7 changes: 7 additions & 0 deletions crates/pgt_tokenizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,13 @@ mod tests {
assert_debug_snapshot!(result);
}

#[test]
fn graphile_named_param() {
let result =
lex("grant usage on schema public, app_public, app_hidden to :DATABASE_VISITOR;");
assert_debug_snapshot!(result);
}

#[test]
fn named_param_dollar_raw() {
let result = lex("select 1 from c where id = $id;");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
source: crates/pgt_tokenizer/src/lib.rs
expression: result
snapshot_kind: text
---
[
"grant" @ Ident,
" " @ Space,
"usage" @ Ident,
" " @ Space,
"on" @ Ident,
" " @ Space,
"schema" @ Ident,
" " @ Space,
"public" @ Ident,
"," @ Comma,
" " @ Space,
"app_public" @ Ident,
"," @ Comma,
" " @ Space,
"app_hidden" @ Ident,
" " @ Space,
"to" @ Ident,
" " @ Space,
":DATABASE_VISITOR" @ NamedParam { kind: ColonRaw },
";" @ Semi,
]
97 changes: 88 additions & 9 deletions crates/pgt_workspace/src/workspace/server/pg_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::num::NonZeroUsize;
use std::sync::{Arc, LazyLock, Mutex};

use lru::LruCache;
use pgt_lexer::lex;
use pgt_query_ext::diagnostics::*;
use pgt_text_size::TextRange;
use pgt_tokenizer::tokenize;
use regex::Regex;

use super::statement_identifier::StatementId;
Expand Down Expand Up @@ -104,6 +104,27 @@ fn is_composite_type_error(err: &str) -> bool {
COMPOSITE_TYPE_ERROR_RE.is_match(err)
}

// Keywords that, when preceding a named parameter, indicate that the parameter should be treated
// as an identifier rather than a positional parameter.
const IDENTIFIER_CONTEXT: [pgt_lexer::SyntaxKind; 15] = [
pgt_lexer::SyntaxKind::TO_KW,
pgt_lexer::SyntaxKind::FROM_KW,
pgt_lexer::SyntaxKind::SCHEMA_KW,
pgt_lexer::SyntaxKind::TABLE_KW,
pgt_lexer::SyntaxKind::INDEX_KW,
pgt_lexer::SyntaxKind::CONSTRAINT_KW,
pgt_lexer::SyntaxKind::OWNER_KW,
pgt_lexer::SyntaxKind::ROLE_KW,
pgt_lexer::SyntaxKind::USER_KW,
pgt_lexer::SyntaxKind::DATABASE_KW,
pgt_lexer::SyntaxKind::TYPE_KW,
pgt_lexer::SyntaxKind::CAST_KW,
pgt_lexer::SyntaxKind::ALTER_KW,
pgt_lexer::SyntaxKind::DROP_KW,
// for schema.table style identifiers
pgt_lexer::SyntaxKind::DOT,
];

/// Converts named parameters in a SQL query string to positional parameters.
///
/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`)
Expand All @@ -116,13 +137,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut param_mapping: HashMap<&str, usize> = HashMap::new();
let mut param_index = 1;
let mut position = 0;

for token in tokenize(text) {
let token_len = token.len as usize;
let token_text = &text[position..position + token_len];
let lexed = lex(text);
for (token_idx, kind) in lexed.tokens().enumerate() {
if kind == pgt_lexer::SyntaxKind::EOF {
break;
}

let token_text = lexed.text(token_idx);

if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) {
if matches!(kind, pgt_lexer::SyntaxKind::NAMED_PARAM) {
let idx = match param_mapping.get(token_text) {
Some(&index) => index,
None => {
Expand All @@ -133,7 +157,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
}
};

let replacement = format!("${}", idx);
// find previous non-trivia token
let prev_token = (0..token_idx)
.rev()
.map(|i| lexed.kind(i))
.find(|kind| !kind.is_trivia());

let replacement = match prev_token {
Some(k) if IDENTIFIER_CONTEXT.contains(&k) => deterministic_identifier(idx - 1),
_ => format!("${}", idx),
};
let original_len = token_text.len();
let replacement_len = replacement.len();

Expand All @@ -146,17 +179,45 @@ pub fn convert_to_positional_params(text: &str) -> String {
} else {
result.push_str(token_text);
}

position += token_len;
}

result
}

const ALPHABET: [char; 26] = [
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
't', 'u', 'v', 'w', 'x', 'y', 'z',
];

/// Generates a deterministic identifier based on the given index.
fn deterministic_identifier(idx: usize) -> String {
let iteration = idx / ALPHABET.len();
let pos = idx % ALPHABET.len();

format!(
"{}{}",
ALPHABET[pos],
if iteration > 0 {
deterministic_identifier(iteration - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i love me a good recursion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

das kam sogar ohne Claude 😂

} else {
"".to_string()
}
)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_deterministic_identifier() {
assert_eq!(deterministic_identifier(0), "a");
assert_eq!(deterministic_identifier(25), "z");
assert_eq!(deterministic_identifier(26), "aa");
assert_eq!(deterministic_identifier(27), "ba");
assert_eq!(deterministic_identifier(51), "za");
}

#[test]
fn test_convert_to_positional_params() {
let input = "select * from users where id = @one and name = :two and email = :'three';";
Expand All @@ -177,6 +238,24 @@ mod tests {
);
}

#[test]
fn test_positional_params_in_grant() {
let input = "grant usage on schema public, app_public, app_hidden to :DB_ROLE;";

let result = convert_to_positional_params(input);

assert_eq!(
result,
"grant usage on schema public, app_public, app_hidden to a ;"
);

let store = PgQueryStore::new();

let res = store.get_or_cache_ast(&StatementId::new(input));

assert!(res.is_ok());
}

#[test]
fn test_plpgsql_syntax_error() {
let input = "
Expand Down