Skip to content

Commit b93159d

Browse files
committed
WIP - doing experiments
1 parent 011e3bb commit b93159d

File tree

6 files changed

+434
-3
lines changed

6 files changed

+434
-3
lines changed

crates/hir-ty/src/db.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use base_db::{
1111
use hir_def::{
1212
db::DefDatabase, hir::ExprId, layout::TargetDataLayout, AdtId, BlockId, CallableDefId,
1313
ConstParamId, DefWithBodyId, EnumVariantId, FunctionId, GeneralConstId, GenericDefId, ImplId,
14-
LifetimeParamId, LocalFieldId, StaticId, TypeAliasId, TypeOrConstParamId, VariantId,
14+
LifetimeParamId, LocalFieldId, StaticId, TraitId, TypeAliasId, TypeOrConstParamId, VariantId,
1515
};
1616
use la_arena::ArenaMap;
1717
use smallvec::SmallVec;
@@ -24,6 +24,7 @@ use crate::{
2424
lower::{GenericDefaults, GenericPredicates},
2525
method_resolution::{InherentImpls, TraitImpls, TyFingerprint},
2626
mir::{BorrowckResult, MirBody, MirLowerError},
27+
object_safety::{ObjectSafetyError, ObjectSafetyViolation},
2728
Binders, ClosureId, Const, FnDefId, ImplTraitId, ImplTraits, InferenceResult, Interner,
2829
PolyFnSig, Substitution, TraitEnvironment, TraitRef, Ty, TyDefId, ValueTyDefId,
2930
};
@@ -104,6 +105,12 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
104105
#[salsa::cycle(crate::layout::layout_of_ty_recover)]
105106
fn layout_of_ty(&self, ty: Ty, env: Arc<TraitEnvironment>) -> Result<Arc<Layout>, LayoutError>;
106107

108+
#[salsa::invoke(crate::object_safety::object_safety_of_trait_query)]
109+
fn object_safety_of_trait(
110+
&self,
111+
trait_: TraitId,
112+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError>;
113+
107114
#[salsa::invoke(crate::layout::target_data_layout_query)]
108115
fn target_data_layout(&self, krate: CrateId) -> Result<Arc<TargetDataLayout>, Arc<str>>;
109116

crates/hir-ty/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub mod lang_items;
4040
pub mod layout;
4141
pub mod method_resolution;
4242
pub mod mir;
43+
pub mod object_safety;
4344
pub mod primitive;
4445
pub mod traits;
4546

crates/hir-ty/src/object_safety.rs

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
//! Compute the object-safety of a trait
2+
3+
use std::ops::ControlFlow;
4+
5+
use chalk_ir::{
6+
visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
7+
BoundVar, DebruijnIndex, WhereClause,
8+
};
9+
use hir_def::{
10+
lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId,
11+
TypeAliasId,
12+
};
13+
14+
use crate::{
15+
all_super_traits, db::HirDatabase, from_chalk_trait_id, generics::generics,
16+
layout::LayoutError, lower::callable_item_sig, make_single_type_binders, static_lifetime,
17+
utils::elaborate_clause_supertraits, wrap_empty_binders, AliasEq, AliasTy, Binders, DynTy,
18+
Interner, QuantifiedWhereClauses, Substitution, TyBuilder, TyKind,
19+
};
20+
21+
#[derive(Debug, Clone, PartialEq, Eq)]
22+
pub enum ObjectSafetyError {
23+
LayoutError(LayoutError),
24+
}
25+
26+
#[derive(Debug, Clone, PartialEq, Eq)]
27+
pub enum ObjectSafetyViolation {
28+
SizedSelf,
29+
SelfReferential,
30+
NonLifetimeBinder,
31+
Method(FunctionId, MethodViolationCode),
32+
AssocConst(ConstId),
33+
GAT(TypeAliasId),
34+
// This doesn't exist in rustc, but added for better visualization
35+
HasNonSafeSuperTrait(TraitId),
36+
}
37+
38+
#[derive(Debug, Clone, PartialEq, Eq)]
39+
pub enum MethodViolationCode {
40+
StaticMethod,
41+
ReferencesSelfInput,
42+
ReferencesSelfOutput,
43+
ReferencesImplTraitInTrait,
44+
AsyncFn,
45+
WhereClauseReferencesSelf,
46+
Generic,
47+
UndispatchableReceiver,
48+
HasNonLifetimeTypeParam,
49+
NonReceiverSelfParam,
50+
}
51+
52+
// Basically, this is almost same as `rustc_trait_selection::traits::object_safety`
53+
// but some difference;
54+
//
55+
// 1. While rustc gathers almost every violation, but this only early return on
56+
// first violation for perf.
57+
//
58+
// These can be changed anytime while implementing.
59+
pub fn object_safety_of_trait_query(
60+
db: &dyn HirDatabase,
61+
trait_: TraitId,
62+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
63+
for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1) {
64+
if db.object_safety_of_trait(super_trait)?.is_some() {
65+
return Ok(Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait)));
66+
}
67+
}
68+
69+
// Check whether this has a `Sized` bound
70+
if generics_require_sized_self(db, trait_.into()) {
71+
return Ok(Some(ObjectSafetyViolation::SizedSelf));
72+
}
73+
74+
if predicates_reference_self(db, trait_) {
75+
return Ok(Some(ObjectSafetyViolation::SelfReferential));
76+
}
77+
78+
// TODO: bound referencing self
79+
80+
// TODO: non lifetime binder
81+
82+
let trait_data = db.trait_data(trait_);
83+
for (_, assoc_item) in &trait_data.items {
84+
let item_violation = object_safety_violation_for_assoc_item(db, trait_, *assoc_item)?;
85+
if item_violation.is_some() {
86+
return Ok(item_violation);
87+
}
88+
}
89+
90+
Ok(None)
91+
}
92+
93+
fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool {
94+
let krate = def.module(db.upcast()).krate();
95+
let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else {
96+
return false;
97+
};
98+
99+
let predicates = &*db.generic_predicates(def);
100+
let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone());
101+
elaborate_clause_supertraits(db, predicates).any(|pred| match pred {
102+
WhereClause::Implemented(trait_ref) => {
103+
if from_chalk_trait_id(trait_ref.trait_id) == sized {
104+
if let TyKind::BoundVar(it) =
105+
*trait_ref.self_type_parameter(Interner).kind(Interner)
106+
{
107+
// Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of
108+
// self-parameter is `1`
109+
return it.index_if_bound_at(DebruijnIndex::new(1)).is_some_and(|i| i == 0);
110+
}
111+
}
112+
false
113+
}
114+
_ => false,
115+
})
116+
}
117+
118+
fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
119+
db.generic_predicates(trait_.into())
120+
.iter()
121+
.any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No))
122+
}
123+
124+
fn bounds_reference_self() {}
125+
126+
#[derive(Clone, Copy)]
127+
enum AllowSelfProjection {
128+
Yes,
129+
No,
130+
}
131+
132+
fn predicate_references_self(
133+
db: &dyn HirDatabase,
134+
trait_: TraitId,
135+
predicate: &Binders<Binders<WhereClause<Interner>>>,
136+
allow_self_projection: AllowSelfProjection,
137+
) -> bool {
138+
match predicate.skip_binders().skip_binders() {
139+
WhereClause::Implemented(trait_ref) => {
140+
trait_ref.substitution.iter(Interner).skip(1).any(|arg| {
141+
contains_illegal_self_type_reference(db, trait_, arg, allow_self_projection)
142+
})
143+
}
144+
WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(proj), .. }) => {
145+
proj.substitution.iter(Interner).skip(1).any(|arg| {
146+
contains_illegal_self_type_reference(db, trait_, arg, allow_self_projection)
147+
})
148+
}
149+
_ => false,
150+
}
151+
}
152+
153+
fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>(
154+
db: &dyn HirDatabase,
155+
trait_: TraitId,
156+
t: &T,
157+
allow_self_projection: AllowSelfProjection,
158+
) -> bool {
159+
struct IllegalSelfTypeVisitor;
160+
impl TypeVisitor<Interner> for IllegalSelfTypeVisitor {
161+
type BreakTy = ();
162+
163+
fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
164+
self
165+
}
166+
167+
fn interner(&self) -> Interner {
168+
Interner
169+
}
170+
171+
fn visit_ty(
172+
&mut self,
173+
ty: &chalk_ir::Ty<Interner>,
174+
outer_binder: DebruijnIndex,
175+
) -> ControlFlow<Self::BreakTy> {
176+
match ty.kind(Interner) {
177+
TyKind::BoundVar(BoundVar { debruijn: DebruijnIndex::ONE, index: 0 }) => {
178+
ControlFlow::Break(())
179+
}
180+
// TODO: RPITIT -> Continue
181+
TyKind::Alias(AliasTy::Projection(proj)) => {
182+
todo!()
183+
}
184+
_ => ty.super_visit_with(self.as_dyn(), outer_binder),
185+
}
186+
}
187+
188+
fn visit_const(
189+
&mut self,
190+
constant: &chalk_ir::Const<Interner>,
191+
outer_binder: DebruijnIndex,
192+
) -> std::ops::ControlFlow<Self::BreakTy> {
193+
constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder)
194+
}
195+
}
196+
197+
let mut visitor = IllegalSelfTypeVisitor;
198+
t.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST).is_break()
199+
}
200+
201+
fn object_safety_violation_for_assoc_item(
202+
db: &dyn HirDatabase,
203+
trait_: TraitId,
204+
item: AssocItemId,
205+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
206+
match item {
207+
AssocItemId::ConstId(it) => Ok(Some(ObjectSafetyViolation::AssocConst(it))),
208+
AssocItemId::FunctionId(it) => virtual_call_violations_for_method(db, trait_, it)
209+
.map(|v| v.map(|v| ObjectSafetyViolation::Method(it, v)))
210+
.map_err(ObjectSafetyError::LayoutError),
211+
AssocItemId::TypeAliasId(it) => {
212+
let generics = generics(db.upcast(), it.into());
213+
// rustc checks if the `generic_associate_type_extended` feature gate is set
214+
if generics.len_self() > 0 && db.type_alias_impl_traits(it).is_none() {
215+
Ok(Some(ObjectSafetyViolation::GAT(it)))
216+
} else {
217+
Ok(None)
218+
}
219+
}
220+
}
221+
}
222+
223+
fn virtual_call_violations_for_method(
224+
db: &dyn HirDatabase,
225+
trait_: TraitId,
226+
func: FunctionId,
227+
) -> Result<Option<MethodViolationCode>, LayoutError> {
228+
let func_data = db.function_data(func);
229+
if !func_data.has_self_param() {
230+
return Ok(Some(MethodViolationCode::StaticMethod));
231+
}
232+
233+
// TODO: check self reference in params
234+
235+
// TODO: check self reference in return type
236+
237+
// TODO: check asyncness, RPIT
238+
239+
let generic_params = db.generic_params(func.into());
240+
if generic_params.len_type_or_consts() > 0 {
241+
return Ok(Some(MethodViolationCode::Generic));
242+
}
243+
244+
// Check if the receiver is a correct type like `Self`, `Box<Self>`, `Arc<Self>`, etc
245+
//
246+
// TODO: rustc does this in two steps :thinking_face:
247+
// I'm doing only the second, real one, layout check
248+
// TODO: clean all the messes for building receiver types to check layout of
249+
250+
// Check for types like `Rc<()>`
251+
let sig = callable_item_sig(db, func.into());
252+
// TODO: Getting receiver type that substituted `Self` by `()`. there might be more clever way?
253+
let subst = Substitution::from_iter(
254+
Interner,
255+
std::iter::repeat(TyBuilder::unit()).take(sig.len(Interner)),
256+
);
257+
let sig = sig.substitute(Interner, &subst);
258+
let receiver_ty = sig.params()[0].to_owned();
259+
let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?;
260+
261+
if !matches!(layout.abi, rustc_abi::Abi::Scalar(..)) {
262+
return Ok(Some(MethodViolationCode::UndispatchableReceiver));
263+
}
264+
265+
// Check for types like `Rc<dyn Trait>`
266+
// TODO: `dyn Trait` and receiver type building is a total mess
267+
let trait_ref =
268+
TyBuilder::trait_ref(db, trait_).fill_with_bound_vars(DebruijnIndex::INNERMOST, 0).build();
269+
let bound = wrap_empty_binders(WhereClause::Implemented(trait_ref));
270+
let bounds = QuantifiedWhereClauses::from_iter(Interner, [bound]);
271+
let dyn_trait = TyKind::Dyn(DynTy {
272+
bounds: make_single_type_binders(bounds),
273+
lifetime: static_lifetime(),
274+
})
275+
.intern(Interner);
276+
let sig = callable_item_sig(db, func.into());
277+
let subst = Substitution::from_iter(
278+
Interner,
279+
std::iter::once(dyn_trait)
280+
.chain(std::iter::repeat(TyBuilder::unit()))
281+
.take(sig.len(Interner)),
282+
);
283+
let sig = sig.substitute(Interner, &subst);
284+
let receiver_ty = sig.params()[0].to_owned();
285+
let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?;
286+
287+
if !matches!(layout.abi, rustc_abi::Abi::ScalarPair(..)) {
288+
return Ok(Some(MethodViolationCode::UndispatchableReceiver));
289+
}
290+
291+
Ok(None)
292+
}

