Skip to content

Structurally resolve projections (but actually) in the new solver #108833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
@@ -38,11 +38,13 @@ use rustc_middle::ty::{
use rustc_span::def_id::CRATE_DEF_ID;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::VariantIdx;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::query::type_op::custom::scrape_region_constraints;
use rustc_trait_selection::traits::query::type_op::custom::CustomTypeOp;
use rustc_trait_selection::traits::query::type_op::{TypeOp, TypeOpOutput};
use rustc_trait_selection::traits::query::Fallible;
use rustc_trait_selection::traits::PredicateObligation;
use rustc_trait_selection::traits::{fully_solve_obligation, Obligation, ObligationCause};

use rustc_mir_dataflow::impls::MaybeInitializedPlaces;
use rustc_mir_dataflow::move_paths::MoveData;
@@ -1154,6 +1156,31 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
self.infcx.tcx
}

fn structurally_resolved_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx().trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: DUMMY_SP,
});
let obligation = Obligation::new(
self.tcx(),
ObligationCause::dummy(),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);

if self.infcx.predicate_may_hold(&obligation) {
fully_solve_obligation(&self.infcx, obligation);
return self.infcx.resolve_vars_if_possible(new_infer_ty);
}
}

ty
}

#[instrument(skip(self, body, location), level = "debug")]
fn check_stmt(&mut self, body: &Body<'tcx>, stmt: &Statement<'tcx>, location: Location) {
let tcx = self.tcx();
@@ -1871,6 +1898,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
Rvalue::Cast(cast_kind, op, ty) => {
self.check_operand(op, location);

let structurally_resolved_cast_tys = || {
(
self.structurally_resolved_ty(op.ty(body, tcx)),
self.structurally_resolved_ty(*ty),
)
};

match cast_kind {
CastKind::Pointer(PointerCast::ReifyFnPointer) => {
let fn_sig = op.ty(body, tcx).fn_sig(tcx);
@@ -1902,7 +1936,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::Pointer(PointerCast::ClosureFnPointer(unsafety)) => {
let sig = match op.ty(body, tcx).kind() {
let sig = match self.structurally_resolved_ty(op.ty(body, tcx)).kind() {
ty::Closure(_, substs) => substs.as_closure().sig(),
_ => bug!(),
};
@@ -1971,10 +2005,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
// get the constraints from the target type (`dyn* Clone`)
//
// apply them to prove that the source type `Foo` implements `Clone` etc
let (existential_predicates, region) = match ty.kind() {
Dynamic(predicates, region, ty::DynStar) => (predicates, region),
_ => panic!("Invalid dyn* cast_ty"),
};
let (existential_predicates, region) =
match self.structurally_resolved_ty(*ty).kind() {
Dynamic(predicates, region, ty::DynStar) => (predicates, region),
_ => panic!("Invalid dyn* cast_ty"),
};

let self_ty = op.ty(body, tcx);

@@ -2001,7 +2036,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let ty::RawPtr(ty::TypeAndMut {
ty: ty_from,
mutbl: hir::Mutability::Mut,
}) = op.ty(body, tcx).kind() else {
}) = self.structurally_resolved_ty(op.ty(body, tcx)).kind() else {
span_mirbug!(
self,
rvalue,
@@ -2013,7 +2048,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let ty::RawPtr(ty::TypeAndMut {
ty: ty_to,
mutbl: hir::Mutability::Not,
}) = ty.kind() else {
}) = self.structurally_resolved_ty(*ty).kind() else {
span_mirbug!(
self,
rvalue,
@@ -2042,7 +2077,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
CastKind::Pointer(PointerCast::ArrayToPointer) => {
let ty_from = op.ty(body, tcx);

let opt_ty_elem_mut = match ty_from.kind() {
let opt_ty_elem_mut = match self.structurally_resolved_ty(ty_from).kind() {
ty::RawPtr(ty::TypeAndMut { mutbl: array_mut, ty: array_ty }) => {
match array_ty.kind() {
ty::Array(ty_elem, _) => Some((ty_elem, *array_mut)),
@@ -2062,7 +2097,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
return;
};

let (ty_to, ty_to_mut) = match ty.kind() {
let (ty_to, ty_to_mut) = match self.structurally_resolved_ty(*ty).kind() {
ty::RawPtr(ty::TypeAndMut { mutbl: ty_to_mut, ty: ty_to }) => {
(ty_to, *ty_to_mut)
}
@@ -2106,9 +2141,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::PointerExposeAddress => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Ptr(_) | CastTy::FnPtr), Some(CastTy::Int(_))) => (),
_ => {
@@ -2124,9 +2159,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::PointerFromExposedAddress => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Ptr(_))) => (),
_ => {
@@ -2141,9 +2176,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::IntToInt => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Int(_))) => (),
_ => {
@@ -2158,9 +2193,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::IntToFloat => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Float)) => (),
_ => {
@@ -2175,9 +2210,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FloatToInt => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Float), Some(CastTy::Int(_))) => (),
_ => {
@@ -2192,9 +2227,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FloatToFloat => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Float), Some(CastTy::Float)) => (),
_ => {
@@ -2209,9 +2244,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FnPtrToPtr => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::FnPtr), Some(CastTy::Ptr(_))) => (),
_ => {
@@ -2226,9 +2261,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::PtrToPtr => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Ptr(_)), Some(CastTy::Ptr(_))) => (),
_ => {
13 changes: 12 additions & 1 deletion compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ use rustc_span::{self, BytePos, DesugaringKind, Span};
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::infer::InferCtxtExt as _;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::{
self, NormalizeExt, ObligationCause, ObligationCauseCode, ObligationCtxt,
};
@@ -144,12 +145,22 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
debug!("unify(a: {:?}, b: {:?}, use_lub: {})", a, b, self.use_lub);
self.commit_if_ok(|_| {
let at = self.at(&self.cause, self.fcx.param_env).define_opaque_types(true);
if self.use_lub {
let result = if self.use_lub {
at.lub(b, a)
} else {
at.sup(b, a)
.map(|InferOk { value: (), obligations }| InferOk { value: a, obligations })
}?;

if self.tcx.trait_solver_next() {
for obligation in &result.obligations {
if !self.predicate_may_hold(obligation) {
return Err(TypeError::Mismatch);
}
}
}

Ok(result)
})
}

24 changes: 23 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ use rustc_hir_analysis::astconv::{
};
use rustc_infer::infer::canonical::{Canonical, OriginalQueryValues, QueryResponse};
use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_infer::infer::InferResult;
use rustc_middle::ty::adjustment::{Adjust, Adjustment, AutoBorrow, AutoBorrowMutability};
use rustc_middle::ty::error::TypeError;
@@ -34,6 +35,7 @@ use rustc_span::hygiene::DesugaringKind;
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::Span;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::{self, NormalizeExt, ObligationCauseCode, ObligationCtxt};

use std::collections::hash_map::Entry;
@@ -1408,7 +1410,27 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// Resolves `typ` by a single level if `typ` is a type variable.
/// If no resolution is possible, then an error is reported.
/// Numeric inference variables may be left unresolved.
pub fn structurally_resolved_type(&self, sp: Span, ty: Ty<'tcx>) -> Ty<'tcx> {
pub fn structurally_resolved_type(&self, sp: Span, mut ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx.trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: sp,
});
let obligation = traits::Obligation::new(
self.tcx,
self.misc(sp),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);
if self.predicate_may_hold(&obligation) {
self.register_predicate(obligation);
ty = new_infer_ty;
}
}

let ty = self.resolve_vars_with_obligations(ty);
if !ty.is_ty_var() {
ty
12 changes: 6 additions & 6 deletions compiler/rustc_mir_build/src/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ use rustc_middle::mir::AssertKind;
use rustc_middle::mir::Place;
use rustc_middle::mir::*;
use rustc_middle::thir::*;
use rustc_middle::ty::cast::{mir_cast_kind, CastTy};
use rustc_middle::ty::cast::{mir_cast_kind};
use rustc_middle::ty::{self, Ty, UpvarSubsts};
use rustc_span::Span;

@@ -263,11 +263,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
);
(source, ty)
};
let from_ty = CastTy::from_ty(ty);
let cast_ty = CastTy::from_ty(expr.ty);
debug!("ExprKind::Cast from_ty={from_ty:?}, cast_ty={:?}/{cast_ty:?}", expr.ty,);
let cast_kind = mir_cast_kind(ty, expr.ty);
block.and(Rvalue::Cast(cast_kind, source, expr.ty))

let ty = this.structurally_resolved_ty(ty);
let cast_ty = this.structurally_resolved_ty(expr.ty);
let cast_kind = mir_cast_kind(ty, cast_ty);
block.and(Rvalue::Cast(cast_kind, source, cast_ty))
}
ExprKind::Pointer { cast, source } => {
let source = unpack!(
30 changes: 29 additions & 1 deletion compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{GeneratorKind, Node};
use rustc_index::vec::{Idx, IndexVec};
use rustc_infer::infer::type_variable::{TypeVariableOriginKind, TypeVariableOrigin};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_middle::hir::place::PlaceBase as HirPlaceBase;
use rustc_middle::middle::region;
@@ -22,9 +23,11 @@ use rustc_middle::thir::{
};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::symbol::sym;
use rustc_span::Span;
use rustc_span::{Span, DUMMY_SP};
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::traits;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;

use super::lints;

@@ -229,6 +232,31 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
fn var_local_id(&self, id: LocalVarId, for_guard: ForGuard) -> Local {
self.var_indices[&id].local_id(for_guard)
}

fn structurally_resolved_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx.trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: DUMMY_SP,
});
let obligation = traits::Obligation::new(
self.tcx,
traits::ObligationCause::dummy(),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);

if self.infcx.predicate_may_hold(&obligation) {
traits::fully_solve_obligation(&self.infcx, obligation);
return self.infcx.resolve_vars_if_possible(new_infer_ty);
}
}

ty
}
}

impl BlockContext {
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// compile-flags: -Ztrait-solver=next
// known-bug: unknown
// check-pass

fn main() {
(0u8 + 0u8) as char;

This file was deleted.