From 221b9dc4c577539d86a9dca10118a596766cbef6 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 21 Nov 2024 16:36:37 +0800 Subject: [PATCH 1/6] extract `support_period_map_access_key` config --- src/dialect/bigquery.rs | 4 ++++ src/dialect/duckdb.rs | 5 +++++ src/dialect/generic.rs | 4 ++++ src/dialect/mod.rs | 9 +++++++++ 4 files changed, 22 insertions(+) diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index 96633552b..0e6514aee 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -72,4 +72,8 @@ impl Dialect for BigQueryDialect { fn require_interval_qualifier(&self) -> bool { true } + + fn support_period_map_access_key(&self) -> bool { + true + } } diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index 905b04e36..9eea23986 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -71,4 +71,9 @@ impl Dialect for DuckDbDialect { fn supports_load_extension(&self) -> bool { true } + + /// See DuckDB + fn support_period_map_access_key(&self) -> bool { + true + } } diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 4998e0f4b..080e9d29a 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -119,4 +119,8 @@ impl Dialect for GenericDialect { fn supports_load_extension(&self) -> bool { true } + + fn support_period_map_access_key(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 956a58986..a6db1d5cc 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -657,6 +657,15 @@ pub trait Dialect: Debug + Any { fn supports_create_table_select(&self) -> bool { false } + + /// Return true if the dialect supports the period map access key + /// + /// Access on BigQuery nested and repeated expressions can + /// mix notations in the same expression. + /// https://cloud.google.com/bigquery/docs/nested-repeated#query_nested_and_repeated_columns + fn support_period_map_access_key(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined From 27782113623b0507abf70dfbd1f0228950e0603f Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 21 Nov 2024 16:37:49 +0800 Subject: [PATCH 2/6] handle the chain of the subscript and map accesses for generic and duckdb --- src/parser/mod.rs | 64 ++++++++++++++++++++++++++++----------- tests/sqlparser_common.rs | 22 ++++++++++++++ 2 files changed, 68 insertions(+), 18 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a583112a7..8dafbf026 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2919,12 +2919,23 @@ impl<'a> Parser<'a> { }) } else if Token::LBracket == tok { if dialect_of!(self is PostgreSqlDialect | DuckDbDialect | GenericDialect) { - self.parse_subscript(expr) + let expr = self.parse_multi_dim_subscript(expr)?; + if self.dialect.support_period_map_access_key() { + self.parse_map_access(expr, vec![]) + } else { + Ok(expr) + } } else if dialect_of!(self is SnowflakeDialect) { self.prev_token(); self.parse_json_access(expr) } else { - self.parse_map_access(expr) + let key = self.parse_expr()?; + self.expect_token(&Token::RBracket)?; + let keys = vec![MapAccessKey { + key, + syntax: MapAccessSyntax::Bracket, + }]; + self.parse_map_access(expr, keys) } } else if dialect_of!(self is SnowflakeDialect | GenericDialect) && Token::Colon == tok { self.prev_token(); @@ -3020,6 +3031,19 @@ impl<'a> Parser<'a> { }) } + /// Parse an multi-dimension array accessing like '[1:3][1][1]' + /// + /// Parser is right after the first `[` + pub fn parse_multi_dim_subscript(&mut self, mut expr: Expr) -> Result { + loop { + expr = self.parse_subscript(expr)?; + if !self.consume_token(&Token::LBracket) { + break; + } + } + Ok(expr) + } + /// Parses an array subscript like `[1:3]` /// /// Parser is right after `[` @@ -3085,14 +3109,15 @@ impl<'a> Parser<'a> { }) } - pub fn parse_map_access(&mut self, expr: Expr) -> Result { - let key = self.parse_expr()?; - self.expect_token(&Token::RBracket)?; - - let mut keys = vec![MapAccessKey { - key, - syntax: MapAccessSyntax::Bracket, - }]; + /// Parse the map access like `[key]` or `.key` if [Dialect::support_period_map_access_key] is true + /// It could be an access-chain like `[key1][key2].key3` + /// + /// The parameter `keys` is an initialized buffer that could contain some keys parsed from other places. + pub fn parse_map_access( + &mut self, + expr: Expr, + mut keys: Vec, + ) -> Result { loop { let key = match self.peek_token().token { Token::LBracket => { @@ -3104,10 +3129,7 @@ impl<'a> Parser<'a> { syntax: MapAccessSyntax::Bracket, } } - // Access on BigQuery nested and repeated expressions can - // mix notations in the same expression. - // https://cloud.google.com/bigquery/docs/nested-repeated#query_nested_and_repeated_columns - Token::Period if dialect_of!(self is BigQueryDialect) => { + Token::Period if self.dialect.support_period_map_access_key() => { self.next_token(); // consume `.` MapAccessKey { key: self.parse_expr()?, @@ -3119,10 +3141,16 @@ impl<'a> Parser<'a> { keys.push(key); } - Ok(Expr::MapAccess { - column: Box::new(expr), - keys, - }) + // If no any key be collected, it means the elements have been parsed to [Subscript] + // e.g. `select abc[1]` or `select abc[1][2]` + if keys.is_empty() { + Ok(expr) + } else { + Ok(Expr::MapAccess { + column: Box::new(expr), + keys, + }) + } } /// Parses the parens following the `[ NOT ] IN` operator. diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 2ffb5f44b..17cc16356 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -12013,3 +12013,25 @@ fn parse_create_table_select() { ); } } + +#[test] +fn test_period_map_access() { + let supported_dialects = TestedDialects::new(vec![ + Box::new(GenericDialect {}), + Box::new(DuckDbDialect {}), + ]); + let sqls = [ + "SELECT abc[1] FROM t", + "SELECT abc[1].f1 FROM t", + "SELECT abc[1].f1.f2 FROM t", + "SELECT f1.abc[1] FROM t", + "SELECT f1.f2.abc[1] FROM t", + "SELECT f1.abc[1].f2 FROM t", + "SELECT abc['a'][1].f1 FROM t", + "SELECT abc['a'].f1[1].f2 FROM t", + "SELECT abc['a'].f1[1].f2[2] FROM t", + ]; + for sql in sqls { + supported_dialects.verified_stmt(sql); + } +} From a4a544850d559f38cc91a8d5ea75d1dbb7ab48ba Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 21 Nov 2024 16:45:39 +0800 Subject: [PATCH 3/6] fix the doc test --- src/dialect/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 9f39100ed..1d5b234be 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -680,7 +680,7 @@ pub trait Dialect: Debug + Any { /// /// Access on BigQuery nested and repeated expressions can /// mix notations in the same expression. - /// https://cloud.google.com/bigquery/docs/nested-repeated#query_nested_and_repeated_columns + /// fn support_period_map_access_key(&self) -> bool { false } From dc5e54078ea2738bd778d9e89a6b1d6fa2db480a Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 21 Nov 2024 16:47:32 +0800 Subject: [PATCH 4/6] fix doc --- src/parser/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index ad279b80f..cba0556b8 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3047,7 +3047,7 @@ impl<'a> Parser<'a> { }) } - /// Parse an multi-dimension array accessing like '[1:3][1][1]' + /// Parse an multi-dimension array accessing like `[1:3][1][1]` /// /// Parser is right after the first `[` pub fn parse_multi_dim_subscript(&mut self, mut expr: Expr) -> Result { From 96082d876233ee748a5497d7a3682b2d1172bf76 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Tue, 26 Nov 2024 22:01:25 +0800 Subject: [PATCH 5/6] support visit option field --- derive/README.md | 49 ++++++++++++++++++++++++++++++++++++++++++++++ derive/src/lib.rs | 36 +++++++++++++++++++++++++++------- src/ast/mod.rs | 2 +- src/ast/visitor.rs | 13 ++++++++++-- 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/derive/README.md b/derive/README.md index aa70e7c71..b5ccc69e0 100644 --- a/derive/README.md +++ b/derive/README.md @@ -151,6 +151,55 @@ visitor.post_visit_expr() visitor.post_visit_expr() ``` +If the field is a `Option` and add `#[with = "visit_xxx"]` to the field, the generated code +will try to access the field only if it is `Some`: + +```rust +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ShowStatementIn { + pub clause: ShowStatementInClause, + pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] + pub parent_name: Option, +} +``` + +This will generate + +```rust +impl sqlparser::ast::Visit for ShowStatementIn { + fn visit( + &self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::Visit::visit(&self.clause, visitor)?; + sqlparser::ast::Visit::visit(&self.parent_type, visitor)?; + if let Some(value) = &self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::Visit::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} + +impl sqlparser::ast::VisitMut for ShowStatementIn { + fn visit( + &mut self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?; + sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?; + if let Some(value) = &mut self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::VisitMut::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} +``` + ## Releasing This crate's release is not automated. Instead it is released manually as needed diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5ad1607f9..dd4d37b41 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -18,11 +18,8 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, - Ident, Index, LitStr, Meta, Token, -}; +use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath}; +use syn::{Path, PathArguments}; /// Implementation of `[#derive(Visit)]` #[proc_macro_derive(VisitMut, attributes(visit))] @@ -182,9 +179,21 @@ fn visit_children( Fields::Named(fields) => { let recurse = fields.named.iter().map(|f| { let name = &f.ident; + let is_option = is_option(&f.ty); let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit) + if is_option && attributes.with.is_some() { + let (pre_visit, post_visit) = attributes.visit(quote!(value)); + quote_spanned!(f.span() => + if let Some(value) = &#modifier self.#name { + #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit + } + ) + } else { + let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); + quote_spanned!(f.span() => + #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit + ) + } }); quote! { #(#recurse)* @@ -256,3 +265,16 @@ fn visit_children( Data::Union(_) => unimplemented!(), } } + +fn is_option(ty: &Type) -> bool { + if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty { + if let Some(segment) = segments.last() { + if segment.ident == "Option" { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + return args.args.len() == 1; + } + } + } + } + false +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9185c9df4..35b82ed97 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -26,7 +26,6 @@ use alloc::{ use core::fmt::{self, Display}; use core::ops::Deref; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -7587,6 +7586,7 @@ impl fmt::Display for ShowStatementInParentType { pub struct ShowStatementIn { pub clause: ShowStatementInClause, pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] pub parent_name: Option, } diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 418e0a299..001143132 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -734,7 +734,7 @@ mod tests { #[test] fn test_sql() { - let tests = vec![ + let tests: Vec<_> = vec![ ( "SELECT * from table_name as my_table", vec![ @@ -876,7 +876,16 @@ mod tests { "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", ] - ) + ), + ( + "SHOW COLUMNS FROM t1", + vec![ + "PRE: STATEMENT: SHOW COLUMNS FROM t1", + "PRE: RELATION: t1", + "POST: RELATION: t1", + "POST: STATEMENT: SHOW COLUMNS FROM t1", + ], + ), ]; for (sql, expected) in tests { let actual = do_visit(sql); From 55f1eeb2f3c33787f6a02c0dd0833afbd7943fc1 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Tue, 26 Nov 2024 22:48:02 +0800 Subject: [PATCH 6/6] remove unnecessary --- src/ast/visitor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 001143132..eacd268a4 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -734,7 +734,7 @@ mod tests { #[test] fn test_sql() { - let tests: Vec<_> = vec![ + let tests = vec![ ( "SELECT * from table_name as my_table", vec![