Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct Foo;
#[derive(Copy)]
struct Foo;

impl < > core::marker::Copy for Foo< > {}"#]],
impl < > core::marker::Copy for Foo< > where {}"#]],
);
}

Expand All @@ -41,7 +41,7 @@ macro Copy {}
#[derive(Copy)]
struct Foo;

impl < > crate ::marker::Copy for Foo< > {}"#]],
impl < > crate ::marker::Copy for Foo< > where {}"#]],
);
}

Expand All @@ -57,7 +57,7 @@ struct Foo<A, B>;
#[derive(Copy)]
struct Foo<A, B>;

impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -74,7 +74,7 @@ struct Foo<A, B, 'a, 'b>;
#[derive(Copy)]
struct Foo<A, B, 'a, 'b>;

impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -90,7 +90,7 @@ struct Foo<A, B>;
#[derive(Clone)]
struct Foo<A, B>;

impl <T0: core::clone::Clone, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Foo<A, B, > where {}"#]],
);
}

Expand All @@ -106,6 +106,6 @@ struct Foo<const X: usize, T>(u32);
#[derive(Clone)]
struct Foo<const X: usize, T>(u32);

impl <const T0: usize, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
impl <const X: usize, T: core::clone::Clone, > core::clone::Clone for Foo<X, T, > where {}"#]],
);
}
131 changes: 111 additions & 20 deletions crates/hir-expand/src/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//! Builtin derives.

use base_db::{CrateOrigin, LangCrateOrigin};
use std::collections::HashSet;
use tracing::debug;

use crate::tt::{self, TokenId};
use syntax::{
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName},
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, PathType},
match_ast,
};

Expand Down Expand Up @@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>

struct BasicAdtInfo {
name: tt::Ident,
/// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
param_types: Vec<Option<tt::Subtree>>,
/// first field is the name, and
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
associated_types: Vec<tt::Subtree>,
}

fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
Expand All @@ -86,46 +90,126 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
},
}
};
let name = name.ok_or_else(|| {
debug!("parsed item has no name");
ExpandError::Other("missing name".into())
})?;
let name_token_id =
token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified);
let name_token = tt::Ident { span: name_token_id, text: name.text().into() };
let mut param_type_set: HashSet<String> = HashSet::new();
let param_types = params
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
if let ast::TypeOrConstParam::Const(param) = param {
let name = {
let this = param.name();
match this {
Some(x) => {
param_type_set.insert(x.to_string());
mbe::syntax_node_to_token_tree(x.syntax()).0
}
None => tt::Subtree::empty(),
}
};
let bounds = match &param {
ast::TypeOrConstParam::Type(x) => {
x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
}
ast::TypeOrConstParam::Const(_) => None,
};
let ty = if let ast::TypeOrConstParam::Const(param) = param {
let ty = param
.ty()
.map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0)
.unwrap_or_else(tt::Subtree::empty);
Some(ty)
} else {
None
}
};
(name, ty, bounds)
})
.collect();
Ok(BasicAdtInfo { name: name_token, param_types })
let is_associated_type = |p: &PathType| {
if let Some(p) = p.path() {
if let Some(parent) = p.qualifier() {
if let Some(x) = parent.segment() {
if let Some(x) = x.path_type() {
if let Some(x) = x.path() {
if let Some(pname) = x.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// <T as Trait>::Assoc
return true;
}
}
}
}
}
if let Some(pname) = parent.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// T::Assoc
return true;
}
}
}
}
false
};
let associated_types = node
.descendants()
.filter_map(PathType::cast)
.filter(is_associated_type)
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.collect::<Vec<_>>();
let name = name.ok_or_else(|| {
debug!("parsed item has no name");
ExpandError::Other("missing name".into())
})?;
let name_token_id =
token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified);
let name_token = tt::Ident { span: name_token_id, text: name.text().into() };
Ok(BasicAdtInfo { name: name_token, param_types, associated_types })
}

