Skip to content

Commit cfd6525

Browse files
committed
Add PostGuard. async-graphql#129
1 parent f7bdf20 commit cfd6525

File tree

8 files changed

+416
-79
lines changed

8 files changed

+416
-79
lines changed

async-graphql-derive/src/args.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::utils::{get_rustdoc, parse_default, parse_default_with, parse_guards, parse_validator};
1+
use crate::utils::{
2+
get_rustdoc, parse_default, parse_default_with, parse_guards, parse_post_guards,
3+
parse_validator,
4+
};
25
use proc_macro2::TokenStream;
36
use quote::quote;
47
use syn::{Attribute, AttributeArgs, Error, Lit, Meta, MetaList, NestedMeta, Result, Type};
@@ -196,6 +199,7 @@ pub struct Field {
196199
pub requires: Option<String>,
197200
pub is_ref: bool,
198201
pub guard: Option<TokenStream>,
202+
pub post_guard: Option<TokenStream>,
199203
}
200204

201205
impl Field {
@@ -209,11 +213,13 @@ impl Field {
209213
let mut requires = None;
210214
let mut is_ref = false;
211215
let mut guard = None;
216+
let mut post_guard = None;
212217

213218
for attr in attrs {
214219
match attr.parse_meta()? {
215220
Meta::List(ls) if ls.path.is_ident("field") => {
216221
guard = parse_guards(crate_name, &ls)?;
222+
post_guard = parse_post_guards(crate_name, &ls)?;
217223
for meta in &ls.nested {
218224
match meta {
219225
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("skip") => {
@@ -300,6 +306,7 @@ impl Field {
300306
requires,
301307
is_ref,
302308
guard,
309+
post_guard,
303310
}))
304311
}
305312
}
@@ -881,20 +888,17 @@ impl Scalar {
881888
}
882889
}
883890

884-
pub struct Entity {
885-
pub guard: Option<TokenStream>,
886-
}
891+
pub struct Entity {}
887892

888893
impl Entity {
889-
pub fn parse(crate_name: &TokenStream, attrs: &[Attribute]) -> Result<Option<Self>> {
894+
pub fn parse(_crate_name: &TokenStream, attrs: &[Attribute]) -> Result<Option<Self>> {
890895
for attr in attrs {
891896
match attr.parse_meta()? {
892897
Meta::List(ls) if ls.path.is_ident("entity") => {
893-
let guard = parse_guards(crate_name, &ls)?;
894-
return Ok(Some(Self { guard }));
898+
return Ok(Some(Self {}));
895899
}
896900
Meta::Path(p) if p.is_ident("entity") => {
897-
return Ok(Some(Self { guard: None }));
901+
return Ok(Some(Self {}));
898902
}
899903
_ => {}
900904
}

async-graphql-derive/src/object.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::args;
22
use crate::output_type::OutputType;
3-
use crate::utils::{check_reserved_name, get_crate_name, get_rustdoc};
3+
use crate::utils::{check_reserved_name, get_crate_name, get_param_getter_ident, get_rustdoc};
44
use inflector::Inflector;
55
use proc_macro::TokenStream;
66
use quote::quote;
@@ -44,7 +44,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
4444

4545
for item in &mut item_impl.items {
4646
if let ImplItem::Method(method) = item {
47-
if let Some(entity) = args::Entity::parse(&crate_name, &method.attrs)? {
47+
if args::Entity::parse(&crate_name, &method.attrs)?.is_some() {
4848
if method.sig.asyncness.is_none() {
4949
return Err(Error::new_spanned(&method, "Must be asynchronous"));
5050
}
@@ -159,16 +159,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
159159
}
160160
let do_find = quote! { self.#field_ident(ctx, #(#use_keys),*).await.map_err(|err| err.into_error(ctx.position()))? };
161161

162-
let guard = entity.guard.map(
163-
|guard| quote! { #guard.check(ctx).await.map_err(|err| err.into_error(ctx.position()))?; },
164-
);
165-
166162
find_entities.push((
167163
args.len(),
168164
quote! {
169165
if typename == &<#entity_type as #crate_name::Type>::type_name() {
170166
if let (#(#key_pat),*) = (#(#key_getter),*) {
171-
#guard
172167
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
173168
return #crate_name::OutputValueType::resolve(&#do_find, &ctx_obj, ctx.item).await;
174169
}
@@ -328,8 +323,10 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
328323
Some(default) => quote! { Some(|| -> #ty { #default }) },
329324
None => quote! { None },
330325
};
326+
let param_getter_name = get_param_getter_ident(&ident.ident.to_string());
331327
get_params.push(quote! {
332-
let #ident: #ty = ctx.param_value(#name, #default)?;
328+
let #param_getter_name = || -> #crate_name::Result<#ty> { ctx.param_value(#name, #default) };
329+
let #ident: #ty = #param_getter_name()?;
333330
});
334331
}
335332

@@ -381,14 +378,22 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
381378
#guard.check(ctx).await
382379
.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
383380
});
381+
let post_guard = field
382+
.post_guard
383+
.map(|guard| quote! {
384+
#guard.check(ctx, &res).await
385+
.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
386+
});
384387

385388
resolvers.push(quote! {
386389
if ctx.name.node == #field_name {
387390
use #crate_name::OutputValueType;
388391
#(#get_params)*
389392
#guard
390393
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
391-
return OutputValueType::resolve(&#resolve_obj, &ctx_obj, ctx.item).await;
394+
let res = #resolve_obj;
395+
#post_guard
396+
return OutputValueType::resolve(&res, &ctx_obj, ctx.item).await;
392397
}
393398
});
394399

@@ -462,6 +467,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
462467
}.into_error(ctx.position()))
463468
}
464469

470+
#[allow(unused_variables)]
465471
async fn find_entity(&self, ctx: &#crate_name::Context<'_>, params: &#crate_name::Value) -> #crate_name::Result<#crate_name::serde_json::Value> {
466472
let params = match params {
467473
#crate_name::Value::Object(params) => params,

async-graphql-derive/src/simple_object.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result<T
9595
let guard = field
9696
.guard
9797
.map(|guard| quote! { #guard.check(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?; });
98+
let post_guard = field
99+
.post_guard
100+
.map(|guard| quote! { #guard.check(ctx, &res).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?; });
98101

99102
if field.is_ref {
100103
getters.push(quote! {
@@ -119,6 +122,7 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result<T
119122
#guard
120123
let res = self.#ident(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
121124
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
125+
#post_guard
122126
return #crate_name::OutputValueType::resolve(&res, &ctx_obj, ctx.item).await;
123127
}
124128
});

async-graphql-derive/src/subscription.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::args;
22
use crate::output_type::OutputType;
3-
use crate::utils::{check_reserved_name, get_crate_name, get_rustdoc};
3+
use crate::utils::{check_reserved_name, get_crate_name, get_param_getter_ident, get_rustdoc};
44
use inflector::Inflector;
55
use proc_macro::TokenStream;
66
use quote::quote;
@@ -176,7 +176,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
176176
Some(default) => quote! { Some(|| -> #ty { #default }) },
177177
None => quote! { None },
178178
};
179+
let param_getter_name = get_param_getter_ident(&ident.ident.to_string());
179180
get_params.push(quote! {
181+
let #param_getter_name = || -> #crate_name::Result<#ty> { ctx.param_value(#name, #default) };
180182
let #ident: #ty = ctx.param_value(#name, #default)?;
181183
});
182184
}
@@ -229,6 +231,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
229231
let guard = field.guard.map(|guard| quote! {
230232
#guard.check(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
231233
});
234+
if field.post_guard.is_some() {
235+
return Err(Error::new_spanned(
236+
method,
237+
"The subscription field does not support post guard",
238+
));
239+
}
232240

233241
create_stream.push(quote! {
234242
if ctx.name.node == #field_name {

async-graphql-derive/src/utils.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ pub fn parse_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<
133133
if let Lit::Str(value) = &nv.lit {
134134
let value_str = value.value();
135135
if value_str.starts_with('@') {
136-
let id = Ident::new(&value_str[1..], value.span());
137-
params.push(quote! { #name: &#id });
136+
let getter_name = get_param_getter_ident(&value_str[1..]);
137+
params.push(quote! { #name: #getter_name()? });
138138
} else {
139139
let expr = syn::parse_str::<Expr>(&value_str)?;
140140
params.push(quote! { #name: (#expr).into() });
@@ -170,6 +170,60 @@ pub fn parse_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<
170170
Ok(None)
171171
}
172172

173+
pub fn parse_post_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<TokenStream>> {
174+
for arg in &args.nested {
175+
if let NestedMeta::Meta(Meta::List(ls)) = arg {
176+
if ls.path.is_ident("post_guard") {
177+
let mut guards = None;
178+
179+
for item in &ls.nested {
180+
if let NestedMeta::Meta(Meta::List(ls)) = item {
181+
let ty = &ls.path;
182+
let mut params = Vec::new();
183+
for attr in &ls.nested {
184+
if let NestedMeta::Meta(Meta::NameValue(nv)) = attr {
185+
let name = &nv.path;
186+
if let Lit::Str(value) = &nv.lit {
187+
let value_str = value.value();
188+
if value_str.starts_with('@') {
189+
let getter_name = get_param_getter_ident(&value_str[1..]);
190+
params.push(quote! { #name: #getter_name()? });
191+
} else {
192+
let expr = syn::parse_str::<Expr>(&value_str)?;
193+
params.push(quote! { #name: (#expr).into() });
194+
}
195+
} else {
196+
return Err(Error::new_spanned(
197+
&nv.lit,
198+
"Value must be string literal",
199+
));
200+
}
201+
} else {
202+
return Err(Error::new_spanned(attr, "Invalid property for guard"));
203+
}
204+
}
205+
206+
let guard = quote! { #ty { #(#params),* } };
207+
if guards.is_none() {
208+
guards = Some(guard);
209+
} else {
210+
guards = Some(
211+
quote! { #crate_name::guard::PostGuardExt::and(#guard, #guards) },
212+
);
213+
}
214+
} else {
215+
return Err(Error::new_spanned(item, "Invalid guard"));
216+
}
217+
}
218+
219+
return Ok(guards);
220+
}
221+
}
222+
}
223+
224+
Ok(None)
225+
}
226+
173227
pub fn get_rustdoc(attrs: &[Attribute]) -> Result<Option<String>> {
174228
let mut full_docs = String::new();
175229
for attr in attrs {
@@ -231,3 +285,7 @@ pub fn parse_default_with(lit: &Lit) -> Result<TokenStream> {
231285
))
232286
}
233287
}
288+
289+
pub fn get_param_getter_ident(name: &str) -> Ident {
290+
Ident::new(&format!("__{}_getter", name), Span::call_site())
291+
}

src/guard.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
//! Field guards
22
33
use crate::{Context, FieldResult};
4+
use serde::export::PhantomData;
45

56
/// Field guard
67
///
7-
/// Guard is a precondition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
8+
/// Guard is a pre-condition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
89
#[async_trait::async_trait]
910
pub trait Guard {
1011
#[allow(missing_docs)]
@@ -31,3 +32,35 @@ impl<A: Guard + Send + Sync, B: Guard + Send + Sync> Guard for And<A, B> {
3132
self.1.check(ctx).await
3233
}
3334
}
35+
36+
/// Field post guard
37+
///
38+
/// Guard is a post-condition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
39+
#[async_trait::async_trait]
40+
pub trait PostGuard<T: Send + Sync> {
41+
#[allow(missing_docs)]
42+
async fn check(&self, ctx: &Context<'_>, result: &T) -> FieldResult<()>;
43+
}
44+
45+
/// An extension trait for `PostGuard<T>`
46+
pub trait PostGuardExt<T: Send + Sync>: PostGuard<T> + Sized {
47+
/// Merge the two guards.
48+
fn and<R: PostGuard<T>>(self, other: R) -> PostAnd<T, Self, R> {
49+
PostAnd(self, other, PhantomData)
50+
}
51+
}
52+
53+
impl<T: PostGuard<R>, R: Send + Sync> PostGuardExt<R> for T {}
54+
55+
/// PostGuard for `GuardExt<T>::and`
56+
pub struct PostAnd<T: Send + Sync, A: PostGuard<T>, B: PostGuard<T>>(A, B, PhantomData<T>);
57+
58+
#[async_trait::async_trait]
59+
impl<T: Send + Sync, A: PostGuard<T> + Send + Sync, B: PostGuard<T> + Send + Sync> PostGuard<T>
60+
for PostAnd<T, A, B>
61+
{
62+
async fn check(&self, ctx: &Context<'_>, result: &T) -> FieldResult<()> {
63+
self.0.check(ctx, result).await?;
64+
self.1.check(ctx, result).await
65+
}
66+
}

0 commit comments

Comments
 (0)