crates/hir/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use hir_ty::{
6666
diagnostics::BodyValidationDiagnostic,
6767
error_lifetime, known_const_to_ast,
6868
layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding},
69-
method_resolution::{self},
69+
method_resolution,
7070
mir::{interpret_mir, MutBorrowKind},
7171
primitive::UintTy,
7272
traits::FnTrait,
@@ -145,6 +145,7 @@ pub use {
145145
display::{ClosureStyle, HirDisplay, HirDisplayError, HirWrite},
146146
layout::LayoutError,
147147
mir::{MirEvalError, MirLowerError},
148+
object_safety::{MethodViolationCode, ObjectSafetyError, ObjectSafetyViolation},
148149
FnAbi, PointerCast, Safety,
149150
},
150151
// FIXME: Properly encapsulate mir
@@ -2641,6 +2642,13 @@ impl Trait {
26412642
.count()
26422643
}
26432644

2645+
pub fn object_safety(
2646+
&self,
2647+
db: &dyn HirDatabase,
2648+
) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> {
2649+
db.object_safety_of_trait(self.id)
2650+
}
2651+
26442652
fn all_macro_calls(&self, db: &dyn HirDatabase) -> Box<[(AstId<ast::Item>, MacroCallId)]> {
26452653
db.trait_data(self.id)
26462654
.macro_calls

0 commit comments

Comments
 (0)