|
| 1 | +//! Compute the object-safety of a trait |
| 2 | +
|
| 3 | +use chalk_ir::{ |
| 4 | + visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, |
| 5 | + DebruijnIndex, WhereClause, |
| 6 | +}; |
| 7 | +use hir_def::{ |
| 8 | + lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId, |
| 9 | + TypeAliasId, |
| 10 | +}; |
| 11 | + |
| 12 | +use crate::{ |
| 13 | + all_super_traits, db::HirDatabase, from_chalk_trait_id, generics::generics, |
| 14 | + layout::LayoutError, lower::callable_item_sig, make_single_type_binders, static_lifetime, |
| 15 | + utils::elaborate_clause_supertraits, wrap_empty_binders, DynTy, Interner, |
| 16 | + QuantifiedWhereClauses, Substitution, TyBuilder, TyKind, |
| 17 | +}; |
| 18 | + |
| 19 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 20 | +pub enum ObjectSafetyError { |
| 21 | + LayoutError(LayoutError), |
| 22 | +} |
| 23 | + |
| 24 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 25 | +pub enum ObjectSafetyViolation { |
| 26 | + SizedSelf, |
| 27 | + SelfReferential, |
| 28 | + NonLifetimeBinder, |
| 29 | + Method(FunctionId, MethodViolationCode), |
| 30 | + AssocConst(ConstId), |
| 31 | + GAT(TypeAliasId), |
| 32 | + // This doesn't exist in rustc, but added for better visualization |
| 33 | + HasNonSafeSuperTrait(TraitId), |
| 34 | +} |
| 35 | + |
| 36 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 37 | +pub enum MethodViolationCode { |
| 38 | + StaticMethod, |
| 39 | + ReferencesSelfInput, |
| 40 | + ReferencesSelfOutput, |
| 41 | + ReferencesImplTraitInTrait, |
| 42 | + AsyncFn, |
| 43 | + WhereClauseReferencesSelf, |
| 44 | + Generic, |
| 45 | + UndispatchableReceiver, |
| 46 | + HasNonLifetimeTypeParam, |
| 47 | + NonReceiverSelfParam, |
| 48 | +} |
| 49 | + |
| 50 | +// Basically, this is almost same as `rustc_trait_selection::traits::object_safety` |
| 51 | +// but some difference; |
| 52 | +// |
| 53 | +// 1. While rustc gathers almost every violation, but this only early return on |
| 54 | +// first violation for perf. |
| 55 | +// |
| 56 | +// These can be changed anytime while implementing. |
| 57 | +pub fn object_safety_of_trait_query( |
| 58 | + db: &dyn HirDatabase, |
| 59 | + trait_: TraitId, |
| 60 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 61 | + for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1) { |
| 62 | + if db.object_safety_of_trait(super_trait)?.is_some() { |
| 63 | + return Ok(Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait))); |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + // Check whether this has a `Sized` bound |
| 68 | + if generics_require_sized_self(db, trait_.into()) { |
| 69 | + return Ok(Some(ObjectSafetyViolation::SizedSelf)); |
| 70 | + } |
| 71 | + |
| 72 | + // TODO: bound referencing self |
| 73 | + |
| 74 | + // TODO: non lifetime binder |
| 75 | + |
| 76 | + let trait_data = db.trait_data(trait_); |
| 77 | + for (_, assoc_item) in &trait_data.items { |
| 78 | + let item_violation = object_safety_violation_for_assoc_item(db, trait_, *assoc_item)?; |
| 79 | + if item_violation.is_some() { |
| 80 | + return Ok(item_violation); |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + Ok(None) |
| 85 | +} |
| 86 | + |
| 87 | +fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool { |
| 88 | + let krate = def.module(db.upcast()).krate(); |
| 89 | + let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else { |
| 90 | + return false; |
| 91 | + }; |
| 92 | + |
| 93 | + let predicates = &*db.generic_predicates(def); |
| 94 | + let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone()); |
| 95 | + elaborate_clause_supertraits(db, predicates).any(|pred| match pred { |
| 96 | + WhereClause::Implemented(trait_ref) => { |
| 97 | + if from_chalk_trait_id(trait_ref.trait_id) == sized { |
| 98 | + if let TyKind::BoundVar(it) = |
| 99 | + *trait_ref.self_type_parameter(Interner).kind(Interner) |
| 100 | + { |
| 101 | + // Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of |
| 102 | + // self-parameter is `1` |
| 103 | + return it.index_if_bound_at(DebruijnIndex::new(1)).is_some_and(|i| i == 0); |
| 104 | + } |
| 105 | + } |
| 106 | + false |
| 107 | + } |
| 108 | + _ => false, |
| 109 | + }) |
| 110 | +} |
| 111 | + |
| 112 | +fn predicates_reference_self(db: &dyn HirDatabase, def: GenericDefId, trait_: TraitId) -> bool { |
| 113 | + let generics = generics(db.upcast(), def); |
| 114 | + let subst = generics.placeholder_subst(db); |
| 115 | + for pred in db.generic_predicates(def).iter() { |
| 116 | + let pred = pred.clone().substitute(Interner, &subst); |
| 117 | + match pred.skip_binders() { |
| 118 | + WhereClause::Implemented(trait_ref) => { |
| 119 | + // trait_ref.type_parameters(Interner).skip(1).any(|a| a.) |
| 120 | + } |
| 121 | + WhereClause::AliasEq(alias_eq) => {} |
| 122 | + _ => {} |
| 123 | + } |
| 124 | + } |
| 125 | + |
| 126 | + false |
| 127 | +} |
| 128 | + |
| 129 | +fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>( |
| 130 | + db: &dyn HirDatabase, |
| 131 | + def: GenericDefId, |
| 132 | + t: &T, |
| 133 | + allow_self_projection: bool, |
| 134 | +) -> bool { |
| 135 | + // Move this into TyKind::Placehoder match below |
| 136 | + let generic_params = db.generic_params(def); |
| 137 | + // Not this way. we need to check what is `Self` |
| 138 | + // The following is not object safe |
| 139 | + // trait S: Into<Self> {} |
| 140 | + // fn wtf(s: &dyn S) {} |
| 141 | + let self_param = generic_params.trait_self_param(); |
| 142 | + struct IllegalSelfTypeVisitor; |
| 143 | + impl TypeVisitor<Interner> for IllegalSelfTypeVisitor { |
| 144 | + type BreakTy = (); |
| 145 | + |
| 146 | + fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> { |
| 147 | + self |
| 148 | + } |
| 149 | + |
| 150 | + fn interner(&self) -> Interner { |
| 151 | + Interner |
| 152 | + } |
| 153 | + |
| 154 | + fn visit_ty( |
| 155 | + &mut self, |
| 156 | + ty: &chalk_ir::Ty<Interner>, |
| 157 | + outer_binder: DebruijnIndex, |
| 158 | + ) -> std::ops::ControlFlow<Self::BreakTy> { |
| 159 | + match ty.kind(Interner) { |
| 160 | + TyKind::Placeholder(it) => { |
| 161 | + todo!() |
| 162 | + } |
| 163 | + // TODO: rustc |
| 164 | + _ => ty.super_visit_with(self.as_dyn(), outer_binder), |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + fn visit_const( |
| 169 | + &mut self, |
| 170 | + constant: &chalk_ir::Const<Interner>, |
| 171 | + outer_binder: DebruijnIndex, |
| 172 | + ) -> std::ops::ControlFlow<Self::BreakTy> { |
| 173 | + constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder) |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + let mut visitor = IllegalSelfTypeVisitor; |
| 178 | + t.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST).is_break() |
| 179 | +} |
| 180 | + |
| 181 | +fn object_safety_violation_for_assoc_item( |
| 182 | + db: &dyn HirDatabase, |
| 183 | + trait_: TraitId, |
| 184 | + item: AssocItemId, |
| 185 | +) -> Result<Option<ObjectSafetyViolation>, ObjectSafetyError> { |
| 186 | + match item { |
| 187 | + AssocItemId::ConstId(it) => Ok(Some(ObjectSafetyViolation::AssocConst(it))), |
| 188 | + AssocItemId::FunctionId(it) => virtual_call_violations_for_method(db, trait_, it) |
| 189 | + .map(|v| v.map(|v| ObjectSafetyViolation::Method(it, v))) |
| 190 | + .map_err(ObjectSafetyError::LayoutError), |
| 191 | + AssocItemId::TypeAliasId(it) => { |
| 192 | + let generics = generics(db.upcast(), it.into()); |
| 193 | + // rustc checks if the `generic_associate_type_extended` feature gate is set |
| 194 | + if generics.len_self() > 0 && db.type_alias_impl_traits(it).is_none() { |
| 195 | + Ok(Some(ObjectSafetyViolation::GAT(it))) |
| 196 | + } else { |
| 197 | + Ok(None) |
| 198 | + } |
| 199 | + } |
| 200 | + } |
| 201 | +} |
| 202 | + |
| 203 | +fn virtual_call_violations_for_method( |
| 204 | + db: &dyn HirDatabase, |
| 205 | + trait_: TraitId, |
| 206 | + func: FunctionId, |
| 207 | +) -> Result<Option<MethodViolationCode>, LayoutError> { |
| 208 | + let func_data = db.function_data(func); |
| 209 | + if !func_data.has_self_param() { |
| 210 | + return Ok(Some(MethodViolationCode::StaticMethod)); |
| 211 | + } |
| 212 | + |
| 213 | + // TODO: check self reference in params |
| 214 | + |
| 215 | + // TODO: check self reference in return type |
| 216 | + |
| 217 | + // TODO: check asyncness, RPIT |
| 218 | + |
| 219 | + let generic_params = db.generic_params(func.into()); |
| 220 | + if generic_params.len_type_or_consts() > 0 { |
| 221 | + return Ok(Some(MethodViolationCode::Generic)); |
| 222 | + } |
| 223 | + |
| 224 | + // Check if the receiver is a correct type like `Self`, `Box<Self>`, `Arc<Self>`, etc |
| 225 | + // |
| 226 | + // TODO: rustc does this in two steps :thinking_face: |
| 227 | + // I'm doing only the second, real one, layout check |
| 228 | + // TODO: clean all the messes for building receiver types to check layout of |
| 229 | + |
| 230 | + // Check for types like `Rc<()>` |
| 231 | + let sig = callable_item_sig(db, func.into()); |
| 232 | + // TODO: Getting receiver type that substituted `Self` by `()`. there might be more clever way? |
| 233 | + let subst = Substitution::from_iter( |
| 234 | + Interner, |
| 235 | + std::iter::repeat(TyBuilder::unit()).take(sig.len(Interner)), |
| 236 | + ); |
| 237 | + let sig = sig.substitute(Interner, &subst); |
| 238 | + let receiver_ty = sig.params()[0].to_owned(); |
| 239 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 240 | + |
| 241 | + if !matches!(layout.abi, rustc_abi::Abi::Scalar(..)) { |
| 242 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 243 | + } |
| 244 | + |
| 245 | + // Check for types like `Rc<dyn Trait>` |
| 246 | + // TODO: `dyn Trait` and receiver type building is a total mess |
| 247 | + let trait_ref = |
| 248 | + TyBuilder::trait_ref(db, trait_).fill_with_bound_vars(DebruijnIndex::INNERMOST, 0).build(); |
| 249 | + let bound = wrap_empty_binders(WhereClause::Implemented(trait_ref)); |
| 250 | + let bounds = QuantifiedWhereClauses::from_iter(Interner, [bound]); |
| 251 | + let dyn_trait = TyKind::Dyn(DynTy { |
| 252 | + bounds: make_single_type_binders(bounds), |
| 253 | + lifetime: static_lifetime(), |
| 254 | + }) |
| 255 | + .intern(Interner); |
| 256 | + let sig = callable_item_sig(db, func.into()); |
| 257 | + let subst = Substitution::from_iter( |
| 258 | + Interner, |
| 259 | + std::iter::once(dyn_trait) |
| 260 | + .chain(std::iter::repeat(TyBuilder::unit())) |
| 261 | + .take(sig.len(Interner)), |
| 262 | + ); |
| 263 | + let sig = sig.substitute(Interner, &subst); |
| 264 | + let receiver_ty = sig.params()[0].to_owned(); |
| 265 | + let layout = db.layout_of_ty(receiver_ty, db.trait_environment(trait_.into()))?; |
| 266 | + |
| 267 | + if !matches!(layout.abi, rustc_abi::Abi::ScalarPair(..)) { |
| 268 | + return Ok(Some(MethodViolationCode::UndispatchableReceiver)); |
| 269 | + } |
| 270 | + |
| 271 | + Ok(None) |
| 272 | +} |
0 commit comments