|
| 1 | +//! Compute the object-safety of a trait |
| 2 | +
|
| 3 | +use std::ops::ControlFlow; |
| 4 | + |
| 5 | +use chalk_ir::{ |
| 6 | + visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, |
| 7 | + BoundVar, DebruijnIndex, WhereClause, |
| 8 | +}; |
| 9 | +use hir_def::{ |
| 10 | + lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId, |
| 11 | + TypeAliasId, |
| 12 | +}; |
| 13 | + |
| 14 | +use crate::{ |
| 15 | + all_super_traits, db::HirDatabase, from_chalk_trait_id, generics::generics, |
| 16 | + layout::LayoutError, lower::callable_item_sig, make_single_type_binders, static_lifetime, |
| 17 | + utils::elaborate_clause_supertraits, wrap_empty_binders, AliasEq, AliasTy, Binders, DynTy, |
| 18 | + Interner, QuantifiedWhereClauses, Substitution, TyBuilder, TyKind, |
| 19 | +}; |
| 20 | + |
| 21 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 22 | +pub enum ObjectSafetyError { |
| 23 | + LayoutError(LayoutError), |
| 24 | +} |
| 25 | + |
| 26 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 27 | +pub enum ObjectSafetyViolation { |
| 28 | + SizedSelf, |
| 29 | + SelfReferential, |
| 30 | + NonLifetimeBinder, |
| 31 | + Method(FunctionId, MethodViolationCode), |
| 32 | + AssocConst(ConstId), |
| 33 | + GAT(TypeAliasId), |
| 34 | + // This doesn't exist in rustc, but added for better visualization |
| 35 | + HasNonSafeSuperTrait(TraitId), |
| 36 | +} |
| 37 | + |
| 38 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 39 | +pub enum MethodViolationCode { |
| 40 | + StaticMethod, |
| 41 | + ReferencesSelfInput, |
| 42 | + ReferencesSelfOutput, |
| 43 | + ReferencesImplTraitInTrait, |
| 44 | + AsyncFn, |
| 45 | + WhereClauseReferencesSelf, |
| 46 | + Generic, |
| 47 | + UndispatchableReceiver, |
| 48 | + HasNonLifetimeTypeParam, |
| 49 | + NonReceiverSelfParam, |
| 50 | +} |
| 51 | + |
| 52 | +// Basically, this is almost same as `rustc_trait_selection::traits::object_safety` |
| 53 | +// but some difference; |
| 54 | +// |
| 55 | +// 1. While rustc gathers almost every violation, but this only early return on |
| 56 | +// first violation for perf. |
| 57 | +// |
| 58 | +// These can be changed anytime while implementing. |
| 59 | +pub fn object_safety_of_trait_query( |
| 60 | + db: &dyn HirDatabase, |
| 61 | + trait_: TraitId, |
| 62 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 63 | + for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1) { |
| 64 | + if db.object_safety_of_trait(super_trait)?.is_some() { |
| 65 | + return Ok(Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait))); |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + // Check whether this has a `Sized` bound |
| 70 | + if generics_require_sized_self(db, trait_.into()) { |
| 71 | + return Ok(Some(ObjectSafetyViolation::SizedSelf)); |
| 72 | + } |
| 73 | + |
| 74 | + if predicates_reference_self(db, trait_) { |
| 75 | + return Ok(Some(ObjectSafetyViolation::SelfReferential)); |
| 76 | + } |
| 77 | + |
| 78 | + // TODO: bound referencing self |
| 79 | + |
| 80 | + // TODO: non lifetime binder |
| 81 | + |
| 82 | + let trait_data = db.trait_data(trait_); |
| 83 | + for (_, assoc_item) in &trait_data.items { |
| 84 | + let item_violation = object_safety_violation_for_assoc_item(db, trait_, *assoc_item)?; |
| 85 | + if item_violation.is_some() { |
| 86 | + return Ok(item_violation); |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + Ok(None) |
| 91 | +} |
| 92 | + |
| 93 | +fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool { |
| 94 | + let krate = def.module(db.upcast()).krate(); |
| 95 | + let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else { |
| 96 | + return false; |
| 97 | + }; |
| 98 | + |
| 99 | + let predicates = &*db.generic_predicates(def); |
| 100 | + let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone()); |
| 101 | + elaborate_clause_supertraits(db, predicates).any(|pred| match pred { |
| 102 | + WhereClause::Implemented(trait_ref) => { |
| 103 | + if from_chalk_trait_id(trait_ref.trait_id) == sized { |
| 104 | + if let TyKind::BoundVar(it) = |
| 105 | + *trait_ref.self_type_parameter(Interner).kind(Interner) |
| 106 | + { |
| 107 | + // Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of |
| 108 | + // self-parameter is `1` |
| 109 | + return it.index_if_bound_at(DebruijnIndex::new(1)).is_some_and(|i| i == 0); |
| 110 | + } |
| 111 | + } |
| 112 | + false |
| 113 | + } |
| 114 | + _ => false, |
| 115 | + }) |
| 116 | +} |
| 117 | + |
| 118 | +fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool { |
| 119 | + db.generic_predicates(trait_.into()) |
| 120 | + .iter() |
| 121 | + .any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No)) |
| 122 | +} |
| 123 | + |
| 124 | +fn bounds_reference_self() {} |
| 125 | + |
| 126 | +#[derive(Clone, Copy)] |
| 127 | +enum AllowSelfProjection { |
| 128 | + Yes, |
| 129 | + No, |
| 130 | +} |
| 131 | + |
| 132 | +fn predicate_references_self( |
| 133 | + db: &dyn HirDatabase, |
| 134 | + trait_: TraitId, |
| 135 | + predicate: &Binders<Binders<WhereClause<Interner>>>, |
| 136 | + allow_self_projection: AllowSelfProjection, |
| 137 | +) -> bool { |
| 138 | + match predicate.skip_binders().skip_binders() { |
| 139 | + WhereClause::Implemented(trait_ref) => { |
| 140 | + trait_ref.substitution.iter(Interner).skip(1).any(|arg| { |
| 141 | + contains_illegal_self_type_reference(db, trait_, arg, allow_self_projection) |
| 142 | + }) |
| 143 | + } |
| 144 | + WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(proj), .. }) => { |
| 145 | + proj.substitution.iter(Interner).skip(1).any(|arg| { |
| 146 | + contains_illegal_self_type_reference(db, trait_, arg, allow_self_projection) |
| 147 | + }) |
| 148 | + } |
| 149 | + _ => false, |
| 150 | + } |
| 151 | +} |
| 152 | + |
| 153 | +fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>( |
| 154 | + db: &dyn HirDatabase, |
| 155 | + trait_: TraitId, |
| 156 | + t: &T, |
| 157 | + allow_self_projection: AllowSelfProjection, |
| 158 | +) -> bool { |
| 159 | + struct IllegalSelfTypeVisitor; |
| 160 | + impl TypeVisitor<Interner> for IllegalSelfTypeVisitor { |
| 161 | + type BreakTy = (); |
| 162 | + |
| 163 | + fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> { |
| 164 | + self |
| 165 | + } |
| 166 | + |
| 167 | + fn interner(&self) -> Interner { |
| 168 | + Interner |
| 169 | + } |
| 170 | + |
| 171 | + fn visit_ty( |
| 172 | + &mut self, |
| 173 | + ty: &chalk_ir::Ty<Interner>, |
| 174 | + outer_binder: DebruijnIndex, |
| 175 | + ) -> ControlFlow<Self::BreakTy> { |
| 176 | + match ty.kind(Interner) { |
| 177 | + TyKind::BoundVar(BoundVar { debruijn: DebruijnIndex::ONE, index: 0 }) => { |
| 178 | + ControlFlow::Break(()) |
| 179 | + } |
| 180 | + // TODO: RPITIT -> Continue |
| 181 | + TyKind::Alias(AliasTy::Projection(proj)) => { |
| 182 | + todo!() |
| 183 | + } |
| 184 | + _ => ty.super_visit_with(self.as_dyn(), outer_binder), |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + fn visit_const( |
| 189 | + &mut self, |
| 190 | + constant: &chalk_ir::Const<Interner>, |
| 191 | + outer_binder: DebruijnIndex, |
| 192 | + ) -> std::ops::ControlFlow<Self::BreakTy> { |
| 193 | + constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder) |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + let mut visitor = IllegalSelfTypeVisitor; |
| 198 | + t.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST).is_break() |
| 199 | +} |
| 200 | + |
| 201 | +fn object_safety_violation_for_assoc_item( |
| 202 | + db: &dyn HirDatabase, |
| 203 | + trait_: TraitId, |
| 204 | + item: AssocItemId, |
| 205 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 206 | + match item { |
| 207 | + AssocItemId::ConstId(it) => Ok(Some(ObjectSafetyViolation::AssocConst(it))), |
| 208 | + AssocItemId::FunctionId(it) => virtual_call_violations_for_method(db, trait_, it) |
| 209 | + .map(|v| v.map(|v| ObjectSafetyViolation::Method(it, v))) |
| 210 | + .map_err(ObjectSafetyError::LayoutError), |
| 211 | + AssocItemId::TypeAliasId(it) => { |
| 212 | + let generics = generics(db.upcast(), it.into()); |
| 213 | + // rustc checks if the `generic_associate_type_extended` feature gate is set |
| 214 | + if generics.len_self() > 0 && db.type_alias_impl_traits(it).is_none() { |
| 215 | + Ok(Some(ObjectSafetyViolation::GAT(it))) |
| 216 | + } else { |
| 217 | + Ok(None) |
| 218 | + } |
| 219 | + } |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +fn virtual_call_violations_for_method( |
| 224 | + db: &dyn HirDatabase, |
| 225 | + trait_: TraitId, |
| 226 | + func: FunctionId, |
| 227 | +) -> Result<Option<MethodViolationCode>, LayoutError> { |
| 228 | + let func_data = db.function_data(func); |
| 229 | + if !func_data.has_self_param() { |
| 230 | + return Ok(Some(MethodViolationCode::StaticMethod)); |
| 231 | + } |
| 232 | + |
| 233 | + // TODO: check self reference in params |
| 234 | + |
| 235 | + // TODO: check self reference in return type |
| 236 | + |
| 237 | + // TODO: check asyncness, RPIT |
| 238 | + |
| 239 | + let generic_params = db.generic_params(func.into()); |
| 240 | + if generic_params.len_type_or_consts() > 0 { |
| 241 | + return Ok(Some(MethodViolationCode::Generic)); |
| 242 | + } |
| 243 | + |
| 244 | + // Check if the receiver is a correct type like `Self`, `Box<Self>`, `Arc<Self>`, etc |
| 245 | + // |
| 246 | + // TODO: rustc does this in two steps :thinking_face: |
| 247 | + // I'm doing only the second, real one, layout check |
| 248 | + // TODO: clean all the messes for building receiver types to check layout of |
| 249 | + |
| 250 | + // Check for types like `Rc<()>` |
| 251 | + let sig = callable_item_sig(db, func.into()); |
| 252 | + // TODO: Getting receiver type that substituted `Self` by `()`. there might be more clever way? |
| 253 | + let subst = Substitution::from_iter( |
| 254 | + Interner, |
| 255 | + std::iter::repeat(TyBuilder::unit()).take(sig.len(Interner)), |
| 256 | + ); |
| 257 | + let sig = sig.substitute(Interner, &subst); |
| 258 | + let receiver_ty = sig.params()[0].to_owned(); |
| 259 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 260 | + |
| 261 | + if !matches!(layout.abi, rustc_abi::Abi::Scalar(..)) { |
| 262 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 263 | + } |
| 264 | + |
| 265 | + // Check for types like `Rc<dyn Trait>` |
| 266 | + // TODO: `dyn Trait` and receiver type building is a total mess |
| 267 | + let trait_ref = |
| 268 | + TyBuilder::trait_ref(db, trait_).fill_with_bound_vars(DebruijnIndex::INNERMOST, 0).build(); |
| 269 | + let bound = wrap_empty_binders(WhereClause::Implemented(trait_ref)); |
| 270 | + let bounds = QuantifiedWhereClauses::from_iter(Interner, [bound]); |
| 271 | + let dyn_trait = TyKind::Dyn(DynTy { |
| 272 | + bounds: make_single_type_binders(bounds), |
| 273 | + lifetime: static_lifetime(), |
| 274 | + }) |
| 275 | + .intern(Interner); |
| 276 | + let sig = callable_item_sig(db, func.into()); |
| 277 | + let subst = Substitution::from_iter( |
| 278 | + Interner, |
| 279 | + std::iter::once(dyn_trait) |
| 280 | + .chain(std::iter::repeat(TyBuilder::unit())) |
| 281 | + .take(sig.len(Interner)), |
| 282 | + ); |
| 283 | + let sig = sig.substitute(Interner, &subst); |
| 284 | + let receiver_ty = sig.params()[0].to_owned(); |
| 285 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 286 | + |
| 287 | + if !matches!(layout.abi, rustc_abi::Abi::ScalarPair(..)) { |
| 288 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 289 | + } |
| 290 | + |
| 291 | + Ok(None) |
| 292 | +} |
0 commit comments