From 96082d876233ee748a5497d7a3682b2d1172bf76 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Tue, 26 Nov 2024 22:01:25 +0800 Subject: [PATCH 1/2] support visit option field --- derive/README.md | 49 ++++++++++++++++++++++++++++++++++++++++++++++ derive/src/lib.rs | 36 +++++++++++++++++++++++++++------- src/ast/mod.rs | 2 +- src/ast/visitor.rs | 13 ++++++++++-- 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/derive/README.md b/derive/README.md index aa70e7c71..b5ccc69e0 100644 --- a/derive/README.md +++ b/derive/README.md @@ -151,6 +151,55 @@ visitor.post_visit_expr() visitor.post_visit_expr() ``` +If the field is a `Option` and add `#[with = "visit_xxx"]` to the field, the generated code +will try to access the field only if it is `Some`: + +```rust +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ShowStatementIn { + pub clause: ShowStatementInClause, + pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] + pub parent_name: Option, +} +``` + +This will generate + +```rust +impl sqlparser::ast::Visit for ShowStatementIn { + fn visit( + &self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::Visit::visit(&self.clause, visitor)?; + sqlparser::ast::Visit::visit(&self.parent_type, visitor)?; + if let Some(value) = &self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::Visit::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} + +impl sqlparser::ast::VisitMut for ShowStatementIn { + fn visit( + &mut self, + visitor: &mut V, + ) -> ::std::ops::ControlFlow { + sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?; + sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?; + if let Some(value) = &mut self.parent_name { + visitor.pre_visit_relation(value)?; + sqlparser::ast::VisitMut::visit(value, visitor)?; + visitor.post_visit_relation(value)?; + } + ::std::ops::ControlFlow::Continue(()) + } +} +``` + ## Releasing This crate's release is not automated. Instead it is released manually as needed diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5ad1607f9..dd4d37b41 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -18,11 +18,8 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; -use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, - Ident, Index, LitStr, Meta, Token, -}; +use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath}; +use syn::{Path, PathArguments}; /// Implementation of `[#derive(Visit)]` #[proc_macro_derive(VisitMut, attributes(visit))] @@ -182,9 +179,21 @@ fn visit_children( Fields::Named(fields) => { let recurse = fields.named.iter().map(|f| { let name = &f.ident; + let is_option = is_option(&f.ty); let attributes = Attributes::parse(&f.attrs); - let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); - quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit) + if is_option && attributes.with.is_some() { + let (pre_visit, post_visit) = attributes.visit(quote!(value)); + quote_spanned!(f.span() => + if let Some(value) = &#modifier self.#name { + #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit + } + ) + } else { + let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name)); + quote_spanned!(f.span() => + #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit + ) + } }); quote! { #(#recurse)* @@ -256,3 +265,16 @@ fn visit_children( Data::Union(_) => unimplemented!(), } } + +fn is_option(ty: &Type) -> bool { + if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty { + if let Some(segment) = segments.last() { + if segment.ident == "Option" { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + return args.args.len() == 1; + } + } + } + } + false +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 9185c9df4..35b82ed97 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -26,7 +26,6 @@ use alloc::{ use core::fmt::{self, Display}; use core::ops::Deref; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -7587,6 +7586,7 @@ impl fmt::Display for ShowStatementInParentType { pub struct ShowStatementIn { pub clause: ShowStatementInClause, pub parent_type: Option, + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] pub parent_name: Option, } diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 418e0a299..001143132 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -734,7 +734,7 @@ mod tests { #[test] fn test_sql() { - let tests = vec![ + let tests: Vec<_> = vec![ ( "SELECT * from table_name as my_table", vec![ @@ -876,7 +876,16 @@ mod tests { "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID", ] - ) + ), + ( + "SHOW COLUMNS FROM t1", + vec![ + "PRE: STATEMENT: SHOW COLUMNS FROM t1", + "PRE: RELATION: t1", + "POST: RELATION: t1", + "POST: STATEMENT: SHOW COLUMNS FROM t1", + ], + ), ]; for (sql, expected) in tests { let actual = do_visit(sql); From 55f1eeb2f3c33787f6a02c0dd0833afbd7943fc1 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Tue, 26 Nov 2024 22:48:02 +0800 Subject: [PATCH 2/2] remove unnecessary --- src/ast/visitor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 001143132..eacd268a4 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -734,7 +734,7 @@ mod tests { #[test] fn test_sql() { - let tests: Vec<_> = vec![ + let tests = vec![ ( "SELECT * from table_name as my_table", vec![