From 207b4b8e88d3fad105e3fe892c64cdf8d1ace252 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 6 May 2024 11:24:40 -0400
Subject: [PATCH 1/2] Record impl args in the InsepctCandiate rather than
 rematching during select

---
 .../rustc_middle/src/traits/solve/inspect.rs  |  2 +
 .../src/traits/solve/inspect/format.rs        |  3 +
 .../src/solve/eval_ctxt/mod.rs                |  5 ++
 .../src/solve/eval_ctxt/select.rs             | 81 +++++--------------
 .../src/solve/inspect/analyse.rs              | 42 ++++++++--
 .../src/solve/inspect/build.rs                | 30 ++++++-
 .../src/solve/trait_goals.rs                  |  1 +
 7 files changed, 99 insertions(+), 65 deletions(-)

diff --git a/compiler/rustc_middle/src/traits/solve/inspect.rs b/compiler/rustc_middle/src/traits/solve/inspect.rs
index cddf9d5f874a3..2ddcb8aab2530 100644
--- a/compiler/rustc_middle/src/traits/solve/inspect.rs
+++ b/compiler/rustc_middle/src/traits/solve/inspect.rs
@@ -123,6 +123,8 @@ pub enum ProbeStep<'tcx> {
     /// used whenever there are multiple candidates to prove the
     /// current goalby .
     NestedProbe(Probe<'tcx>),
+    /// A trait goal was satisfied by an impl candidate.
+    RecordImplArgs { impl_args: CanonicalState<'tcx, ty::GenericArgsRef<'tcx>> },
     /// A call to `EvalCtxt::evaluate_added_goals_make_canonical_response` with
     /// `Certainty` was made. This is the certainty passed in, so it's not unified
     /// with the certainty of the `try_evaluate_added_goals` that is done within;
diff --git a/compiler/rustc_middle/src/traits/solve/inspect/format.rs b/compiler/rustc_middle/src/traits/solve/inspect/format.rs
index 11aa0e10931cb..e652f0586c4ea 100644
--- a/compiler/rustc_middle/src/traits/solve/inspect/format.rs
+++ b/compiler/rustc_middle/src/traits/solve/inspect/format.rs
@@ -136,6 +136,9 @@ impl<'a, 'b> ProofTreeFormatter<'a, 'b> {
                     ProbeStep::MakeCanonicalResponse { shallow_certainty } => {
                         writeln!(this.f, "EVALUATE GOALS AND MAKE RESPONSE: {shallow_certainty:?}")?
                     }
+                    ProbeStep::RecordImplArgs { impl_args } => {
+                        writeln!(this.f, "RECORDED IMPL ARGS: {impl_args:?}")?
+                    }
                 }
             }
             Ok(())
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
index 773babde0d7b3..b103cbe9d480d 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
@@ -587,6 +587,11 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
 
         Ok(unchanged_certainty)
     }
+
+    /// Record impl args in the proof tree for later access by `InspectCandidate`.
+    pub(crate) fn record_impl_args(&mut self, impl_args: ty::GenericArgsRef<'tcx>) {
+        self.inspect.record_impl_args(self.infcx, self.max_input_universe, impl_args)
+    }
 }
 
 impl<'tcx> EvalCtxt<'_, 'tcx> {
diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
index 16fe045b82de7..3700ddf7ef5d1 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
@@ -1,12 +1,11 @@
 use std::ops::ControlFlow;
 
-use rustc_hir::def_id::DefId;
-use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
+use rustc_infer::infer::InferCtxt;
 use rustc_infer::traits::solve::inspect::ProbeKind;
 use rustc_infer::traits::solve::{CandidateSource, Certainty, Goal};
 use rustc_infer::traits::{
     BuiltinImplSource, ImplSource, ImplSourceUserDefinedData, Obligation, ObligationCause,
-    PolyTraitObligation, PredicateObligation, Selection, SelectionError, SelectionResult,
+    PolyTraitObligation, Selection, SelectionError, SelectionResult,
 };
 use rustc_macros::extension;
 use rustc_span::Span;
@@ -133,32 +132,33 @@ fn to_selection<'tcx>(
         return None;
     }
 
