Skip to content

Add bounds for fields in derive macro #14521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct Foo;
#[derive(Copy)]
struct Foo;

impl < > core::marker::Copy for Foo< > {}"#]],
impl < > core::marker::Copy for Foo< > where {}"#]],
);
}

Expand All @@ -41,7 +41,7 @@ macro Copy {}
#[derive(Copy)]
struct Foo;

impl < > crate ::marker::Copy for Foo< > {}"#]],
impl < > crate ::marker::Copy for Foo< > where {}"#]],
);
}

Expand All @@ -57,7 +57,7 @@ struct Foo<A, B>;
#[derive(Copy)]
struct Foo<A, B>;

impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -74,7 +74,7 @@ struct Foo<A, B, 'a, 'b>;
#[derive(Copy)]
struct Foo<A, B, 'a, 'b>;

impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -90,7 +90,7 @@ struct Foo<A, B>;
#[derive(Clone)]
struct Foo<A, B>;

impl <T0: core::clone::Clone, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -106,6 +106,6 @@ struct Foo<const X: usize, T>(u32);
#[derive(Clone)]
struct Foo<const X: usize, T>(u32);

impl <const T0: usize, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
impl <const X: usize, T: core::clone::Clone, > core::clone::Clone for Foo<X, T, > where u32: core::clone::Clone, {}"#]],
);
}
73 changes: 56 additions & 17 deletions crates/hir-expand/src/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//! Builtin derives.

use base_db::{CrateOrigin, LangCrateOrigin};
use either::Either;
use tracing::debug;

use crate::tt::{self, TokenId};
use syntax::{
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName},
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds},
match_ast,
};

Expand Down Expand Up @@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>

struct BasicAdtInfo {
name: tt::Ident,
/// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
param_types: Vec<Option<tt::Subtree>>,
/// first field is the name, and
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
field_types: Vec<tt::Subtree>,
}

fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
Expand All @@ -75,17 +79,34 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
ExpandError::Other("no item found".into())
})?;
let node = item.syntax();
let (name, params) = match_ast! {
let (name, params, fields) = match_ast! {
match node {
ast::Struct(it) => (it.name(), it.generic_param_list()),
ast::Enum(it) => (it.name(), it.generic_param_list()),
ast::Union(it) => (it.name(), it.generic_param_list()),
ast::Struct(it) => {
(it.name(), it.generic_param_list(), it.field_list().into_iter().collect::<Vec<_>>())
},
ast::Enum(it) => (it.name(), it.generic_param_list(), it.variant_list().into_iter().flat_map(|x| x.variants()).filter_map(|x| x.field_list()).collect()),
ast::Union(it) => (it.name(), it.generic_param_list(), it.record_field_list().into_iter().map(|x| ast::FieldList::RecordFieldList(x)).collect()),
_ => {
debug!("unexpected node is {:?}", node);
return Err(ExpandError::Other("expected struct, enum or union".into()))
},
}
};
let field_types = fields
.into_iter()
.flat_map(|f| match f {
ast::FieldList::RecordFieldList(x) => Either::Left(
x.fields()
.filter_map(|x| x.ty())
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0),
),
ast::FieldList::TupleFieldList(x) => Either::Right(
x.fields()
.filter_map(|x| x.ty())
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0),
),
})
.collect::<Vec<_>>();
let name = name.ok_or_else(|| {
debug!("parsed item has no name");
ExpandError::Other("missing name".into())
Expand All @@ -97,35 +118,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
if let ast::TypeOrConstParam::Const(param) = param {
let name = param
.name()
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.unwrap_or_else(tt::Subtree::empty);
let bounds = match &param {
ast::TypeOrConstParam::Type(x) => {
x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
}
ast::TypeOrConstParam::Const(_) => None,
};
let ty = if let ast::TypeOrConstParam::Const(param) = param {
let ty = param
.ty()
.map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0)
.unwrap_or_else(tt::Subtree::empty);
Some(ty)
} else {
None
}
};
(name, ty, bounds)
})
.collect();
Ok(BasicAdtInfo { name: name_token, param_types })
Ok(BasicAdtInfo { name: name_token, param_types, field_types })
}

fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e),
};
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
.into_iter()
.enumerate()
.map(|(idx, param_ty)| {
let ident = tt::Leaf::Ident(tt::Ident {
span: tt::TokenId::unspecified(),
text: format!("T{idx}").into(),
});
.map(|(ident, param_ty, bound)| {
let ident_ = ident.clone();
if let Some(b) = bound {
let ident = ident.clone();
where_block.push(quote! { #ident : #b , });
}
if let Some(ty) = param_ty {
(quote! { const #ident : #ty , }, quote! { #ident_ , })
} else {
Expand All @@ -134,9 +166,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
}
})
.unzip();

where_block.extend(info.field_types.iter().map(|x| {
let x = x.clone();
let bound = trait_path.clone();
quote! { #x : #bound , }
}));

let name = info.name;
let expanded = quote! {
impl < ##params > #trait_path for #name < ##args > {}
impl < ##params > #trait_path for #name < ##args > where ##where_block {}
};
ExpandResult::ok(expanded)
}
Expand Down
8 changes: 4 additions & 4 deletions crates/ide/src/expand_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ struct Foo {}
"#,
expect![[r#"
Clone
impl < >core::clone::Clone for Foo< >{}
impl < >core::clone::Clone for Foo< >where{}
"#]],
);
}
Expand All @@ -488,7 +488,7 @@ struct Foo {}
"#,
expect![[r#"
Copy
impl < >core::marker::Copy for Foo< >{}
impl < >core::marker::Copy for Foo< >where{}
"#]],
);
}
Expand All @@ -504,7 +504,7 @@ struct Foo {}
"#,
expect![[r#"
Copy
impl < >core::marker::Copy for Foo< >{}
impl < >core::marker::Copy for Foo< >where{}
"#]],
);
check(
Expand All @@ -516,7 +516,7 @@ struct Foo {}
"#,
expect![[r#"
Clone
impl < >core::clone::Clone for Foo< >{}
impl < >core::clone::Clone for Foo< >where{}
"#]],
);
}
Expand Down