Skip to content

Commit 9f8adaf

Browse files
committed
Add support for $$ in generic dialect ...
... in `create function` closes apache#1183
1 parent 6b03a25 commit 9f8adaf

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

derive/src/lib.rs

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,34 @@ use quote::{format_ident, quote, quote_spanned, ToTokens};
33
use syn::spanned::Spanned;
44
use syn::{
55
parse::{Parse, ParseStream},
6-
parse_macro_input, parse_quote, Attribute, Data, DeriveInput,
7-
Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token
6+
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
7+
Ident, Index, LitStr, Meta, Token,
88
};
99

10-
1110
/// Implementation of `[#derive(Visit)]`
1211
#[proc_macro_derive(VisitMut, attributes(visit))]
1312
pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
14-
derive_visit(input, &VisitType {
15-
visit_trait: quote!(VisitMut),
16-
visitor_trait: quote!(VisitorMut),
17-
modifier: Some(quote!(mut)),
18-
})
13+
derive_visit(
14+
input,
15+
&VisitType {
16+
visit_trait: quote!(VisitMut),
17+
visitor_trait: quote!(VisitorMut),
18+
modifier: Some(quote!(mut)),
19+
},
20+
)
1921
}
2022

2123
/// Implementation of `[#derive(Visit)]`
2224
#[proc_macro_derive(Visit, attributes(visit))]
2325
pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24-
derive_visit(input, &VisitType {
25-
visit_trait: quote!(Visit),
26-
visitor_trait: quote!(Visitor),
27-
modifier: None,
28-
})
26+
derive_visit(
27+
input,
28+
&VisitType {
29+
visit_trait: quote!(Visit),
30+
visitor_trait: quote!(Visitor),
31+
modifier: None,
32+
},
33+
)
2934
}
3035

3136
struct VisitType {
@@ -34,15 +39,16 @@ struct VisitType {
3439
modifier: Option<TokenStream>,
3540
}
3641

37-
fn derive_visit(
38-
input: proc_macro::TokenStream,
39-
visit_type: &VisitType,
40-
) -> proc_macro::TokenStream {
42+
fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_macro::TokenStream {
4143
// Parse the input tokens into a syntax tree.
4244
let input = parse_macro_input!(input as DeriveInput);
4345
let name = input.ident;
4446

45-
let VisitType { visit_trait, visitor_trait, modifier } = visit_type;
47+
let VisitType {
48+
visit_trait,
49+
visitor_trait,
50+
modifier,
51+
} = visit_type;
4652

4753
let attributes = Attributes::parse(&input.attrs);
4854
// Add a bound `T: Visit` to every type parameter T.
@@ -87,7 +93,10 @@ impl Parse for WithIdent {
8793
let mut result = WithIdent { with: None };
8894
let ident = input.parse::<Ident>()?;
8995
if ident != "with" {
90-
return Err(syn::Error::new(ident.span(), "Expected identifier to be `with`"));
96+
return Err(syn::Error::new(
97+
ident.span(),
98+
"Expected identifier to be `with`",
99+
));
91100
}
92101
input.parse::<Token!(=)>()?;
93102
let s = input.parse::<LitStr>()?;
@@ -131,17 +140,26 @@ impl Attributes {
131140
}
132141

133142
// Add a bound `T: Visit` to every type parameter T.
134-
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
143+
fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics {
135144
for param in &mut generics.params {
136145
if let GenericParam::Type(ref mut type_param) = *param {
137-
type_param.bounds.push(parse_quote!(sqlparser::ast::#visit_trait));
146+
type_param
147+
.bounds
148+
.push(parse_quote!(sqlparser::ast::#visit_trait));
138149
}
139150
}
140151
generics
141152
}
142153

143154
// Generate the body of the visit implementation for the given type
144-
fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType) -> TokenStream {
155+
fn visit_children(
156+
data: &Data,
157+
VisitType {
158+
visit_trait,
159+
modifier,
160+
..
161+
}: &VisitType,
162+
) -> TokenStream {
145163
match data {
146164
Data::Struct(data) => match &data.fields {
147165
Fields::Named(fields) => {

src/parser/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5954,7 +5954,8 @@ impl<'a> Parser<'a> {
59545954
pub fn parse_function_definition(&mut self) -> Result<FunctionDefinition, ParserError> {
59555955
let peek_token = self.peek_token();
59565956
match peek_token.token {
5957-
Token::DollarQuotedString(value) if dialect_of!(self is PostgreSqlDialect) => {
5957+
Token::DollarQuotedString(value) if dialect_of!(self is PostgreSqlDialect | GenericDialect) =>
5958+
{
59585959
self.next_token();
59595960
Ok(FunctionDefinition::DoubleDollarDef(value.value))
59605961
}

tests/sqlparser_postgres.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3335,7 +3335,7 @@ fn parse_create_function() {
33353335

33363336
let sql = r#"CREATE OR REPLACE FUNCTION increment(i INTEGER) RETURNS INTEGER LANGUAGE plpgsql AS $$ BEGIN RETURN i + 1; END; $$"#;
33373337
assert_eq!(
3338-
pg().verified_stmt(sql),
3338+
pg_and_generic().verified_stmt(sql),
33393339
Statement::CreateFunction {
33403340
or_replace: true,
33413341
temporary: false,

0 commit comments

Comments
 (0)