diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 8aa24265..20218378 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -79,15 +79,6 @@ jobs: lint: name: Lint Project runs-on: ubuntu-latest - services: - postgres: - image: postgres:latest - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: postgres - ports: - - 5432:5432 steps: - name: Checkout PR Branch uses: actions/checkout@v4 @@ -103,6 +94,24 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # we need to use the same database as we do locally for sqlx prepare to output the same hashes + - name: Build and start PostgreSQL with plpgsql_check + run: | + docker build -t postgres-plpgsql-check:latest . + docker run -d --name postgres \ + -e POSTGRES_USER=postgres \ + -e POSTGRES_PASSWORD=postgres \ + -e POSTGRES_DB=postgres \ + -p 5432:5432 \ + postgres-plpgsql-check:latest + # Wait for postgres to be ready + for _ in {1..30}; do + if docker exec postgres pg_isready -U postgres; then + break + fi + sleep 1 + done + - name: Setup sqlx-cli run: cargo install sqlx-cli @@ -154,13 +163,37 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # running containers via `services` only works on linux - # https://github.com/actions/runner/issues/1866 - - name: Setup postgres + # For Linux, use custom Docker image with plpgsql_check + - name: Build and start PostgreSQL with plpgsql_check + if: runner.os == 'Linux' + run: | + docker build -t postgres-plpgsql-check:latest . + docker run -d --name postgres \ + -e POSTGRES_USER=postgres \ + -e POSTGRES_PASSWORD=postgres \ + -e POSTGRES_DB=postgres \ + -p 5432:5432 \ + postgres-plpgsql-check:latest + # Wait for postgres to be ready + for _ in {1..30}; do + if docker exec postgres pg_isready -U postgres; then + break + fi + sleep 1 + done + # For Windows, use the action since PostgreSQL Docker image doesn't support Windows containers + - name: Setup postgres (Windows) + if: runner.os == 'Windows' id: postgres uses: ikalnytskyi/action-setup-postgres@v7 - name: Print Roles - run: psql ${{ steps.postgres.outputs.connection-uri }} -c "select rolname from pg_roles;" + run: | + if [[ "$RUNNER_OS" == "Linux" ]]; then + docker exec postgres psql -U postgres -c "select rolname from pg_roles;" + else + psql ${{ steps.postgres.outputs.connection-uri }} -c "select rolname from pg_roles;" + fi + shell: bash - name: Run tests run: cargo test --workspace diff --git a/.sqlx/query-277e47bf46f8331549f55c8a0ebae6f3075c4f754cd379b0555c205fff95a95c.json b/.sqlx/query-277e47bf46f8331549f55c8a0ebae6f3075c4f754cd379b0555c205fff95a95c.json new file mode 100644 index 00000000..db3f4a73 --- /dev/null +++ b/.sqlx/query-277e47bf46f8331549f55c8a0ebae6f3075c4f754cd379b0555c205fff95a95c.json @@ -0,0 +1,50 @@ +{ + "db_name": "PostgreSQL", + "query": "-- we need to join tables from the pg_catalog since \"TRUNCATE\" triggers are\n-- not available in the information_schema.trigger table.\nselect\n t.tgname as \"name!\",\n c.relname as \"table_name!\",\n p.proname as \"proc_name!\",\n proc_ns.nspname as \"proc_schema!\",\n table_ns.nspname as \"table_schema!\",\n t.tgtype as \"details_bitmask!\"\nfrom\n pg_catalog.pg_trigger t\nleft join pg_catalog.pg_proc p on t.tgfoid = p.oid\nleft join pg_catalog.pg_class c on t.tgrelid = c.oid\nleft join pg_catalog.pg_namespace table_ns on c.relnamespace = table_ns.oid\nleft join pg_catalog.pg_namespace proc_ns on p.pronamespace = proc_ns.oid\nwhere\n t.tgisinternal = false and\n t.tgconstraint = 0;\n", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "name!", + "type_info": "Name" + }, + { + "ordinal": 1, + "name": "table_name!", + "type_info": "Name" + }, + { + "ordinal": 2, + "name": "proc_name!", + "type_info": "Name" + }, + { + "ordinal": 3, + "name": "proc_schema!", + "type_info": "Name" + }, + { + "ordinal": 4, + "name": "table_schema!", + "type_info": "Name" + }, + { + "ordinal": 5, + "name": "details_bitmask!", + "type_info": "Int2" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + true, + true, + true, + true, + false + ] + }, + "hash": "277e47bf46f8331549f55c8a0ebae6f3075c4f754cd379b0555c205fff95a95c" +} diff --git a/.sqlx/query-4ea19fee016f1daeafdc466647d117910b19f540f19393b76aa6434e9d1d8502.json b/.sqlx/query-4ea19fee016f1daeafdc466647d117910b19f540f19393b76aa6434e9d1d8502.json index 4980f4f3..400f031d 100644 --- a/.sqlx/query-4ea19fee016f1daeafdc466647d117910b19f540f19393b76aa6434e9d1d8502.json +++ b/.sqlx/query-4ea19fee016f1daeafdc466647d117910b19f540f19393b76aa6434e9d1d8502.json @@ -90,9 +90,9 @@ "nullable": [ null, true, + false, true, - true, - true, + false, null, null, null, @@ -101,9 +101,9 @@ null, null, null, - true, + false, null, - true + false ] }, "hash": "4ea19fee016f1daeafdc466647d117910b19f540f19393b76aa6434e9d1d8502" diff --git a/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json b/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json deleted file mode 100644 index b6fd2fc8..00000000 --- a/.sqlx/query-df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "-- we need to join tables from the pg_catalog since \"TRUNCATE\" triggers are \n-- not available in the information_schema.trigger table.\nselect \n t.tgname as \"name!\",\n c.relname as \"table_name!\",\n p.proname as \"proc_name!\",\n n.nspname as \"schema_name!\",\n t.tgtype as \"details_bitmask!\"\nfrom \n pg_catalog.pg_trigger t \n left join pg_catalog.pg_proc p on t.tgfoid = p.oid\n left join pg_catalog.pg_class c on t.tgrelid = c.oid\n left join pg_catalog.pg_namespace n on c.relnamespace = n.oid\nwhere \n -- triggers enforcing constraints (e.g. unique fields) should not be included.\n t.tgisinternal = false and \n t.tgconstraint = 0;\n", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "name!", - "type_info": "Name" - }, - { - "ordinal": 1, - "name": "table_name!", - "type_info": "Name" - }, - { - "ordinal": 2, - "name": "proc_name!", - "type_info": "Name" - }, - { - "ordinal": 3, - "name": "schema_name!", - "type_info": "Name" - }, - { - "ordinal": 4, - "name": "details_bitmask!", - "type_info": "Int2" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - true, - true, - true, - false - ] - }, - "hash": "df57cc22f7d63847abce1d0d15675ba8951faa1be2ea6b2bf6714b1aa9127a6f" -} diff --git a/Cargo.lock b/Cargo.lock index d76baca3..49143908 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2976,6 +2976,24 @@ dependencies = [ "quote", ] +[[package]] +name = "pgt_plpgsql_check" +version = "0.0.0" +dependencies = [ + "pgt_console", + "pgt_diagnostics", + "pgt_query", + "pgt_query_ext", + "pgt_schema_cache", + "pgt_test_utils", + "pgt_text_size", + "regex", + "serde", + "serde_json", + "sqlx", + "tree-sitter", +] + [[package]] name = "pgt_query" version = "0.0.0" @@ -3163,6 +3181,7 @@ dependencies = [ "pgt_diagnostics", "pgt_fs", "pgt_lexer", + "pgt_plpgsql_check", "pgt_query", "pgt_query_ext", "pgt_schema_cache", diff --git a/Cargo.toml b/Cargo.toml index e243ab3e..d68aafe0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ pgt_lexer = { path = "./crates/pgt_lexer", version = "0.0.0" } pgt_lexer_codegen = { path = "./crates/pgt_lexer_codegen", version = "0.0.0" } pgt_lsp = { path = "./crates/pgt_lsp", version = "0.0.0" } pgt_markup = { path = "./crates/pgt_markup", version = "0.0.0" } +pgt_plpgsql_check = { path = "./crates/pgt_plpgsql_check", version = "0.0.0" } pgt_query = { path = "./crates/pgt_query", version = "0.0.0" } pgt_query_ext = { path = "./crates/pgt_query_ext", version = "0.0.0" } pgt_query_macros = { path = "./crates/pgt_query_macros", version = "0.0.0" } diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..10353bb2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM postgres:15 + +# Install build dependencies +RUN apt-get update && \ + apt-get install -y postgresql-server-dev-15 gcc make git && \ + cd /tmp && \ + git clone https://github.com/okbob/plpgsql_check.git && \ + cd plpgsql_check && \ + make && \ + make install && \ + apt-get remove -y postgresql-server-dev-15 gcc make git && \ + apt-get autoremove -y && \ + rm -rf /tmp/plpgsql_check /var/lib/apt/lists/* + +# Add initialization script directly +RUN echo "CREATE EXTENSION IF NOT EXISTS plpgsql_check;" > /docker-entrypoint-initdb.d/01-create-extension.sql \ No newline at end of file diff --git a/crates/pgt_diagnostics_categories/src/categories.rs b/crates/pgt_diagnostics_categories/src/categories.rs index b9d29698..14df90b9 100644 --- a/crates/pgt_diagnostics_categories/src/categories.rs +++ b/crates/pgt_diagnostics_categories/src/categories.rs @@ -32,6 +32,7 @@ define_categories! { "flags/invalid", "project", "typecheck", + "plpgsql_check", "internalError/panic", "syntax", "dummy", diff --git a/crates/pgt_plpgsql_check/Cargo.toml b/crates/pgt_plpgsql_check/Cargo.toml new file mode 100644 index 00000000..75d1a52b --- /dev/null +++ b/crates/pgt_plpgsql_check/Cargo.toml @@ -0,0 +1,30 @@ +[package] +authors.workspace = true +categories.workspace = true +description = "" +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +name = "pgt_plpgsql_check" +repository.workspace = true +version = "0.0.0" + + +[dependencies] +pgt_console = { workspace = true } +pgt_diagnostics = { workspace = true } +pgt_query = { workspace = true } +pgt_query_ext = { workspace = true } +pgt_schema_cache = { workspace = true } +pgt_text_size = { workspace = true } +regex = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +tree-sitter = { workspace = true } + +[dev-dependencies] +pgt_test_utils = { workspace = true } + +[lib] diff --git a/crates/pgt_plpgsql_check/src/diagnostics.rs b/crates/pgt_plpgsql_check/src/diagnostics.rs new file mode 100644 index 00000000..a0daec13 --- /dev/null +++ b/crates/pgt_plpgsql_check/src/diagnostics.rs @@ -0,0 +1,245 @@ +use std::io; + +use pgt_console::markup; +use pgt_diagnostics::{Advices, Diagnostic, LogCategory, MessageAndDescription, Severity, Visit}; +use pgt_text_size::TextRange; + +use crate::{PlpgSqlCheckIssue, PlpgSqlCheckResult}; + +/// Find the first occurrence of target text that is not within string literals +fn find_text_outside_strings(text: &str, target: &str) -> Option { + let text_lower = text.to_lowercase(); + let target_lower = target.to_lowercase(); + let mut in_string = false; + let mut quote_char = '\0'; + let bytes = text_lower.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + let ch = bytes[i] as char; + + if !in_string { + // Check if we're starting a string literal + if ch == '\'' || ch == '"' { + in_string = true; + quote_char = ch; + } else { + // Check if we found our target at this position + if text_lower[i..].starts_with(&target_lower) { + // Check if this is a complete word (not part of another identifier) + let is_word_start = + i == 0 || !bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_'; + let target_end = i + target_lower.len(); + let is_word_end = target_end >= bytes.len() + || (!bytes[target_end].is_ascii_alphanumeric() + && bytes[target_end] != b'_'); + + if is_word_start && is_word_end { + return Some(i); + } + } + } + } else { + // We're inside a string literal + if ch == quote_char { + // Check if it's escaped (look for double quotes/apostrophes) + if i + 1 < bytes.len() && bytes[i + 1] as char == quote_char { + // Skip the escaped quote + i += 1; + } else { + // End of string literal + in_string = false; + quote_char = '\0'; + } + } + } + + i += 1; + } + + None +} + +/// A specialized diagnostic for plpgsql_check. +#[derive(Clone, Debug, Diagnostic)] +#[diagnostic(category = "plpgsql_check")] +pub struct PlPgSqlCheckDiagnostic { + #[location(span)] + pub span: Option, + #[description] + #[message] + pub message: MessageAndDescription, + #[advice] + pub advices: PlPgSqlCheckAdvices, + #[severity] + pub severity: Severity, +} + +#[derive(Debug, Clone)] +pub struct PlPgSqlCheckAdvices { + pub code: Option, + /// the relation (table or view) where the issue was found, if applicable + /// only applicable for trigger functions + pub relation: Option, +} + +impl Advices for PlPgSqlCheckAdvices { + fn record(&self, visitor: &mut dyn Visit) -> io::Result<()> { + // Show the error code if available + if let Some(code) = &self.code { + visitor.record_log( + LogCategory::Error, + &markup! { "SQL State: " {code} }, + )?; + } + + // Show relation information if available + if let Some(relation) = &self.relation { + visitor.record_log( + LogCategory::Info, + &markup! { "Relation: " {relation} }, + )?; + } + + Ok(()) + } +} + +/// Convert plpgsql_check results into diagnostics with optional relation info for triggers +pub fn create_diagnostics_from_check_result( + result: &PlpgSqlCheckResult, + fn_body: &str, + offset: usize, + relation: Option, +) -> Vec { + result + .issues + .iter() + .map(|issue| { + let severity = match issue.level.as_str() { + "error" => Severity::Error, + "warning" => Severity::Warning, + "notice" => Severity::Hint, + _ => Severity::Information, + }; + + PlPgSqlCheckDiagnostic { + message: issue.message.clone().into(), + severity, + span: resolve_span(issue, fn_body, offset), + advices: PlPgSqlCheckAdvices { + code: issue.sql_state.clone(), + relation: relation.clone(), + }, + } + }) + .collect() +} + +fn resolve_span(issue: &PlpgSqlCheckIssue, fn_body: &str, offset: usize) -> Option { + let stmt = match issue.statement.as_ref() { + Some(s) => s, + None => { + return Some(TextRange::new( + (offset as u32).into(), + ((offset + fn_body.len()) as u32).into(), + )); + } + }; + + let line_number = stmt + .line_number + .parse::() + .expect("Expected line number to be a valid usize"); + + let text = &stmt.text; + + // calculate the offset to the target line + let line_offset: usize = fn_body + .lines() + .take(line_number - 1) + .map(|line| line.len() + 1) // +1 for newline + .sum(); + + // find the position within the target line + let line = fn_body.lines().nth(line_number - 1)?; + let start = line + .to_lowercase() + .find(&text.to_lowercase()) + .unwrap_or_else(|| { + line.char_indices() + .find_map(|(i, c)| if !c.is_whitespace() { Some(i) } else { None }) + .unwrap_or(0) + }); + + let stmt_offset = line_offset + start; + + if let Some(q) = &issue.query { + // first find the query within the fn body *after* stmt_offset, ignoring string literals + let query_start = find_text_outside_strings(&fn_body[stmt_offset..], &q.text) + .map(|pos| pos + stmt_offset); + + // the position is *within* the query text + let pos = q + .position + .parse::() + .expect("Expected query position to be a valid usize") + - 1; // -1 because the position is 1-based + + let start = query_start? + pos; + + // the range of the diagnostics is the token that `pos` is on + // Find the end of the current token by looking for whitespace or SQL delimiters + let remaining = &fn_body[start..]; + let end = remaining + .char_indices() + .find(|(_, c)| { + c.is_whitespace() || matches!(c, ',' | ';' | ')' | '(' | '=' | '<' | '>') + }) + .map(|(i, _c)| { + i // just the token end, don't include delimiters + }) + .unwrap_or(remaining.len()); + + return Some(TextRange::new( + ((offset + start) as u32).into(), + ((offset + start + end) as u32).into(), + )); + } + + // if no query is present, the end range covers + // - if text is "IF" or "ELSIF", then until the next "THEN" + // - TODO: check "LOOP", "CASE", "WHILE", "EXPECTION" and others + // - else: until the next semicolon or end of line + + if text.to_uppercase() == "IF" || text.to_uppercase() == "ELSIF" { + // Find the position of the next "THEN" after the statement + let remaining = &fn_body[stmt_offset..]; + if let Some(then_pos) = remaining.to_uppercase().find("THEN") { + let end = then_pos + "THEN".len(); + return Some(TextRange::new( + ((offset + stmt_offset) as u32).into(), + ((offset + stmt_offset + end) as u32).into(), + )); + } + } + + // if no specific end is found, use the next semicolon or the end of the line + let remaining = &fn_body[stmt_offset..]; + let end = remaining + .char_indices() + .find(|(_, c)| matches!(c, ';' | '\n' | '\r')) + .map(|(i, c)| { + if c == ';' { + i + 1 // include the semicolon + } else { + i // just the end of the line + } + }) + .unwrap_or(remaining.len()); + + Some(TextRange::new( + ((offset + stmt_offset) as u32).into(), + ((offset + stmt_offset + end) as u32).into(), + )) +} diff --git a/crates/pgt_plpgsql_check/src/lib.rs b/crates/pgt_plpgsql_check/src/lib.rs new file mode 100644 index 00000000..05e2f570 --- /dev/null +++ b/crates/pgt_plpgsql_check/src/lib.rs @@ -0,0 +1,794 @@ +mod diagnostics; + +pub use diagnostics::PlPgSqlCheckDiagnostic; +use diagnostics::create_diagnostics_from_check_result; +use pgt_query::protobuf::CreateFunctionStmt; +use regex::Regex; +use serde::Deserialize; +pub use sqlx::postgres::PgSeverity; +use sqlx::{Acquire, PgPool, Postgres, Transaction}; + +#[derive(Debug)] +pub struct PlPgSqlCheckParams<'a> { + pub conn: &'a PgPool, + pub sql: &'a str, + pub ast: &'a pgt_query::NodeEnum, + pub schema_cache: &'a pgt_schema_cache::SchemaCache, +} + +#[derive(Debug, Deserialize)] +pub struct PlpgSqlCheckResult { + pub function: String, + pub issues: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct PlpgSqlCheckIssue { + pub level: String, + pub message: String, + pub statement: Option, + pub query: Option, + #[serde(rename = "sqlState")] + pub sql_state: Option, +} + +#[derive(Debug, Deserialize)] +pub struct Statement { + #[serde(rename = "lineNumber")] + pub line_number: String, + pub text: String, +} + +#[derive(Debug, Deserialize)] +pub struct Query { + pub position: String, + pub text: String, +} + +/// check if the given node is a plpgsql function that should be checked +fn should_check_function<'a>( + ast: &'a pgt_query::NodeEnum, + schema_cache: &pgt_schema_cache::SchemaCache, +) -> Option<&'a CreateFunctionStmt> { + let create_fn = match ast { + pgt_query::NodeEnum::CreateFunctionStmt(stmt) => stmt, + _ => return None, + }; + + if pgt_query_ext::utils::find_option_value(create_fn, "language") != Some("plpgsql".to_string()) + { + return None; + } + + if !schema_cache + .extensions + .iter() + .any(|e| e.name == "plpgsql_check") + { + return None; + } + + Some(create_fn) +} + +/// check if a function is a trigger function +fn is_trigger_function(create_fn: &CreateFunctionStmt) -> bool { + create_fn + .return_type + .as_ref() + .map(|n| { + matches!( + pgt_query_ext::utils::parse_name(&n.names), + Some((None, name)) if name == "trigger" + ) + }) + .unwrap_or(false) +} + +/// build the function identifier string used by plpgsql_check +fn build_function_identifier( + create_fn: &CreateFunctionStmt, + fn_schema: &Option, + fn_name: &str, +) -> String { + let args = create_fn + .parameters + .iter() + .filter_map(|arg| { + let node = match &arg.node { + Some(pgt_query::NodeEnum::FunctionParameter(n)) => n, + _ => return None, + }; + let type_name_node = node.arg_type.as_ref()?; + let type_name = match pgt_query_ext::utils::parse_name(&type_name_node.names) { + Some((schema, name)) => match schema { + Some(s) => format!("{}.{}", s, name), + None => name, + }, + None => return None, + }; + + if !type_name_node.array_bounds.is_empty() { + Some(format!("{}[]", type_name)) + } else { + Some(type_name) + } + }) + .collect::>(); + + let fn_qualified_name = match fn_schema { + Some(schema) => format!("{}.{}", schema, fn_name), + None => fn_name.to_string(), + }; + + if args.is_empty() { + fn_qualified_name + } else { + format!("{}({})", fn_qualified_name, args.join(", ")) + } +} + +pub async fn check_plpgsql( + params: PlPgSqlCheckParams<'_>, +) -> Result, sqlx::Error> { + let create_fn = match should_check_function(params.ast, params.schema_cache) { + Some(stmt) => stmt, + None => return Ok(vec![]), + }; + + let (fn_schema, fn_name) = match pgt_query_ext::utils::parse_name(&create_fn.funcname) { + Some(n) => n, + None => return Ok(vec![]), + }; + + let fn_identifier = build_function_identifier(create_fn, &fn_schema, &fn_name); + + let fn_body = pgt_query_ext::utils::find_option_value(create_fn, "as") + .ok_or_else(|| sqlx::Error::Protocol("Failed to find function body".to_string()))?; + let offset = params + .sql + .find(&fn_body) + .ok_or_else(|| sqlx::Error::Protocol("Failed to find function body in SQL".to_string()))?; + let is_trigger = is_trigger_function(create_fn); + + let mut conn = params.conn.acquire().await?; + conn.close_on_drop(); + + let mut tx: Transaction<'_, Postgres> = conn.begin().await?; + + // disable function body checking to rely on plpgsql_check + sqlx::query("SET LOCAL check_function_bodies = off") + .execute(&mut *tx) + .await?; + + // make sure we run with "or replace" + let sql_with_replace = if !create_fn.replace { + let re = Regex::new(r"(?i)\bCREATE\s+FUNCTION\b").unwrap(); + re.replace(params.sql, "CREATE OR REPLACE FUNCTION") + .to_string() + } else { + params.sql.to_string() + }; + + // create the function - this should always succeed + sqlx::query(&sql_with_replace).execute(&mut *tx).await?; + + // run plpgsql_check and collect results with their relations + let results_with_relations: Vec<(String, Option)> = if is_trigger { + let mut results = Vec::new(); + + for trigger in params.schema_cache.triggers.iter() { + if trigger.proc_name == fn_name + && (fn_schema.is_none() || fn_schema.as_deref() == Some(&trigger.proc_schema)) + { + let relation = format!("{}.{}", trigger.table_schema, trigger.table_name); + + let result: Option = sqlx::query_scalar(&format!( + "select plpgsql_check_function('{}', '{}', format := 'json')", + fn_identifier, relation + )) + .fetch_optional(&mut *tx) + .await? + .flatten(); + + if let Some(result) = result { + results.push((result, Some(relation))); + } + } + } + + results + } else { + let result: Option = sqlx::query_scalar(&format!( + "select plpgsql_check_function('{}', format := 'json')", + fn_identifier + )) + .fetch_optional(&mut *tx) + .await? + .flatten(); + + if let Some(result) = result { + vec![(result, None)] + } else { + vec![] + } + }; + + tx.rollback().await?; + + // Parse results and create diagnostics + let mut diagnostics = Vec::new(); + for (result_json, relation) in results_with_relations { + let check_result: PlpgSqlCheckResult = serde_json::from_str(&result_json).map_err(|e| { + sqlx::Error::Protocol(format!("Failed to parse plpgsql_check result: {}", e)) + })?; + + let mut result_diagnostics = + create_diagnostics_from_check_result(&check_result, &fn_body, offset, relation); + diagnostics.append(&mut result_diagnostics); + } + + Ok(diagnostics) +} + +#[cfg(all(test, not(target_os = "windows")))] +mod tests { + use sqlx::{Executor, PgPool}; + + /// Test helper to run plpgsql_check and return diagnostics with span text + async fn run_plpgsql_check_test( + test_db: &PgPool, + setup_sql: &str, + create_fn_sql: &str, + ) -> Result<(Vec, Vec>), Box> + { + test_db.execute(setup_sql).await?; + + let ast = pgt_query::parse(create_fn_sql)? + .into_root() + .ok_or("Failed to parse SQL root")?; + let schema_cache = pgt_schema_cache::SchemaCache::load(test_db).await?; + + let diagnostics = super::check_plpgsql(super::PlPgSqlCheckParams { + conn: test_db, + sql: create_fn_sql, + ast: &ast, + schema_cache: &schema_cache, + }) + .await?; + + let span_texts = diagnostics + .iter() + .map(|diag| { + diag.span.as_ref().map(|s| { + let start = usize::from(s.start()); + let end = usize::from(s.end()); + create_fn_sql[start..end].to_string() + }) + }) + .collect(); + + Ok((diagnostics, span_texts)) + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_if_expr(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE t1(a int, b int); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.f1() + RETURNS void + LANGUAGE plpgsql + AS $function$ + declare r t1 := (select t1 from t1 where a = 1); + BEGIN + if r.c is null or + true is false + then -- there is bug - table t1 missing "c" column + RAISE NOTICE 'c is null'; + end if; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert_eq!(diagnostics.len(), 1); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!( + span_texts[0].as_deref(), + Some("if r.c is null or\n true is false\n then") + ); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_missing_var(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE t1(a int, b int); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.f1() + RETURNS void + LANGUAGE plpgsql + AS $function$ + BEGIN + SELECT 1 from t1 where a = v_c; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + assert_eq!(diagnostics.len(), 1); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("v_c")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_missing_col_if_stmt(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE t1(a int, b int); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.f1() + RETURNS void + LANGUAGE plpgsql + AS $function$ + BEGIN + if (select c from t1 where id = 1) is null then -- there is bug - table t1 missing "c" column + RAISE NOTICE 'c is null'; + end if; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + assert_eq!(diagnostics.len(), 1); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("c")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE t1(a int, b int); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.f1() + RETURNS void + LANGUAGE plpgsql + AS $function$ + DECLARE r record; + BEGIN + FOR r IN SELECT * FROM t1 + LOOP + RAISE NOTICE '%', r.c; -- there is bug - table t1 missing "c" column + END LOOP; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert_eq!(diagnostics.len(), 1); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("RAISE NOTICE '%', r.c;")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_stacked_diagnostics(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + "#; + + let create_fn_sql = r#" + create or replace function fxtest() + returns void as $$ + declare + v_sqlstate text; + v_message text; + v_context text; + begin + get stacked diagnostics + v_sqlstate = returned_sqlstate, + v_message = message_text, + v_context = pg_exception_context; + end; + $$ language plpgsql; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("get stacked diagnostics")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_constant_refcursor(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + create table rc_test(a int); + "#; + + let create_fn_sql = r#" + create function return_constant_refcursor() returns refcursor as $$ + declare + rc constant refcursor; + begin + open rc for select a from rc_test; + return rc; + end + $$ language plpgsql; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!( + span_texts[0].as_deref(), + Some("open rc for select a from rc_test;") + ); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_constant_assignment(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + create procedure p1(a int, out b int) + as $$ + begin + b := a + 10; + end; + $$ language plpgsql; + "#; + + let create_fn_sql = r#" + create function f1() + returns void as $$ + declare b constant int; + begin + call p1(10, b); + end; + $$ language plpgsql; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("call p1(10, b);")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_missing_procedure(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + "#; + + let create_fn_sql = r#" + create function f1() + returns void as $$ + declare b constant int; + begin + call p1(10, b); + end; + $$ language plpgsql; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert_eq!(span_texts[0].as_deref(), Some("p1")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_dml_in_stable_function(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + create table t1(a int, b int); + "#; + + let create_fn_sql = r#" + create function f1() + returns void as $$ + begin + if false then + insert into t1 values(10,20); + update t1 set a = 10; + delete from t1; + end if; + end; + $$ language plpgsql stable; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert_eq!(diagnostics.len(), 1); + assert!(span_texts[0].is_some()); + + assert_eq!(diagnostics[0].advices.code.as_deref(), Some("0A000")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_record_field_assignment(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + create function g1() returns table(a int, b int) as $$ + begin + return query select 1, 2; + end; + $$ language plpgsql; + "#; + + let create_fn_sql = r#" + create or replace function f1() + returns void as $$ + declare r record; + begin + for r in select * from g1() + loop + r.c := 20; + end loop; + end; + $$ language plpgsql; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert!(span_texts[0].is_some()); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_trigger_basic(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE users( + id serial primary key, + name text not null, + email text + ); + + CREATE OR REPLACE FUNCTION public.log_user_changes() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Intentional error: referencing non-existent column + INSERT INTO audit_log(table_name, changed_id, old_email, new_email) + VALUES ('users', NEW.id, OLD.email, NEW.email); + RETURN NEW; + END; + $function$; + + CREATE TRIGGER trg_users_audit + AFTER UPDATE ON users + FOR EACH ROW + EXECUTE FUNCTION public.log_user_changes(); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.log_user_changes() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Intentional error: referencing non-existent column + INSERT INTO audit_log(table_name, changed_id, old_email, new_email) + VALUES ('users', NEW.id, OLD.email, NEW.email); + RETURN NEW; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert!(diagnostics[0].advices.relation.is_some()); + assert_eq!( + diagnostics[0].advices.relation.as_deref(), + Some("public.users") + ); + assert_eq!(span_texts[0].as_deref(), Some("audit_log")); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_trigger_missing_column(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE products( + id serial primary key, + name text not null, + price numeric(10,2) + ); + + CREATE OR REPLACE FUNCTION public.validate_product() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Error: referencing non-existent column + IF NEW.category IS NULL THEN + RAISE EXCEPTION 'Category is required'; + END IF; + RETURN NEW; + END; + $function$; + + CREATE TRIGGER trg_product_validation + BEFORE INSERT OR UPDATE ON products + FOR EACH ROW + EXECUTE FUNCTION public.validate_product(); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.validate_product() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Error: referencing non-existent column + IF NEW.category IS NULL THEN + RAISE EXCEPTION 'Category is required'; + END IF; + RETURN NEW; + END; + $function$; + "#; + + let (diagnostics, span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(matches!( + diagnostics[0].severity, + pgt_diagnostics::Severity::Error + )); + assert!(span_texts[0].as_deref().unwrap().contains("category")); + assert_eq!( + diagnostics[0].advices.relation.as_deref(), + Some("public.products") + ); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_plpgsql_check_trigger_multiple_tables(test_db: PgPool) { + let setup = r#" + create extension if not exists plpgsql_check; + + CREATE TABLE table_a( + id serial primary key, + name text + ); + + CREATE TABLE table_b( + id serial primary key, + description text + ); + + CREATE OR REPLACE FUNCTION public.generic_audit() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Error: referencing column that doesn't exist in both tables + INSERT INTO audit_log(table_name, record_id, old_status) + VALUES (TG_TABLE_NAME, NEW.id, OLD.status); + RETURN NEW; + END; + $function$; + + CREATE TRIGGER trg_audit_a + AFTER UPDATE ON table_a + FOR EACH ROW + EXECUTE FUNCTION public.generic_audit(); + + CREATE TRIGGER trg_audit_b + AFTER UPDATE ON table_b + FOR EACH ROW + EXECUTE FUNCTION public.generic_audit(); + "#; + + let create_fn_sql = r#" + CREATE OR REPLACE FUNCTION public.generic_audit() + RETURNS trigger + LANGUAGE plpgsql + AS $function$ + BEGIN + -- Error: referencing column that doesn't exist in both tables + INSERT INTO audit_log(table_name, record_id, old_status) + VALUES (TG_TABLE_NAME, NEW.id, OLD.status); + RETURN NEW; + END; + $function$; + "#; + + let (diagnostics, _span_texts) = run_plpgsql_check_test(&test_db, setup, create_fn_sql) + .await + .expect("Failed to run plpgsql_check test"); + + assert!(!diagnostics.is_empty()); + assert!(diagnostics.len() >= 2); + + let relations: Vec<_> = diagnostics + .iter() + .filter_map(|d| d.advices.relation.as_ref()) + .collect(); + assert!(relations.contains(&&"public.table_a".to_string())); + assert!(relations.contains(&&"public.table_b".to_string())); + } +} diff --git a/crates/pgt_query_ext/src/lib.rs b/crates/pgt_query_ext/src/lib.rs index 4c630487..b0288da8 100644 --- a/crates/pgt_query_ext/src/lib.rs +++ b/crates/pgt_query_ext/src/lib.rs @@ -1 +1,2 @@ pub mod diagnostics; +pub mod utils; diff --git a/crates/pgt_workspace/src/workspace/server/function_utils.rs b/crates/pgt_query_ext/src/utils.rs similarity index 63% rename from crates/pgt_workspace/src/workspace/server/function_utils.rs rename to crates/pgt_query_ext/src/utils.rs index 74e76ff2..6dedebea 100644 --- a/crates/pgt_workspace/src/workspace/server/function_utils.rs +++ b/crates/pgt_query_ext/src/utils.rs @@ -55,3 +55,46 @@ pub fn parse_name(nodes: &[pgt_query::protobuf::Node]) -> Option<(Option _ => None, } } + +#[cfg(test)] +mod tests { + use crate::utils::{find_option_value, parse_name}; + + #[test] + fn test_find_option_value() { + let input = " + CREATE OR REPLACE FUNCTION public.f1() + RETURNS boolean + LANGUAGE plpgsql + AS $function$ + declare r t1 := (select t1 from t1 where a = 1); + BEGIN + if r.c is null or + true is false + then -- there is bug - table t1 missing \"c\" column + RAISE NOTICE 'c is null'; + end if; + END; + $function$; +" + .trim(); + + let ast = pgt_query::parse(input).unwrap().into_root().unwrap(); + let create_fn = match &ast { + pgt_query::NodeEnum::CreateFunctionStmt(stmt) => stmt, + _ => panic!("Expected CreateFunctionStmt"), + }; + + assert_eq!( + find_option_value(create_fn, "language"), + Some("plpgsql".to_string()) + ); + + assert!(find_option_value(create_fn, "as").is_some(),); + + assert_eq!( + parse_name(&create_fn.return_type.as_ref().unwrap().names), + Some((Some("pg_catalog".to_string()), "bool".to_string())) + ); + } +} diff --git a/crates/pgt_schema_cache/src/queries/triggers.sql b/crates/pgt_schema_cache/src/queries/triggers.sql index c28cc39f..895d1be0 100644 --- a/crates/pgt_schema_cache/src/queries/triggers.sql +++ b/crates/pgt_schema_cache/src/queries/triggers.sql @@ -1,17 +1,18 @@ --- we need to join tables from the pg_catalog since "TRUNCATE" triggers are +-- we need to join tables from the pg_catalog since "TRUNCATE" triggers are -- not available in the information_schema.trigger table. -select - t.tgname as "name!", - c.relname as "table_name!", - p.proname as "proc_name!", - n.nspname as "schema_name!", - t.tgtype as "details_bitmask!" -from - pg_catalog.pg_trigger t - left join pg_catalog.pg_proc p on t.tgfoid = p.oid - left join pg_catalog.pg_class c on t.tgrelid = c.oid - left join pg_catalog.pg_namespace n on c.relnamespace = n.oid -where - -- triggers enforcing constraints (e.g. unique fields) should not be included. - t.tgisinternal = false and - t.tgconstraint = 0; +select + t.tgname as "name!", + c.relname as "table_name!", + p.proname as "proc_name!", + proc_ns.nspname as "proc_schema!", + table_ns.nspname as "table_schema!", + t.tgtype as "details_bitmask!" +from + pg_catalog.pg_trigger t +left join pg_catalog.pg_proc p on t.tgfoid = p.oid +left join pg_catalog.pg_class c on t.tgrelid = c.oid +left join pg_catalog.pg_namespace table_ns on c.relnamespace = table_ns.oid +left join pg_catalog.pg_namespace proc_ns on p.pronamespace = proc_ns.oid +where + t.tgisinternal = false and + t.tgconstraint = 0; diff --git a/crates/pgt_schema_cache/src/schema_cache.rs b/crates/pgt_schema_cache/src/schema_cache.rs index df7239ea..84bcd77c 100644 --- a/crates/pgt_schema_cache/src/schema_cache.rs +++ b/crates/pgt_schema_cache/src/schema_cache.rs @@ -7,7 +7,7 @@ use crate::schemas::Schema; use crate::tables::Table; use crate::types::PostgresType; use crate::versions::Version; -use crate::{Role, Trigger}; +use crate::{Extension, Role, Trigger}; #[derive(Debug, Default)] pub struct SchemaCache { @@ -18,13 +18,25 @@ pub struct SchemaCache { pub versions: Vec, pub columns: Vec, pub policies: Vec, + pub extensions: Vec, pub triggers: Vec, pub roles: Vec, } impl SchemaCache { pub async fn load(pool: &PgPool) -> Result { - let (schemas, tables, functions, types, versions, columns, policies, triggers, roles) = futures_util::try_join!( + let ( + schemas, + tables, + functions, + types, + versions, + columns, + policies, + triggers, + roles, + extensions, + ) = futures_util::try_join!( Schema::load(pool), Table::load(pool), Function::load(pool), @@ -33,7 +45,8 @@ impl SchemaCache { Column::load(pool), Policy::load(pool), Trigger::load(pool), - Role::load(pool) + Role::load(pool), + Extension::load(pool), )?; Ok(SchemaCache { @@ -46,6 +59,7 @@ impl SchemaCache { policies, triggers, roles, + extensions, }) } diff --git a/crates/pgt_schema_cache/src/triggers.rs b/crates/pgt_schema_cache/src/triggers.rs index 2b2a3aff..d0a4788a 100644 --- a/crates/pgt_schema_cache/src/triggers.rs +++ b/crates/pgt_schema_cache/src/triggers.rs @@ -82,20 +82,22 @@ impl TryFrom for TriggerTiming { pub struct TriggerQueried { name: String, table_name: String, - schema_name: String, + table_schema: String, proc_name: String, + proc_schema: String, details_bitmask: i16, } #[derive(Debug, PartialEq, Eq)] pub struct Trigger { - name: String, - table_name: String, - schema_name: String, - proc_name: String, - affected: TriggerAffected, - timing: TriggerTiming, - events: Vec, + pub name: String, + pub table_name: String, + pub table_schema: String, + pub proc_name: String, + pub proc_schema: String, + pub affected: TriggerAffected, + pub timing: TriggerTiming, + pub events: Vec, } impl From for Trigger { @@ -104,7 +106,8 @@ impl From for Trigger { name: value.name, table_name: value.table_name, proc_name: value.proc_name, - schema_name: value.schema_name, + proc_schema: value.proc_schema, + table_schema: value.table_schema, affected: value.details_bitmask.into(), timing: value.details_bitmask.try_into().unwrap(), events: TriggerEvents::from(value.details_bitmask).0, @@ -141,7 +144,7 @@ mod tests { id serial primary key, name text ); - + create or replace function public.log_user_insert() returns trigger as $$ begin @@ -149,17 +152,17 @@ mod tests { return new; end; $$ language plpgsql; - + create trigger trg_users_insert before insert on public.users for each row execute function public.log_user_insert(); - + create trigger trg_users_update after update or insert on public.users for each statement execute function public.log_user_insert(); - + create trigger trg_users_delete before delete on public.users for each row @@ -186,7 +189,7 @@ mod tests { .iter() .find(|t| t.name == "trg_users_insert") .unwrap(); - assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_schema, "public"); assert_eq!(insert_trigger.table_name, "users"); assert_eq!(insert_trigger.timing, TriggerTiming::Before); assert_eq!(insert_trigger.affected, TriggerAffected::Row); @@ -197,7 +200,7 @@ mod tests { .iter() .find(|t| t.name == "trg_users_update") .unwrap(); - assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_schema, "public"); assert_eq!(insert_trigger.table_name, "users"); assert_eq!(update_trigger.timing, TriggerTiming::After); assert_eq!(update_trigger.affected, TriggerAffected::Statement); @@ -209,7 +212,7 @@ mod tests { .iter() .find(|t| t.name == "trg_users_delete") .unwrap(); - assert_eq!(insert_trigger.schema_name, "public"); + assert_eq!(insert_trigger.table_schema, "public"); assert_eq!(insert_trigger.table_name, "users"); assert_eq!(delete_trigger.timing, TriggerTiming::Before); assert_eq!(delete_trigger.affected, TriggerAffected::Row); @@ -275,7 +278,7 @@ mod tests { .iter() .find(|t| t.name == "trg_docs_instead_update") .unwrap(); - assert_eq!(instead_trigger.schema_name, "public"); + assert_eq!(instead_trigger.table_schema, "public"); assert_eq!(instead_trigger.table_name, "docs_view"); assert_eq!(instead_trigger.timing, TriggerTiming::Instead); assert_eq!(instead_trigger.affected, TriggerAffected::Row); @@ -286,7 +289,7 @@ mod tests { .iter() .find(|t| t.name == "trg_docs_truncate") .unwrap(); - assert_eq!(truncate_trigger.schema_name, "public"); + assert_eq!(truncate_trigger.table_schema, "public"); assert_eq!(truncate_trigger.table_name, "docs"); assert_eq!(truncate_trigger.timing, TriggerTiming::After); assert_eq!(truncate_trigger.affected, TriggerAffected::Statement); diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index a3dde01d..ceb36b94 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -3,7 +3,6 @@ mod typed_identifier; pub use diagnostics::TypecheckDiagnostic; use diagnostics::create_type_error; -use pgt_text_size::TextRange; use sqlx::postgres::PgDatabaseError; pub use sqlx::postgres::PgSeverity; use sqlx::{Executor, PgPool}; @@ -20,19 +19,6 @@ pub struct TypecheckParams<'a> { pub identifiers: Vec, } -#[derive(Debug, Clone)] -pub struct TypeError { - pub message: String, - pub code: String, - pub severity: PgSeverity, - pub position: Option, - pub range: Option, - pub table: Option, - pub column: Option, - pub data_type: Option, - pub constraint: Option, -} - pub async fn check_sql( params: TypecheckParams<'_>, ) -> Result, sqlx::Error> { diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index 4acc0600..efded47c 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -26,6 +26,7 @@ pgt_console = { workspace = true } pgt_diagnostics = { workspace = true } pgt_fs = { workspace = true, features = ["serde"] } pgt_lexer = { workspace = true } +pgt_plpgsql_check = { workspace = true } pgt_query = { workspace = true } pgt_query_ext = { workspace = true } pgt_schema_cache = { workspace = true } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index c6ed0827..f4a9561f 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -55,7 +55,6 @@ mod async_helper; mod connection_key; mod connection_manager; pub(crate) mod document; -mod function_utils; mod migration; mod pg_query; mod schema_cache_manager; @@ -454,7 +453,8 @@ impl Workspace for WorkspaceServer { let path_clone = params.path.clone(); let schema_cache = self.schema_cache.load(pool.clone())?; let input = doc.iter(TypecheckDiagnosticsMapper).collect::>(); - // sorry for the ugly code :( + + // Combined async context for both typecheck and plpgsql_check let async_results = run_async(async move { stream::iter(input) .map(|(id, range, ast, cst, sign)| { @@ -462,8 +462,11 @@ impl Workspace for WorkspaceServer { let path = path_clone.clone(); let schema_cache = Arc::clone(&schema_cache); async move { + let mut diagnostics = Vec::new(); + if let Some(ast) = ast { - pgt_typecheck::check_sql(TypecheckParams { + // Type checking + let typecheck_result = pgt_typecheck::check_sql(TypecheckParams { conn: &pool, sql: id.content(), ast: &ast, @@ -486,18 +489,40 @@ impl Workspace for WorkspaceServer { }) .unwrap_or_default(), }) + .await; + + if let Ok(Some(diag)) = typecheck_result { + let r = diag.location().span.map(|span| span + range.start()); + diagnostics.push( + diag.with_file_path(path.as_path().display().to_string()) + .with_file_span(r.unwrap_or(range)), + ); + } + + // plpgsql_check + let plpgsql_check_results = pgt_plpgsql_check::check_plpgsql( + pgt_plpgsql_check::PlPgSqlCheckParams { + conn: &pool, + sql: id.content(), + ast: &ast, + schema_cache: schema_cache.as_ref(), + }, + ) .await - .map(|d| { - d.map(|d| { - let r = d.location().span.map(|span| span + range.start()); + .unwrap_or_else(|_| vec![]); + println!("{:#?}", plpgsql_check_results); + + for d in plpgsql_check_results { + let r = d.span.map(|span| span + range.start()); + diagnostics.push( d.with_file_path(path.as_path().display().to_string()) - .with_file_span(r.unwrap_or(range)) - }) - }) - } else { - Ok(None) + .with_file_span(r.unwrap_or(range)), + ); + } } + + Ok::, sqlx::Error>(diagnostics) } }) .buffer_unordered(10) @@ -506,8 +531,8 @@ impl Workspace for WorkspaceServer { })?; for result in async_results.into_iter() { - let result = result?; - if let Some(diag) = result { + let diagnostics_batch = result?; + for diag in diagnostics_batch { diagnostics.push(SDiagnostic::new(diag)); } } @@ -548,6 +573,20 @@ impl Workspace for WorkspaceServer { analysable_stmts.push(node); } if let Some(diag) = diagnostic { + // ignore the syntax error if we already have more specialized diagnostics for the + // same statement. + // this is important for create function statements, where we might already have detailed + // diagnostics from plpgsql_check. + if diagnostics.iter().any(|d| { + d.location().span.is_some_and(|async_loc| { + diag.location() + .span + .is_some_and(|syntax_loc| syntax_loc.contains_range(async_loc)) + }) + }) { + continue; + } + diagnostics.push(SDiagnostic::new( diag.with_file_path(path.clone()) .with_severity(Severity::Error), diff --git a/crates/pgt_workspace/src/workspace/server.tests.rs b/crates/pgt_workspace/src/workspace/server.tests.rs index 0578f90f..ef5ba267 100644 --- a/crates/pgt_workspace/src/workspace/server.tests.rs +++ b/crates/pgt_workspace/src/workspace/server.tests.rs @@ -8,7 +8,7 @@ use pgt_configuration::{ use pgt_diagnostics::Diagnostic; use pgt_fs::PgTPath; use pgt_text_size::TextRange; -use sqlx::PgPool; +use sqlx::{Executor, PgPool}; use crate::{ Workspace, WorkspaceError, @@ -206,3 +206,74 @@ async fn correctly_ignores_files() { assert!(execute_statement_result.is_ok_and(|res| res == ExecuteStatementResult::default())); } + +#[cfg(all(test, not(target_os = "windows")))] +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_dedupe_diagnostics(test_db: PgPool) { + let mut conf = PartialConfiguration::init(); + conf.merge_with(PartialConfiguration { + db: Some(PartialDatabaseConfiguration { + database: Some( + test_db + .connect_options() + .get_database() + .unwrap() + .to_string(), + ), + ..Default::default() + }), + ..Default::default() + }); + + let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace"); + + let path = PgTPath::new("test.sql"); + + let setup_sql = "CREATE EXTENSION IF NOT EXISTS plpgsql_check;"; + test_db.execute(setup_sql).await.expect("setup sql failed"); + + let content = r#" + CREATE OR REPLACE FUNCTION public.f1() + RETURNS void + LANGUAGE plpgsql + AS $function$ + decare r text; + BEGIN + select '1' into into r; + END; + $function$; + "#; + + workspace + .open_file(OpenFileParams { + path: path.clone(), + content: content.into(), + version: 1, + }) + .expect("Unable to open test file"); + + let diagnostics = workspace + .pull_diagnostics(crate::workspace::PullDiagnosticsParams { + path: path.clone(), + categories: RuleCategories::all(), + max_diagnostics: 100, + only: vec![], + skip: vec![], + }) + .expect("Unable to pull diagnostics") + .diagnostics; + + assert_eq!(diagnostics.len(), 1, "Expected one diagnostic"); + + let diagnostic = &diagnostics[0]; + + assert_eq!( + diagnostic.category().map(|c| c.name()), + Some("plpgsql_check") + ); + + assert_eq!( + diagnostic.location().span, + Some(TextRange::new(115.into(), 210.into())) + ); +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index e90dd41b..05f1425d 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -5,7 +5,6 @@ use lru::LruCache; use pgt_query_ext::diagnostics::*; use pgt_text_size::TextRange; -use super::function_utils::find_option_value; use super::statement_identifier::StatementId; const DEFAULT_CACHE_SIZE: usize = 1000; @@ -61,7 +60,7 @@ impl PgQueryStore { _ => return None, }; - let language = find_option_value(create_fn, "language")?; + let language = pgt_query_ext::utils::find_option_value(create_fn, "language")?; if language != "plpgsql" { return None; @@ -73,7 +72,7 @@ impl PgQueryStore { return Some(existing.clone()); } - let sql_body = find_option_value(create_fn, "as")?; + let sql_body = pgt_query_ext::utils::find_option_value(create_fn, "as")?; let start = statement.content().find(&sql_body)?; let end = start + sql_body.len(); diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 0b230edc..4a1463b7 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -1,7 +1,5 @@ use pgt_text_size::TextRange; -use super::function_utils::{find_option_value, parse_name}; - #[derive(Debug, Clone)] pub struct ArgType { pub schema: Option, @@ -37,14 +35,14 @@ pub fn get_sql_fn_signature(ast: &pgt_query::NodeEnum) -> Option Option Option Option