Skip to content

feat: Make block-local trait impls work #9244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/hir_def/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl ModuleId {
self.def_map(db).containing_module(self.local_id)
}

pub fn containing_block(&self) -> Option<BlockId> {
self.block
}

/// Returns `true` if this module represents a block expression.
///
/// Returns `false` if this module is a submodule *inside* a block expression
Expand Down Expand Up @@ -581,6 +585,18 @@ impl HasModule for GenericDefId {
}
}

impl HasModule for TypeAliasId {
fn module(&self, db: &dyn db::DefDatabase) -> ModuleId {
self.lookup(db).module(db)
}
}

impl HasModule for TraitId {
fn module(&self, db: &dyn db::DefDatabase) -> ModuleId {
self.lookup(db).container
}
}

impl HasModule for StaticLoc {
fn module(&self, _db: &dyn db::DefDatabase) -> ModuleId {
self.container
Expand Down
36 changes: 28 additions & 8 deletions crates/hir_ty/src/chalk_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ use chalk_solve::rust_ir::{self, OpaqueTyDatumBound, WellKnownTrait};
use base_db::CrateId;
use hir_def::{
lang_item::{lang_attr, LangItemTarget},
AssocContainerId, AssocItemId, GenericDefId, HasModule, Lookup, TypeAliasId,
AssocContainerId, AssocItemId, GenericDefId, HasModule, Lookup, ModuleId, TypeAliasId,
};
use hir_expand::name::name;

use crate::{
db::HirDatabase,
display::HirDisplay,
from_assoc_type_id, from_chalk_trait_id, make_only_type_binders,
from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, make_only_type_binders,
mapping::{from_chalk, ToChalk, TypeAliasAsValue},
method_resolution::{TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS},
method_resolution::{TraitImpls, TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS},
to_assoc_type_id, to_chalk_trait_id,
traits::ChalkContext,
utils::generics,
Expand Down Expand Up @@ -105,27 +105,47 @@ impl<'a> chalk_solve::RustIrDatabase<Interner> for ChalkContext<'a> {
_ => self_ty_fp.as_ref().map(std::slice::from_ref).unwrap_or(&[]),
};

fn local_impls(db: &dyn HirDatabase, module: ModuleId) -> Option<Arc<TraitImpls>> {
db.trait_impls_in_block(module.containing_block()?)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might cov-mark here?

}

// Note: Since we're using impls_for_trait, only impls where the trait
// can be resolved should ever reach Chalk. Symbol’s value as variable is void: impl_datum relies on that
// can be resolved should ever reach Chalk. impl_datum relies on that
// and will panic if the trait can't be resolved.
let in_deps = self.db.trait_impls_in_deps(self.krate);
let in_self = self.db.trait_impls_in_crate(self.krate);
let impl_maps = [in_deps, in_self];
let trait_module = trait_.module(self.db.upcast());
let type_module = match self_ty_fp {
Some(TyFingerprint::Adt(adt_id)) => Some(adt_id.module(self.db.upcast())),
Some(TyFingerprint::ForeignType(type_id)) => {
Some(from_foreign_def_id(type_id).module(self.db.upcast()))
}
Some(TyFingerprint::Dyn(trait_id)) => Some(trait_id.module(self.db.upcast())),
_ => None,
};
let impl_maps = [
Some(in_deps),
Some(in_self),
local_impls(self.db, trait_module),
type_module.and_then(|m| local_impls(self.db, m)),
];

let id_to_chalk = |id: hir_def::ImplId| id.to_chalk(self.db);

let result: Vec<_> = if fps.is_empty() {
debug!("Unrestricted search for {:?} impls...", trait_);
impl_maps
.iter()
.flat_map(|crate_impl_defs| crate_impl_defs.for_trait(trait_).map(id_to_chalk))
.filter_map(|o| o.as_ref())
.flat_map(|impls| impls.for_trait(trait_).map(id_to_chalk))
.collect()
} else {
impl_maps
.iter()
.flat_map(|crate_impl_defs| {
.filter_map(|o| o.as_ref())
.flat_map(|impls| {
fps.iter().flat_map(move |fp| {
crate_impl_defs.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk)
impls.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk)
})
})
.collect()
Expand Down
7 changes: 5 additions & 2 deletions crates/hir_ty/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use std::sync::Arc;

use base_db::{impl_intern_key, salsa, CrateId, Upcast};
use hir_def::{
db::DefDatabase, expr::ExprId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId, ImplId,
LifetimeParamId, LocalFieldId, TypeParamId, VariantId,
db::DefDatabase, expr::ExprId, BlockId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId,
ImplId, LifetimeParamId, LocalFieldId, TypeParamId, VariantId,
};
use la_arena::ArenaMap;

Expand Down Expand Up @@ -79,6 +79,9 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[salsa::invoke(TraitImpls::trait_impls_in_crate_query)]
fn trait_impls_in_crate(&self, krate: CrateId) -> Arc<TraitImpls>;

#[salsa::invoke(TraitImpls::trait_impls_in_block_query)]
fn trait_impls_in_block(&self, krate: BlockId) -> Option<Arc<TraitImpls>>;

#[salsa::invoke(TraitImpls::trait_impls_in_deps_query)]
fn trait_impls_in_deps(&self, krate: CrateId) -> Arc<TraitImpls>;

Expand Down
64 changes: 38 additions & 26 deletions crates/hir_ty/src/method_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use arrayvec::ArrayVec;
use base_db::{CrateId, Edition};
use chalk_ir::{cast::Cast, Mutability, UniverseIndex};
use hir_def::{
lang_item::LangItemTarget, nameres::DefMap, AssocContainerId, AssocItemId, FunctionId,
lang_item::LangItemTarget, nameres::DefMap, AssocContainerId, AssocItemId, BlockId, FunctionId,
GenericDefId, HasModule, ImplId, Lookup, ModuleId, TraitId,
};
use hir_expand::name::Name;
Expand Down Expand Up @@ -139,35 +139,47 @@ impl TraitImpls {
let mut impls = Self { map: FxHashMap::default() };

let crate_def_map = db.crate_def_map(krate);
collect_def_map(db, &crate_def_map, &mut impls);
impls.collect_def_map(db, &crate_def_map);

return Arc::new(impls);
}

fn collect_def_map(db: &dyn HirDatabase, def_map: &DefMap, impls: &mut TraitImpls) {
for (_module_id, module_data) in def_map.modules() {
for impl_id in module_data.scope.impls() {
let target_trait = match db.impl_trait(impl_id) {
Some(tr) => tr.skip_binders().hir_trait_id(),
None => continue,
};
let self_ty = db.impl_self_ty(impl_id);
let self_ty_fp = TyFingerprint::for_trait_impl(self_ty.skip_binders());
impls
.map
.entry(target_trait)
.or_default()
.entry(self_ty_fp)
.or_default()
.push(impl_id);
}
pub(crate) fn trait_impls_in_block_query(
db: &dyn HirDatabase,
block: BlockId,
) -> Option<Arc<Self>> {
let _p = profile::span("trait_impls_in_block_query");
let mut impls = Self { map: FxHashMap::default() };

// To better support custom derives, collect impls in all unnamed const items.
// const _: () = { ... };
for konst in module_data.scope.unnamed_consts() {
let body = db.body(konst.into());
for (_, block_def_map) in body.blocks(db.upcast()) {
collect_def_map(db, &block_def_map, impls);
}
let block_def_map = db.block_def_map(block)?;
impls.collect_def_map(db, &block_def_map);

return Some(Arc::new(impls));
}

fn collect_def_map(&mut self, db: &dyn HirDatabase, def_map: &DefMap) {
for (_module_id, module_data) in def_map.modules() {
for impl_id in module_data.scope.impls() {
let target_trait = match db.impl_trait(impl_id) {
Some(tr) => tr.skip_binders().hir_trait_id(),
None => continue,
};
let self_ty = db.impl_self_ty(impl_id);
let self_ty_fp = TyFingerprint::for_trait_impl(self_ty.skip_binders());
self.map
.entry(target_trait)
.or_default()
.entry(self_ty_fp)
.or_default()
.push(impl_id);
}

// To better support custom derives, collect impls in all unnamed const items.
// const _: () = { ... };
for konst in module_data.scope.unnamed_consts() {
let body = db.body(konst.into());
for (_, block_def_map) in body.blocks(db.upcast()) {
self.collect_def_map(db, &block_def_map);
}
}
}
Expand Down
67 changes: 67 additions & 0 deletions crates/hir_ty/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3740,3 +3740,70 @@ mod future {
"#,
);
}

#[test]
fn local_impl_1() {
check_types(
r#"
trait Trait<T> {
fn foo(&self) -> T;
}

fn test() {
struct S;
impl Trait<u32> for S {
fn foo(&self) { 0 }
}

S.foo();
// ^^^^^^^ u32
}
"#,
);
}

#[test]
fn local_impl_2() {
check_types(
r#"
struct S;

fn test() {
trait Trait<T> {
fn foo(&self) -> T;
}
impl Trait<u32> for S {
fn foo(&self) { 0 }
}

S.foo();
// ^^^^^^^ u32
}
"#,
);
}

#[test]
fn local_impl_3() {
check_types(
r#"
trait Trait<T> {
fn foo(&self) -> T;
}

fn test() {
struct S1;
{
struct S2;

impl Trait<S1> for S2 {
fn foo(&self) { S1 }
}

S2.foo();
// ^^^^^^^^ S1
}
}
"#,
);
}