/// Given that we are deriving a trait `DerivedTrait` for a type like:
///
/// ```ignore (only-for-syntax-highlight)
/// struct Struct<'a, ..., 'z, A, B: DeclaredTrait, C, ..., Z> where C: WhereTrait {
/// a: A,
/// b: B::Item,
/// b1: <B as DeclaredTrait>::Item,
/// c1: <C as WhereTrait>::Item,
/// c2: Option<<C as WhereTrait>::Item>,
/// ...
/// }
/// ```
///
/// create an impl like:
///
/// ```ignore (only-for-syntax-highlight)
/// impl<'a, ..., 'z, A, B: DeclaredTrait, C, ... Z> where
/// C: WhereTrait,
/// A: DerivedTrait + B1 + ... + BN,
/// B: DerivedTrait + B1 + ... + BN,
/// C: DerivedTrait + B1 + ... + BN,
/// B::Item: DerivedTrait + B1 + ... + BN,
/// <C as WhereTrait>::Item: DerivedTrait + B1 + ... + BN,
/// ...
/// {
/// ...
/// }
/// ```
///
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e),
};
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
.into_iter()
.enumerate()
.map(|(idx, param_ty)| {
let ident = tt::Leaf::Ident(tt::Ident {
span: tt::TokenId::unspecified(),
text: format!("T{idx}").into(),
});
.map(|(ident, param_ty, bound)| {
let ident_ = ident.clone();
if let Some(b) = bound {
let ident = ident.clone();
where_block.push(quote! { #ident : #b , });
}
if let Some(ty) = param_ty {
(quote! { const #ident : #ty , }, quote! { #ident_ , })
} else {
Expand All @@ -134,9 +218,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
}
})
.unzip();

where_block.extend(info.associated_types.iter().map(|x| {
let x = x.clone();
let bound = trait_path.clone();
quote! { #x : #bound , }
}));

let name = info.name;
let expanded = quote! {
impl < ##params > #trait_path for #name < ##args > {}
impl < ##params > #trait_path for #name < ##args > where ##where_block {}
};
ExpandResult::ok(expanded)
}
Expand Down
60 changes: 60 additions & 0 deletions crates/hir-ty/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4315,3 +4315,63 @@ impl Trait for () {
"#,
);
}

#[test]
fn derive_macro_bounds() {
check_types(
r#"
//- minicore: clone, derive
#[derive(Clone)]
struct Copy;
struct NotCopy;
#[derive(Clone)]
struct Generic<T>(T);
trait Tr {
type Assoc;
}
impl Tr for Copy {
type Assoc = NotCopy;
}
#[derive(Clone)]
struct AssocGeneric<T: Tr>(T::Assoc);

#[derive(Clone)]
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);

#[derive(Clone)]
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);

#[derive(Clone)]
struct Vec<T>();

#[derive(Clone)]
struct R1(Vec<R2>);
#[derive(Clone)]
struct R2(R1);

fn f() {
let x = (&Copy).clone();
//^ Copy
let x = (&NotCopy).clone();
//^ &NotCopy
let x = (&Generic(Copy)).clone();
//^ Generic<Copy>
let x = (&Generic(NotCopy)).clone();
//^ &Generic<NotCopy>
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
let x = x.clone();
//^ &AssocGeneric<Copy>
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
let x = x.clone();
//^ &AssocGeneric2<Copy>
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
let x = x.clone();
//^ &AssocGeneric3<Copy>
let x = (&R1(Vec())).clone();
//^ R1
let x = (&R2(R1(Vec()))).clone();
//^ R2
}
"#,
);
}
8 changes: 4 additions & 4 deletions crates/ide/src/expand_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ struct Foo {}
"#,
expect![[r#"
Clone
impl < >core::clone::Clone for Foo< >{}
impl < >core::clone::Clone for Foo< >where{}
"#]],
);
}
Expand All @@ -488,7 +488,7 @@ struct Foo {}
"#,
expect![[r#"
Copy
impl < >core::marker::Copy for Foo< >{}
impl < >core::marker::Copy for Foo< >where{}
"#]],
);
}
Expand All @@ -504,7 +504,7 @@ struct Foo {}
"#,
expect![[r#"
Copy
impl < >core::marker::Copy for Foo< >{}
impl < >core::marker::Copy for Foo< >where{}
"#]],
);
check(
Expand All @@ -516,7 +516,7 @@ struct Foo {}
"#,
expect![[r#"
Clone
impl < >core::clone::Clone for Foo< >{}
impl < >core::clone::Clone for Foo< >where{}
"#]],
);
}
Expand Down
6 changes: 6 additions & 0 deletions crates/test-utils/src/minicore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ pub mod clone {
pub trait Clone: Sized {
fn clone(&self) -> Self;
}

impl<T> Clone for &T {
fn clone(&self) -> Self {
*self
}
}
// region:derive
#[rustc_builtin_macro]
pub macro Clone($item:item) {}
Expand Down