Skip to content

feat: type inference for generic associated types #13494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 27, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions crates/hir-def/src/item_tree/lower.rs
Original file line number Diff line number Diff line change
@@ -662,8 +662,12 @@ fn desugar_future_path(orig: TypeRef) -> Path {
let mut generic_args: Vec<_> =
std::iter::repeat(None).take(path.segments().len() - 1).collect();
let mut last = GenericArgs::empty();
let binding =
AssociatedTypeBinding { name: name![Output], type_ref: Some(orig), bounds: Vec::new() };
let binding = AssociatedTypeBinding {
name: name![Output],
args: None,
type_ref: Some(orig),
bounds: Vec::new(),
};
last.bindings.push(binding);
generic_args.push(Some(Interned::new(last)));

3 changes: 3 additions & 0 deletions crates/hir-def/src/path.rs
Original file line number Diff line number Diff line change
@@ -68,6 +68,9 @@ pub struct GenericArgs {
pub struct AssociatedTypeBinding {
/// The name of the associated type.
pub name: Name,
/// The generic arguments to the associated type. e.g. For `Trait<Assoc<'a, T> = &'a T>`, this
/// would be `['a, T]`.
pub args: Option<Interned<GenericArgs>>,
/// The type bound to this associated type (in `Item = T`, this would be the
/// `T`). This can be `None` if there are bounds instead.
pub type_ref: Option<TypeRef>,
8 changes: 7 additions & 1 deletion crates/hir-def/src/path/lower.rs
Original file line number Diff line number Diff line change
@@ -163,6 +163,10 @@ pub(super) fn lower_generic_args(
ast::GenericArg::AssocTypeArg(assoc_type_arg) => {
if let Some(name_ref) = assoc_type_arg.name_ref() {
let name = name_ref.as_name();
let args = assoc_type_arg
.generic_arg_list()
.and_then(|args| lower_generic_args(lower_ctx, args))
.map(Interned::new);
let type_ref = assoc_type_arg.ty().map(|it| TypeRef::from_ast(lower_ctx, it));
let bounds = if let Some(l) = assoc_type_arg.type_bound_list() {
l.bounds()
@@ -171,7 +175,7 @@ pub(super) fn lower_generic_args(
} else {
Vec::new()
};
bindings.push(AssociatedTypeBinding { name, type_ref, bounds });
bindings.push(AssociatedTypeBinding { name, args, type_ref, bounds });
}
}
ast::GenericArg::LifetimeArg(lifetime_arg) => {
@@ -214,6 +218,7 @@ fn lower_generic_args_from_fn_path(
let type_ref = TypeRef::from_ast_opt(ctx, ret_type.ty());
bindings.push(AssociatedTypeBinding {
name: name![Output],
args: None,
type_ref: Some(type_ref),
bounds: Vec::new(),
});
@@ -222,6 +227,7 @@ fn lower_generic_args_from_fn_path(
let type_ref = TypeRef::Tuple(Vec::new());
bindings.push(AssociatedTypeBinding {
name: name![Output],
args: None,
type_ref: Some(type_ref),
bounds: Vec::new(),
});
17 changes: 10 additions & 7 deletions crates/hir-ty/src/chalk_ext.rs
Original file line number Diff line number Diff line change
@@ -11,9 +11,9 @@ use syntax::SmolStr;

use crate::{
db::HirDatabase, from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id,
from_placeholder_idx, to_chalk_trait_id, AdtId, AliasEq, AliasTy, Binders, CallableDefId,
CallableSig, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy, QuantifiedWhereClause,
Substitution, TraitRef, Ty, TyBuilder, TyKind, WhereClause,
from_placeholder_idx, to_chalk_trait_id, utils::generics, AdtId, AliasEq, AliasTy, Binders,
CallableDefId, CallableSig, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy,
QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, WhereClause,
};

pub trait TyExt {
@@ -338,10 +338,13 @@ pub trait ProjectionTyExt {

impl ProjectionTyExt for ProjectionTy {
fn trait_ref(&self, db: &dyn HirDatabase) -> TraitRef {
TraitRef {
trait_id: to_chalk_trait_id(self.trait_(db)),
substitution: self.substitution.clone(),
}
// FIXME: something like `Split` trait from chalk-solve might be nice.
let generics = generics(db.upcast(), from_assoc_type_id(self.associated_ty_id).into());
let substitution = Substitution::from_iter(
Interner,
self.substitution.iter(Interner).skip(generics.len_self()),
);
TraitRef { trait_id: to_chalk_trait_id(self.trait_(db)), substitution }
}

fn trait_(&self, db: &dyn HirDatabase) -> TraitId {
36 changes: 26 additions & 10 deletions crates/hir-ty/src/display.rs
Original file line number Diff line number Diff line change
@@ -289,16 +289,18 @@ impl HirDisplay for ProjectionTy {
return write!(f, "{}", TYPE_HINT_TRUNCATION);
}

let trait_ = f.db.trait_data(self.trait_(f.db));
let trait_ref = self.trait_ref(f.db);
write!(f, "<")?;
self.self_type_parameter(f.db).hir_fmt(f)?;
write!(f, " as {}", trait_.name)?;
if self.substitution.len(Interner) > 1 {
fmt_trait_ref(&trait_ref, f, true)?;
write!(f, ">::{}", f.db.type_alias_data(from_assoc_type_id(self.associated_ty_id)).name)?;
let proj_params_count =
self.substitution.len(Interner) - trait_ref.substitution.len(Interner);
let proj_params = &self.substitution.as_slice(Interner)[..proj_params_count];
if !proj_params.is_empty() {
write!(f, "<")?;
f.write_joined(&self.substitution.as_slice(Interner)[1..], ", ")?;
f.write_joined(proj_params, ", ")?;
write!(f, ">")?;
}
write!(f, ">::{}", f.db.type_alias_data(from_assoc_type_id(self.associated_ty_id)).name)?;
Ok(())
}
}
@@ -641,9 +643,12 @@ impl HirDisplay for Ty {
// Use placeholder associated types when the target is test (https://rust-lang.github.io/chalk/book/clauses/type_equality.html#placeholder-associated-types)
if f.display_target.is_test() {
write!(f, "{}::{}", trait_.name, type_alias_data.name)?;
// Note that the generic args for the associated type come before those for the
// trait (including the self type).
// FIXME: reconsider the generic args order upon formatting?
if parameters.len(Interner) > 0 {
write!(f, "<")?;
f.write_joined(&*parameters.as_slice(Interner), ", ")?;
f.write_joined(parameters.as_slice(Interner), ", ")?;
write!(f, ">")?;
}
} else {
@@ -972,9 +977,20 @@ fn write_bounds_like_dyn_trait(
angle_open = true;
}
if let AliasTy::Projection(proj) = alias {
let type_alias =
f.db.type_alias_data(from_assoc_type_id(proj.associated_ty_id));
write!(f, "{} = ", type_alias.name)?;
let assoc_ty_id = from_assoc_type_id(proj.associated_ty_id);
let type_alias = f.db.type_alias_data(assoc_ty_id);
write!(f, "{}", type_alias.name)?;

let proj_arg_count = generics(f.db.upcast(), assoc_ty_id.into()).len_self();
if proj_arg_count > 0 {
write!(f, "<")?;
f.write_joined(
&proj.substitution.as_slice(Interner)[..proj_arg_count],
", ",
)?;
write!(f, ">")?;
}
write!(f, " = ")?;
}
ty.hir_fmt(f)?;
}
2 changes: 1 addition & 1 deletion crates/hir-ty/src/infer/path.rs
Original file line number Diff line number Diff line change
@@ -157,7 +157,7 @@ impl<'a> InferenceContext<'a> {
remaining_segments_for_ty,
true,
);
if let TyKind::Error = ty.kind(Interner) {
if ty.is_unknown() {
return None;
}

9 changes: 0 additions & 9 deletions crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
@@ -124,14 +124,6 @@ pub type ConstrainedSubst = chalk_ir::ConstrainedSubst<Interner>;
pub type Guidance = chalk_solve::Guidance<Interner>;
pub type WhereClause = chalk_ir::WhereClause<Interner>;

// FIXME: get rid of this
pub fn subst_prefix(s: &Substitution, n: usize) -> Substitution {
Substitution::from_iter(
Interner,
s.as_slice(Interner)[..std::cmp::min(s.len(Interner), n)].iter().cloned(),
)
}

/// Return an index of a parameter in the generic type parameter list by it's id.
pub fn param_idx(db: &dyn HirDatabase, id: TypeOrConstParamId) -> Option<usize> {
generics(db.upcast(), id.parent).param_idx(id)
@@ -382,7 +374,6 @@ pub(crate) fn fold_tys_and_consts<T: HasInterner<Interner = Interner> + TypeFold
pub fn replace_errors_with_variables<T>(t: &T) -> Canonical<T>
where
T: HasInterner<Interner = Interner> + TypeFoldable<Interner> + Clone,
T: HasInterner<Interner = Interner>,
{
use chalk_ir::{
fold::{FallibleTypeFolder, TypeSuperFoldable},
124 changes: 91 additions & 33 deletions crates/hir-ty/src/lower.rs
Original file line number Diff line number Diff line change
@@ -447,12 +447,31 @@ impl<'a> TyLoweringContext<'a> {
.db
.trait_data(trait_ref.hir_trait_id())
.associated_type_by_name(segment.name);

match found {
Some(associated_ty) => {
// FIXME handle type parameters on the segment
// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
// generic params. It's inefficient to splice the `Substitution`s, so we may want
// that method to optionally take parent `Substitution` as we already know them at
// this point (`trait_ref.substitution`).
let substitution = self.substs_from_path_segment(
segment,
Some(associated_ty.into()),
false,
None,
);
let len_self =
generics(self.db.upcast(), associated_ty.into()).len_self();
let substitution = Substitution::from_iter(
Interner,
substitution
.iter(Interner)
.take(len_self)
.chain(trait_ref.substitution.iter(Interner)),
);
TyKind::Alias(AliasTy::Projection(ProjectionTy {
associated_ty_id: to_assoc_type_id(associated_ty),
substitution: trait_ref.substitution,
substitution,
}))
.intern(Interner)
}
@@ -590,36 +609,48 @@ impl<'a> TyLoweringContext<'a> {
res,
Some(segment.name.clone()),
move |name, t, associated_ty| {
if name == segment.name {
let substs = match self.type_param_mode {
ParamLoweringMode::Placeholder => {
// if we're lowering to placeholders, we have to put
// them in now
let generics = generics(
self.db.upcast(),
self.resolver
.generic_def()
.expect("there should be generics if there's a generic param"),
);
let s = generics.placeholder_subst(self.db);
s.apply(t.substitution.clone(), Interner)
}
ParamLoweringMode::Variable => t.substitution.clone(),
};
// We need to shift in the bound vars, since
// associated_type_shorthand_candidates does not do that
let substs = substs.shifted_in_from(Interner, self.in_binders);
// FIXME handle type parameters on the segment
Some(
TyKind::Alias(AliasTy::Projection(ProjectionTy {
associated_ty_id: to_assoc_type_id(associated_ty),
substitution: substs,
}))
.intern(Interner),
)
} else {
None
if name != segment.name {
return None;
}

// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
// generic params. It's inefficient to splice the `Substitution`s, so we may want
// that method to optionally take parent `Substitution` as we already know them at
// this point (`t.substitution`).
let substs = self.substs_from_path_segment(
segment.clone(),
Some(associated_ty.into()),
false,
None,
);

let len_self = generics(self.db.upcast(), associated_ty.into()).len_self();

let substs = Substitution::from_iter(
Interner,
substs.iter(Interner).take(len_self).chain(t.substitution.iter(Interner)),
);

let substs = match self.type_param_mode {
ParamLoweringMode::Placeholder => {
// if we're lowering to placeholders, we have to put
// them in now
let generics = generics(self.db.upcast(), def);
let s = generics.placeholder_subst(self.db);
s.apply(substs, Interner)
}
ParamLoweringMode::Variable => substs,
};
// We need to shift in the bound vars, since
// associated_type_shorthand_candidates does not do that
let substs = substs.shifted_in_from(Interner, self.in_binders);
Some(
TyKind::Alias(AliasTy::Projection(ProjectionTy {
associated_ty_id: to_assoc_type_id(associated_ty),
substitution: substs,
}))
.intern(Interner),
)
},
);

@@ -777,7 +808,15 @@ impl<'a> TyLoweringContext<'a> {
// handle defaults. In expression or pattern path segments without
// explicitly specified type arguments, missing type arguments are inferred
// (i.e. defaults aren't used).
if !infer_args || had_explicit_args {
// Generic parameters for associated types are not supposed to have defaults, so we just
// ignore them.
let is_assoc_ty = if let GenericDefId::TypeAliasId(id) = def {
let container = id.lookup(self.db.upcast()).container;
matches!(container, ItemContainerId::TraitId(_))
} else {
false
};
if !is_assoc_ty && (!infer_args || had_explicit_args) {
let defaults = self.db.generic_defaults(def);
assert_eq!(total_len, defaults.len());
let parent_from = item_len - substs.len();
@@ -966,9 +1005,28 @@ impl<'a> TyLoweringContext<'a> {
None => return SmallVec::new(),
Some(t) => t,
};
// FIXME: `substs_from_path_segment()` pushes `TyKind::Error` for every parent
// generic params. It's inefficient to splice the `Substitution`s, so we may want
// that method to optionally take parent `Substitution` as we already know them at
// this point (`super_trait_ref.substitution`).
let substitution = self.substs_from_path_segment(
// FIXME: This is hack. We shouldn't really build `PathSegment` directly.
PathSegment { name: &binding.name, args_and_bindings: binding.args.as_deref() },
Some(associated_ty.into()),
false, // this is not relevant
Some(super_trait_ref.self_type_parameter(Interner)),
);
let self_params = generics(self.db.upcast(), associated_ty.into()).len_self();
let substitution = Substitution::from_iter(
Interner,
substitution
.iter(Interner)
.take(self_params)
.chain(super_trait_ref.substitution.iter(Interner)),
);
let projection_ty = ProjectionTy {
associated_ty_id: to_assoc_type_id(associated_ty),
substitution: super_trait_ref.substitution,
substitution,
};
let mut preds: SmallVec<[_; 1]> = SmallVec::with_capacity(
binding.type_ref.as_ref().map_or(0, |_| 1) + binding.bounds.len(),
31 changes: 31 additions & 0 deletions crates/hir-ty/src/tests/display_source_code.rs
Original file line number Diff line number Diff line change
@@ -196,3 +196,34 @@ fn test(
"#,
);
}

#[test]
fn projection_type_correct_arguments_order() {
check_types_source_code(
r#"
trait Foo<T> {
type Assoc<U>;
}
fn f<T: Foo<i32>>(a: T::Assoc<usize>) {
a;
//^ <T as Foo<i32>>::Assoc<usize>
}
"#,
);
}

#[test]
fn generic_associated_type_binding_in_impl_trait() {
check_types_source_code(
r#"
//- minicore: sized
trait Foo<T> {
type Assoc<U>;
}
fn f(a: impl Foo<i8, Assoc<i16> = i32>) {
a;
//^ impl Foo<i8, Assoc<i16> = i32>
}
"#,
);
}
Loading