-    let make_nested = || {
-        cand.instantiate_nested_goals(span)
-            .into_iter()
-            .map(|nested| {
-                Obligation::new(
-                    nested.infcx().tcx,
-                    ObligationCause::dummy_with_span(span),
-                    nested.goal().param_env,
-                    nested.goal().predicate,
-                )
-            })
-            .collect()
-    };
+    let (nested, impl_args) = cand.instantiate_nested_goals_and_opt_impl_args(span);
+    let nested = nested
+        .into_iter()
+        .map(|nested| {
+            Obligation::new(
+                nested.infcx().tcx,
+                ObligationCause::dummy_with_span(span),
+                nested.goal().param_env,
+                nested.goal().predicate,
+            )
+        })
+        .collect();
 
     Some(match cand.kind() {
         ProbeKind::TraitCandidate { source, result: _ } => match source {
             CandidateSource::Impl(impl_def_id) => {
                 // FIXME: Remove this in favor of storing this in the tree
                 // For impl candidates, we do the rematch manually to compute the args.
-                ImplSource::UserDefined(rematch_impl(cand.goal(), impl_def_id, span))
-            }
-            CandidateSource::BuiltinImpl(builtin) => ImplSource::Builtin(builtin, make_nested()),
-            CandidateSource::ParamEnv(_) => ImplSource::Param(make_nested()),
-            CandidateSource::AliasBound => {
-                ImplSource::Builtin(BuiltinImplSource::Misc, make_nested())
+                ImplSource::UserDefined(ImplSourceUserDefinedData {
+                    impl_def_id,
+                    args: impl_args.expect("expected recorded impl args for impl candidate"),
+                    nested,
+                })
             }
+            CandidateSource::BuiltinImpl(builtin) => ImplSource::Builtin(builtin, nested),
+            CandidateSource::ParamEnv(_) => ImplSource::Param(nested),
+            CandidateSource::AliasBound => ImplSource::Builtin(BuiltinImplSource::Misc, nested),
             CandidateSource::CoherenceUnknowable => {
                 span_bug!(span, "didn't expect to select an unknowable candidate")
             }
@@ -173,40 +173,3 @@ fn to_selection<'tcx>(
         }
     })
 }
-
-fn rematch_impl<'tcx>(
-    goal: &inspect::InspectGoal<'_, 'tcx>,
-    impl_def_id: DefId,
-    span: Span,
-) -> ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>> {
-    let infcx = goal.infcx();
-    let goal_trait_ref = infcx
-        .enter_forall_and_leak_universe(goal.goal().predicate.to_opt_poly_trait_pred().unwrap())
-        .trait_ref;
-
-    let args = infcx.fresh_args_for_item(span, impl_def_id);
-    let impl_trait_ref =
-        infcx.tcx.impl_trait_ref(impl_def_id).unwrap().instantiate(infcx.tcx, args);
-
-    let InferOk { value: (), obligations: mut nested } = infcx
-        .at(&ObligationCause::dummy_with_span(span), goal.goal().param_env)
-        .eq(DefineOpaqueTypes::Yes, goal_trait_ref, impl_trait_ref)
-        .expect("rematching impl failed");
-
-    // FIXME(-Znext-solver=coinductive): We need to add supertraits here eventually.
-
-    nested.extend(
-        infcx.tcx.predicates_of(impl_def_id).instantiate(infcx.tcx, args).into_iter().map(
-            |(clause, _)| {
-                Obligation::new(
-                    infcx.tcx,
-                    ObligationCause::dummy_with_span(span),
-                    goal.goal().param_env,
-                    clause,
-                )
-            },
-        ),
-    );
-
-    ImplSourceUserDefinedData { impl_def_id, nested, args }
-}
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
index 4f79f1b2aafe0..fa4323a3a944d 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
@@ -93,6 +93,7 @@ pub struct InspectCandidate<'a, 'tcx> {
     kind: inspect::ProbeKind<'tcx>,
     nested_goals: Vec<(GoalSource, inspect::CanonicalState<'tcx, Goal<'tcx, ty::Predicate<'tcx>>>)>,
     final_state: inspect::CanonicalState<'tcx, ()>,
+    impl_args: Option<inspect::CanonicalState<'tcx, ty::GenericArgsRef<'tcx>>>,
     result: QueryResult<'tcx>,
     shallow_certainty: Certainty,
 }
@@ -135,7 +136,20 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
 
     /// Instantiate the nested goals for the candidate without rolling back their
     /// inference constraints. This function modifies the state of the `infcx`.
+    ///
+    /// See [`Self::instantiate_nested_goals_and_opt_impl_args`] if you need the impl args too.
     pub fn instantiate_nested_goals(&self, span: Span) -> Vec<InspectGoal<'a, 'tcx>> {
+        self.instantiate_nested_goals_and_opt_impl_args(span).0
+    }
+
+    /// Instantiate the nested goals for the candidate without rolling back their
+    /// inference constraints, and optionally the args of an impl if this candidate
+    /// came from a `CandidateSource::Impl`. This function modifies the state of the
+    /// `infcx`.
+    pub fn instantiate_nested_goals_and_opt_impl_args(
+        &self,
+        span: Span,
+    ) -> (Vec<InspectGoal<'a, 'tcx>>, Option<ty::GenericArgsRef<'tcx>>) {
         let infcx = self.goal.infcx;
         let param_env = self.goal.goal.param_env;
         let mut orig_values = self.goal.orig_values.to_vec();
@@ -164,6 +178,17 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             self.final_state,
         );
 
+        let impl_args = self.impl_args.map(|impl_args| {
+            canonical::instantiate_canonical_state(
+                infcx,
+                span,
+                param_env,
+                &mut orig_values,
+                impl_args,
+            )
+            .fold_with(&mut EagerResolver::new(infcx))
+        });
+
         if let Some(term_hack) = self.goal.normalizes_to_term_hack {
             // FIXME: We ignore the expected term of `NormalizesTo` goals
             // when computing the result of its candidates. This is
@@ -171,7 +196,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
             let _ = term_hack.constrain(infcx, span, param_env);
         }
 
-        instantiated_goals
+        let goals = instantiated_goals
             .into_iter()
             .map(|(source, goal)| match goal.predicate.kind().no_bound_vars() {
                 Some(ty::PredicateKind::NormalizesTo(ty::NormalizesTo { alias, term })) => {
@@ -208,7 +233,9 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
                     source,
                 ),
             })
-            .collect()
+            .collect();
+
+        (goals, impl_args)
     }
 
     /// Visit all nested goals of this candidate, rolling back
