Skip to content

Commit 9fcb78e

Browse files
committed
WIP - doing experiments
1 parent 6446173 commit 9fcb78e

File tree

6 files changed

+414
-3
lines changed

6 files changed

+414
-3
lines changed

crates/hir-ty/src/db.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use base_db::{
1111
use hir_def::{
1212
db::DefDatabase, hir::ExprId, layout::TargetDataLayout, AdtId, BlockId, CallableDefId,
1313
ConstParamId, DefWithBodyId, EnumVariantId, FunctionId, GeneralConstId, GenericDefId, ImplId,
14-
LifetimeParamId, LocalFieldId, StaticId, TypeAliasId, TypeOrConstParamId, VariantId,
14+
LifetimeParamId, LocalFieldId, StaticId, TraitId, TypeAliasId, TypeOrConstParamId, VariantId,
1515
};
1616
use la_arena::ArenaMap;
1717
use smallvec::SmallVec;
@@ -24,6 +24,7 @@ use crate::{
2424
lower::{GenericDefaults, GenericPredicates},
2525
method_resolution::{InherentImpls, TraitImpls, TyFingerprint},
2626
mir::{BorrowckResult, MirBody, MirLowerError},
27+
object_safety::{ObjectSafetyError, ObjectSafetyViolation},
2728
Binders, ClosureId, Const, FnDefId, ImplTraitId, ImplTraits, InferenceResult, Interner,
2829
PolyFnSig, Substitution, TraitEnvironment, TraitRef, Ty, TyDefId, ValueTyDefId,
2930
};
@@ -104,6 +105,12 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
104105
#[salsa::cycle(crate::layout::layout_of_ty_recover)]
105106
fn layout_of_ty(&self, ty: Ty, env: Arc<TraitEnvironment>) -> Result<Arc<Layout>, LayoutError>;
106107

108+
#[salsa::invoke(crate::object_safety::object_safety_of_trait_query)]
109+
fn object_safety_of_trait(
110+
&self,
111+
trait_: TraitId,
112+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError>;
113+
107114
#[salsa::invoke(crate::layout::target_data_layout_query)]
108115
fn target_data_layout(&self, krate: CrateId) -> Result<Arc<TargetDataLayout>, Arc<str>>;
109116

crates/hir-ty/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub mod lang_items;
4040
pub mod layout;
4141
pub mod method_resolution;
4242
pub mod mir;
43+
pub mod object_safety;
4344
pub mod primitive;
4445
pub mod traits;
4546

crates/hir-ty/src/object_safety.rs

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
//! Compute the object-safety of a trait
2+
3+
use chalk_ir::{
4+
visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
5+
DebruijnIndex, WhereClause,
6+
};
7+
use hir_def::{
8+
lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId,
9+
TypeAliasId,
10+
};
11+
12+
use crate::{
13+
all_super_traits, db::HirDatabase, from_chalk_trait_id, generics::generics,
14+
layout::LayoutError, lower::callable_item_sig, make_single_type_binders, static_lifetime,
15+
utils::elaborate_clause_supertraits, wrap_empty_binders, DynTy, Interner,
16+
QuantifiedWhereClauses, Substitution, TyBuilder, TyKind,
17+
};
18+
19+
#[derive(Debug, Clone, PartialEq, Eq)]
20+
pub enum ObjectSafetyError {
21+
LayoutError(LayoutError),
22+
}
23+
24+
#[derive(Debug, Clone, PartialEq, Eq)]
25+
pub enum ObjectSafetyViolation {
26+
SizedSelf,
27+
SelfReferential,
28+
NonLifetimeBinder,
29+
Method(FunctionId, MethodViolationCode),
30+
AssocConst(ConstId),
31+
GAT(TypeAliasId),
32+
// This doesn't exist in rustc, but added for better visualization
33+
HasNonSafeSuperTrait(TraitId),
34+
}
35+
36+
#[derive(Debug, Clone, PartialEq, Eq)]
37+
pub enum MethodViolationCode {
38+
StaticMethod,
39+
ReferencesSelfInput,
40+
ReferencesSelfOutput,
41+
ReferencesImplTraitInTrait,
42+
AsyncFn,
43+
WhereClauseReferencesSelf,
44+
Generic,
45+
UndispatchableReceiver,
46+
HasNonLifetimeTypeParam,
47+
NonReceiverSelfParam,
48+
}
49+
50+
// Basically, this is almost same as `rustc_trait_selection::traits::object_safety`
51+
// but some difference;
52+
//
53+
// 1. While rustc gathers almost every violation, but this only early return on
54+
// first violation for perf.
55+
//
56+
// These can be changed anytime while implementing.
57+
pub fn object_safety_of_trait_query(
58+
db: &dyn HirDatabase,
59+
trait_: TraitId,
60+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
61+
for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1) {
62+
if db.object_safety_of_trait(super_trait)?.is_some() {
63+
return Ok(Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait)));
64+
}
65+
}
66+
67+
// Check whether this has a `Sized` bound
68+
if generics_require_sized_self(db, trait_.into()) {
69+
return Ok(Some(ObjectSafetyViolation::SizedSelf));
70+
}
71+
72+
// TODO: bound referencing self
73+
74+
// TODO: non lifetime binder
75+
76+
let trait_data = db.trait_data(trait_);
77+
for (_, assoc_item) in &trait_data.items {
78+
let item_violation = object_safety_violation_for_assoc_item(db, trait_, *assoc_item)?;
79+
if item_violation.is_some() {
80+
return Ok(item_violation);
81+
}
82+
}
83+
84+
Ok(None)
85+
}
86+
87+
fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool {
88+
let krate = def.module(db.upcast()).krate();
89+
let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else {
90+
return false;
91+
};
92+
93+
let predicates = &*db.generic_predicates(def);
94+
let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone());
95+
elaborate_clause_supertraits(db, predicates).any(|pred| match pred {
96+
WhereClause::Implemented(trait_ref) => {
97+
if from_chalk_trait_id(trait_ref.trait_id) == sized {
98+
if let TyKind::BoundVar(it) =
99+
*trait_ref.self_type_parameter(Interner).kind(Interner)
100+
{
101+
// Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of
102+
// self-parameter is `1`
103+
return it.index_if_bound_at(DebruijnIndex::new(1)).is_some_and(|i| i == 0);
104+
}
105+
}
106+
false
107+
}
108+
_ => false,
109+
})
110+
}
111+
112+
fn predicates_reference_self(db: &dyn HirDatabase, def: GenericDefId, trait_: TraitId) -> bool {
113+
let generics = generics(db.upcast(), def);
114+
let subst = generics.placeholder_subst(db);
115+
for pred in db.generic_predicates(def).iter() {
116+
let pred = pred.clone().substitute(Interner, &subst);
117+
match pred.skip_binders() {
118+
WhereClause::Implemented(trait_ref) => {
119+
// trait_ref.type_parameters(Interner).skip(1).any(|a| a.)
120+
}
121+
WhereClause::AliasEq(alias_eq) => {}
122+
_ => {}
123+
}
124+
}
125+
126+
false
127+
}
128+
129+
fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>(
130+
db: &dyn HirDatabase,
131+
def: GenericDefId,
132+
t: &T,
133+
allow_self_projection: bool,
134+
) -> bool {
135+
// Move this into TyKind::Placehoder match below
136+
let generic_params = db.generic_params(def);
137+
// Not this way. we need to check what is `Self`
138+
// The following is not object safe
139+
// trait S: Into<Self> {}
140+
// fn wtf(s: &dyn S) {}
141+
let self_param = generic_params.trait_self_param();
142+
struct IllegalSelfTypeVisitor;
143+
impl TypeVisitor<Interner> for IllegalSelfTypeVisitor {
144+
type BreakTy = ();
145+
146+
fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
147+
self
148+
}
149+
150+
fn interner(&self) -> Interner {
151+
Interner
152+
}
153+
154+
fn visit_ty(
155+
&mut self,
156+
ty: &chalk_ir::Ty<Interner>,
157+
outer_binder: DebruijnIndex,
158+
) -> std::ops::ControlFlow<Self::BreakTy> {
159+
match ty.kind(Interner) {
160+
TyKind::Placeholder(it) => {
161+
todo!()
162+
}
163+
// TODO: rustc
164+
_ => ty.super_visit_with(self.as_dyn(), outer_binder),
165+
}
166+
}
167+
168+
fn visit_const(
169+
&mut self,
170+
constant: &chalk_ir::Const<Interner>,
171+
outer_binder: DebruijnIndex,
172+
) -> std::ops::ControlFlow<Self::BreakTy> {
173+
constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder)
174+
}
175+
}
176+
177+
let mut visitor = IllegalSelfTypeVisitor;
178+
t.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST).is_break()
179+
}
180+
181+
fn object_safety_violation_for_assoc_item(
182+
db: &dyn HirDatabase,
183+
trait_: TraitId,
184+
item: AssocItemId,
185+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
186+
match item {
187+
AssocItemId::ConstId(it) => Ok(Some(ObjectSafetyViolation::AssocConst(it))),
188+
AssocItemId::FunctionId(it) => virtual_call_violations_for_method(db, trait_, it)
189+
.map(|v| v.map(|v| ObjectSafetyViolation::Method(it, v)))
190+
.map_err(ObjectSafetyError::LayoutError),
191+
AssocItemId::TypeAliasId(it) => {
192+
let generics = generics(db.upcast(), it.into());
193+
// rustc checks if the `generic_associate_type_extended` feature gate is set
194+
if generics.len_self() > 0 && db.type_alias_impl_traits(it).is_none() {
195+
Ok(Some(ObjectSafetyViolation::GAT(it)))
196+
} else {
197+
Ok(None)
198+
}
199+
}
200+
}
201+
}
202+
203+
fn virtual_call_violations_for_method(
204+
db: &dyn HirDatabase,
205+
trait_: TraitId,
206+
func: FunctionId,
207+
) -> Result<Option<MethodViolationCode>, LayoutError> {
208+
let func_data = db.function_data(func);
209+
if !func_data.has_self_param() {
210+
return Ok(Some(MethodViolationCode::StaticMethod));
211+
}
212+
213+
// TODO: check self reference in params
214+
215+
// TODO: check self reference in return type
216+
217+
// TODO: check asyncness, RPIT
218+
219+
let generic_params = db.generic_params(func.into());
220+
if generic_params.len_type_or_consts() > 0 {
221+
return Ok(Some(MethodViolationCode::Generic));
222+
}
223+
224+
// Check if the receiver is a correct type like `Self`, `Box<Self>`, `Arc<Self>`, etc
225+
//
226+
// TODO: rustc does this in two steps :thinking_face:
227+
// I'm doing only the second, real one, layout check
228+
// TODO: clean all the messes for building receiver types to check layout of
229+
230+
// Check for types like `Rc<()>`
231+
let sig = callable_item_sig(db, func.into());
232+
// TODO: Getting receiver type that substituted `Self` by `()`. there might be more clever way?
233+
let subst = Substitution::from_iter(
234+
Interner,
235+
std::iter::repeat(TyBuilder::unit()).take(sig.len(Interner)),
236+
);
237+
let sig = sig.substitute(Interner, &subst);
238+
let receiver_ty = sig.params()[0].to_owned();
239+
let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?;
240+
241+
if !matches!(layout.abi, rustc_abi::Abi::Scalar(..)) {
242+
return Ok(Some(MethodViolationCode::UndispatchableReceiver));
243+
}
244+
245+
// Check for types like `Rc<dyn Trait>`
246+
// TODO: `dyn Trait` and receiver type building is a total mess
247+
let trait_ref =
248+
TyBuilder::trait_ref(db, trait_).fill_with_bound_vars(DebruijnIndex::INNERMOST, 0).build();
249+
let bound = wrap_empty_binders(WhereClause::Implemented(trait_ref));
250+
let bounds = QuantifiedWhereClauses::from_iter(Interner, [bound]);
251+
let dyn_trait = TyKind::Dyn(DynTy {
252+
bounds: make_single_type_binders(bounds),
253+
lifetime: static_lifetime(),
254+
})
255+
.intern(Interner);
256+
let sig = callable_item_sig(db, func.into());
257+
let subst = Substitution::from_iter(
258+
Interner,
259+
std::iter::once(dyn_trait)
260+
.chain(std::iter::repeat(TyBuilder::unit()))
261+
.take(sig.len(Interner)),
262+
);
263+
let sig = sig.substitute(Interner, &subst);
264+
let receiver_ty = sig.params()[0].to_owned();
265+
let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?;
266+
267+
if !matches!(layout.abi, rustc_abi::Abi::ScalarPair(..)) {
268+
return Ok(Some(MethodViolationCode::UndispatchableReceiver));
269+
}
270+
271+
Ok(None)
272+
}

