From 9b5a2a4a4861500d64b9f1ebbe6c1140eda0a493 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 24 Jan 2023 21:43:21 +0000
Subject: [PATCH] Use new solver during selection

---
 .../src/traits/select/mod.rs                  | 57 ++++++++++++++-----
 1 file changed, 43 insertions(+), 14 deletions(-)

diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index f90da95d51668..1d23634b6aacf 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -38,6 +38,8 @@ use rustc_errors::Diagnostic;
 use rustc_hir as hir;
 use rustc_hir::def_id::DefId;
 use rustc_infer::infer::LateBoundRegionConversionTime;
+use rustc_infer::traits::TraitEngine;
+use rustc_infer::traits::TraitEngineExt;
 use rustc_middle::dep_graph::{DepKind, DepNodeIndex};
 use rustc_middle::mir::interpret::ErrorHandled;
 use rustc_middle::ty::abstract_const::NotConstEvaluatable;
@@ -47,6 +49,7 @@ use rustc_middle::ty::relate::TypeRelation;
 use rustc_middle::ty::SubstsRef;
 use rustc_middle::ty::{self, EarlyBinder, PolyProjectionPredicate, ToPolyTraitRef, ToPredicate};
 use rustc_middle::ty::{Ty, TyCtxt, TypeFoldable, TypeVisitable};
+use rustc_session::config::TraitSolver;
 use rustc_span::symbol::sym;
 
 use std::cell::{Cell, RefCell};
@@ -544,10 +547,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         obligation: &PredicateObligation<'tcx>,
     ) -> Result<EvaluationResult, OverflowError> {
         self.evaluation_probe(|this| {
-            this.evaluate_predicate_recursively(
-                TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
-                obligation.clone(),
-            )
+            if this.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+                this.evaluate_predicate_recursively(
+                    TraitObligationStackList::empty(&ProvisionalEvaluationCache::default()),
+                    obligation.clone(),
+                )
+            } else {
+                this.evaluate_predicates_recursively_in_new_solver([obligation.clone()])
+            }
         })
     }
 
@@ -586,18 +593,40 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
     where
         I: IntoIterator<Item = PredicateObligation<'tcx>> + std::fmt::Debug,
     {
-        let mut result = EvaluatedToOk;
-        for obligation in predicates {
-            let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
-            if let EvaluatedToErr = eval {
-                // fast-path - EvaluatedToErr is the top of the lattice,
-                // so we don't need to look on the other predicates.
-                return Ok(EvaluatedToErr);
-            } else {
-                result = cmp::max(result, eval);
+        if self.tcx().sess.opts.unstable_opts.trait_solver != TraitSolver::Next {
+            let mut result = EvaluatedToOk;
+            for obligation in predicates {
+                let eval = self.evaluate_predicate_recursively(stack, obligation.clone())?;
+                if let EvaluatedToErr = eval {
+                    // fast-path - EvaluatedToErr is the top of the lattice,
+                    // so we don't need to look on the other predicates.
+                    return Ok(EvaluatedToErr);
+                } else {
+                    result = cmp::max(result, eval);
+                }
             }
+            Ok(result)
+        } else {
+            self.evaluate_predicates_recursively_in_new_solver(predicates)
         }
-        Ok(result)
+    }
+
+    /// Evaluates the predicates using the new solver when `-Ztrait-solver=next` is enabled
+    fn evaluate_predicates_recursively_in_new_solver(
+        &mut self,
+        predicates: impl IntoIterator<Item = PredicateObligation<'tcx>>,
+    ) -> Result<EvaluationResult, OverflowError> {
+        let mut fulfill_cx = crate::solve::FulfillmentCtxt::new();
+        fulfill_cx.register_predicate_obligations(self.infcx, predicates);
+        // True errors
+        if !fulfill_cx.select_where_possible(self.infcx).is_empty() {
+            return Ok(EvaluatedToErr);
+        }
+        if !fulfill_cx.select_all_or_error(self.infcx).is_empty() {
+            return Ok(EvaluatedToAmbig);
+        }
+        // Regions and opaques are handled in the `evaluation_probe` by looking at the snapshot
+        Ok(EvaluatedToOk)
     }
 
     #[instrument(