@@ -245,9 +272,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
         probe: &inspect::Probe<'tcx>,
     ) {
         let mut shallow_certainty = None;
+        let mut impl_args = None;
         for step in &probe.steps {
-            match step {
-                &inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
+            match *step {
+                inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
                 inspect::ProbeStep::NestedProbe(ref probe) => {
                     // Nested probes have to prove goals added in their parent
                     // but do not leak them, so we truncate the added goals
@@ -257,7 +285,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
                     nested_goals.truncate(num_goals);
                 }
                 inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
-                    assert_eq!(shallow_certainty.replace(*c), None);
+                    assert_eq!(shallow_certainty.replace(c), None);
+                }
+                inspect::ProbeStep::RecordImplArgs { impl_args: i } => {
+                    assert_eq!(impl_args.replace(i), None);
                 }
                 inspect::ProbeStep::EvaluateGoals(_) => (),
             }
@@ -284,6 +315,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
                         final_state: probe.final_state,
                         result,
                         shallow_certainty,
+                        impl_args,
                     });
                 }
             }
diff --git a/compiler/rustc_trait_selection/src/solve/inspect/build.rs b/compiler/rustc_trait_selection/src/solve/inspect/build.rs
index 466d0d8006018..79d8901a260b4 100644
--- a/compiler/rustc_trait_selection/src/solve/inspect/build.rs
+++ b/compiler/rustc_trait_selection/src/solve/inspect/build.rs
@@ -242,6 +242,7 @@ enum WipProbeStep<'tcx> {
     EvaluateGoals(WipAddedGoalsEvaluation<'tcx>),
     NestedProbe(WipProbe<'tcx>),
     MakeCanonicalResponse { shallow_certainty: Certainty },
+    RecordImplArgs { impl_args: inspect::CanonicalState<'tcx, ty::GenericArgsRef<'tcx>> },
 }
 
 impl<'tcx> WipProbeStep<'tcx> {
@@ -250,6 +251,9 @@ impl<'tcx> WipProbeStep<'tcx> {
             WipProbeStep::AddGoal(source, goal) => inspect::ProbeStep::AddGoal(source, goal),
             WipProbeStep::EvaluateGoals(eval) => inspect::ProbeStep::EvaluateGoals(eval.finalize()),
             WipProbeStep::NestedProbe(probe) => inspect::ProbeStep::NestedProbe(probe.finalize()),
+            WipProbeStep::RecordImplArgs { impl_args } => {
+                inspect::ProbeStep::RecordImplArgs { impl_args }
+            }
             WipProbeStep::MakeCanonicalResponse { shallow_certainty } => {
                 inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty }
             }
@@ -534,6 +538,30 @@ impl<'tcx> ProofTreeBuilder<'tcx> {
         }
     }
 
+    pub(crate) fn record_impl_args(
+        &mut self,
+        infcx: &InferCtxt<'tcx>,
+        max_input_universe: ty::UniverseIndex,
+        impl_args: ty::GenericArgsRef<'tcx>,
+    ) {
+        match self.as_mut() {
+            Some(DebugSolver::GoalEvaluationStep(state)) => {
+                let impl_args = canonical::make_canonical_state(
+                    infcx,
+                    &state.var_values,
+                    max_input_universe,
+                    impl_args,
+                );
+                state
+                    .current_evaluation_scope()
+                    .steps
+                    .push(WipProbeStep::RecordImplArgs { impl_args });
+            }
+            None => {}
+            _ => bug!(),
+        }
+    }
+
     pub fn make_canonical_response(&mut self, shallow_certainty: Certainty) {
         match self.as_mut() {
             Some(DebugSolver::GoalEvaluationStep(state)) => {
@@ -543,7 +571,7 @@ impl<'tcx> ProofTreeBuilder<'tcx> {
                     .push(WipProbeStep::MakeCanonicalResponse { shallow_certainty });
             }
             None => {}
-            _ => {}
+            _ => bug!(),
         }
     }
 
diff --git a/compiler/rustc_trait_selection/src/solve/trait_goals.rs b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
index d2b893d6383bd..0fde9dd4cd680 100644
--- a/compiler/rustc_trait_selection/src/solve/trait_goals.rs
+++ b/compiler/rustc_trait_selection/src/solve/trait_goals.rs
@@ -75,6 +75,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {
 
         ecx.probe_trait_candidate(CandidateSource::Impl(impl_def_id)).enter(|ecx| {
             let impl_args = ecx.fresh_args_for_item(impl_def_id);
+            ecx.record_impl_args(impl_args);
             let impl_trait_ref = impl_trait_header.trait_ref.instantiate(tcx, impl_args);
 
             ecx.eq(goal.param_env, goal.predicate.trait_ref, impl_trait_ref)?;

From e34723997ae8e44538cea5f9f68085d828771e54 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Mon, 6 May 2024 14:11:21 -0400
Subject: [PATCH 2/2] Use correct ImplSource for alias bounds

---
 .../src/solve/eval_ctxt/select.rs                   |  3 +--
 .../next-solver/select-alias-bound-as-param.rs      | 13 +++++++++++++
 2 files changed, 14 insertions(+), 2 deletions(-)
 create mode 100644 tests/ui/traits/next-solver/select-alias-bound-as-param.rs

diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
index 3700ddf7ef5d1..f586a74258d7e 100644
--- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
+++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
@@ -157,8 +157,7 @@ fn to_selection<'tcx>(
                 })
             }
             CandidateSource::BuiltinImpl(builtin) => ImplSource::Builtin(builtin, nested),
-            CandidateSource::ParamEnv(_) => ImplSource::Param(nested),
-            CandidateSource::AliasBound => ImplSource::Builtin(BuiltinImplSource::Misc, nested),
+            CandidateSource::ParamEnv(_) | CandidateSource::AliasBound => ImplSource::Param(nested),
             CandidateSource::CoherenceUnknowable => {
                 span_bug!(span, "didn't expect to select an unknowable candidate")
             }
diff --git a/tests/ui/traits/next-solver/select-alias-bound-as-param.rs b/tests/ui/traits/next-solver/select-alias-bound-as-param.rs
new file mode 100644
index 0000000000000..fd40ef1f872f3
--- /dev/null
+++ b/tests/ui/traits/next-solver/select-alias-bound-as-param.rs
@@ -0,0 +1,13 @@
+//@ check-pass
+//@ compile-flags: -Znext-solver
+
+pub(crate) fn y() -> impl FnMut() {
+    || {}
+}
+
+pub(crate) fn x(a: (), b: ()) {
+    let x = ();
+    y()()
+}
+
+fn main() {}