crates/hir/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use hir_ty::{
6666
diagnostics::BodyValidationDiagnostic,
6767
error_lifetime, known_const_to_ast,
6868
layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding},
69-
method_resolution::{self},
69+
method_resolution,
7070
mir::{interpret_mir, MutBorrowKind},
7171
primitive::UintTy,
7272
traits::FnTrait,
@@ -145,6 +145,7 @@ pub use {
145145
display::{ClosureStyle, HirDisplay, HirDisplayError, HirWrite},
146146
layout::LayoutError,
147147
mir::{MirEvalError, MirLowerError},
148+
object_safety::{MethodViolationCode, ObjectSafetyError, ObjectSafetyViolation},
148149
FnAbi, PointerCast, Safety,
149150
},
150151
// FIXME: Properly encapsulate mir
@@ -2641,6 +2642,13 @@ impl Trait {
26412642
.count()
26422643
}
26432644

2645+
pub fn object_safety(
2646+
&self,
2647+
db: &dyn HirDatabase,
2648+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
2649+
db.object_safety_of_trait(self.id)
2650+
}
2651+
26442652
fn all_macro_calls(&self, db: &dyn HirDatabase) -> Box<[(AstId<ast::Item>, MacroCallId)]> {
26452653
db.trait_data(self.id)
26462654
.macro_calls

crates/ide/src/hover/render.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{mem, ops::Not};
44
use either::Either;
55
use hir::{
66
Adt, AsAssocItem, AsExternAssocItem, CaptureKind, HasCrate, HasSource, HirDisplay, Layout,
7-
LayoutError, Name, Semantics, Trait, Type, TypeInfo,
7+
LayoutError, Name, ObjectSafetyError, ObjectSafetyViolation, Semantics, Trait, Type, TypeInfo,
88
};
99
use ide_db::{
1010
base_db::SourceDatabase,
@@ -526,6 +526,12 @@ pub(super) fn definition(
526526
_ => None,
527527
};
528528

529+
let object_safety_info = if let Definition::Trait(it) = def {
530+
render_object_safety(it.object_safety(db))
531+
} else {
532+
None
533+
};
534+
529535
let mut desc = String::new();
530536
if let Some(notable_traits) = render_notable_trait_comment(db, notable_traits, edition) {
531537
desc.push_str(&notable_traits);
@@ -535,6 +541,10 @@ pub(super) fn definition(
535541
desc.push_str(&layout_info);
536542
desc.push('\n');
537543
}
544+
if let Some(object_safety_info) = object_safety_info {
545+
desc.push_str(&object_safety_info);
546+
desc.push('\n');
547+
}
538548
desc.push_str(&label);
539549
if let Some(value) = value {
540550
desc.push_str(" = ");
@@ -964,3 +974,10 @@ fn keyword_hints(
964974
_ => KeywordHint::new(token.text().to_owned(), format!("{}_keyword", token.text())),
965975
}
966976
}
977+
978+
fn render_object_safety(
979+
safety: Result<Option<ObjectSafetyViolation>, ObjectSafetyError>,
980+
) -> Option<String> {
981+
// TODO: not implemented
982+
Some(format!("{safety:?}"))
983+
}

0 commit comments

Comments
 (0)