From 64b58dd13b06a23d8429a73c225347b4cd3b2c3c Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Thu, 3 Apr 2025 18:53:48 +0000
Subject: [PATCH] Pass correct param-env to error_implies

---
 compiler/rustc_infer/src/infer/mod.rs         |  3 +-
 .../traits/fulfillment_errors.rs              | 32 ++++++++++-----
 .../src/error_reporting/traits/mod.rs         | 28 +++++--------
 ...onst-param-has-ty-goal-in-error-implies.rs | 41 +++++++++++++++++++
 ...-param-has-ty-goal-in-error-implies.stderr | 23 +++++++++++
 5 files changed, 98 insertions(+), 29 deletions(-)
 create mode 100644 tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.rs
 create mode 100644 tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.stderr

diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index fa8dea064daaa..5fc44b8f356b5 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -27,6 +27,7 @@ use rustc_middle::bug;
 use rustc_middle::infer::canonical::{CanonicalQueryInput, CanonicalVarValues};
 use rustc_middle::mir::ConstraintCategory;
 use rustc_middle::traits::select;
+use rustc_middle::traits::solve::Goal;
 use rustc_middle::ty::error::{ExpectedFound, TypeError};
 use rustc_middle::ty::{
     self, BoundVarReplacerDelegate, ConstVid, FloatVid, GenericArg, GenericArgKind, GenericArgs,
@@ -268,7 +269,7 @@ pub struct InferCtxt<'tcx> {
     /// The set of predicates on which errors have been reported, to
     /// avoid reporting the same error twice.
     pub reported_trait_errors:
-        RefCell<FxIndexMap<Span, (Vec<ty::Predicate<'tcx>>, ErrorGuaranteed)>>,
+        RefCell<FxIndexMap<Span, (Vec<Goal<'tcx, ty::Predicate<'tcx>>>, ErrorGuaranteed)>>,
 
     pub reported_signature_mismatch: RefCell<FxHashSet<(Span, Option<Span>)>>,
 
diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
index 73ae5177c4895..bc45fc11e9bbe 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs
@@ -14,6 +14,7 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE, LocalDefId};
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::{self as hir, LangItem, Node};
 use rustc_infer::infer::{InferOk, TypeTrace};
+use rustc_infer::traits::solve::Goal;
 use rustc_middle::traits::SignatureMismatchData;
 use rustc_middle::traits::select::OverflowError;
 use rustc_middle::ty::abstract_const::NotConstEvaluatable;
@@ -930,7 +931,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
             )) = arg.kind
             && let Node::Pat(pat) = self.tcx.hir_node(*hir_id)
             && let Some((preds, guar)) = self.reported_trait_errors.borrow().get(&pat.span)
-            && preds.contains(&obligation.predicate)
+            && preds.contains(&obligation.as_goal())
         {
             return Err(*guar);
         }
@@ -1292,6 +1293,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
 impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
     fn can_match_trait(
         &self,
+        param_env: ty::ParamEnv<'tcx>,
         goal: ty::TraitPredicate<'tcx>,
         assumption: ty::PolyTraitPredicate<'tcx>,
     ) -> bool {
@@ -1306,11 +1308,12 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
             assumption,
         );
 
-        self.can_eq(ty::ParamEnv::empty(), goal.trait_ref, trait_assumption.trait_ref)
+        self.can_eq(param_env, goal.trait_ref, trait_assumption.trait_ref)
     }
 
     fn can_match_projection(
         &self,
+        param_env: ty::ParamEnv<'tcx>,
         goal: ty::ProjectionPredicate<'tcx>,
         assumption: ty::PolyProjectionPredicate<'tcx>,
     ) -> bool {
@@ -1320,7 +1323,6 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
             assumption,
         );
 
-        let param_env = ty::ParamEnv::empty();
         self.can_eq(param_env, goal.projection_term, assumption.projection_term)
             && self.can_eq(param_env, goal.term, assumption.term)
     }
@@ -1330,24 +1332,32 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
     #[instrument(level = "debug", skip(self), ret)]
     pub(super) fn error_implies(
         &self,
-        cond: ty::Predicate<'tcx>,
-        error: ty::Predicate<'tcx>,
+        cond: Goal<'tcx, ty::Predicate<'tcx>>,
+        error: Goal<'tcx, ty::Predicate<'tcx>>,
     ) -> bool {
         if cond == error {
             return true;
         }
 
-        if let Some(error) = error.as_trait_clause() {
+        // FIXME: We could be smarter about this, i.e. if cond's param-env is a
+        // subset of error's param-env. This only matters when binders will carry
+        // predicates though, and obviously only matters for error reporting.
+        if cond.param_env != error.param_env {
+            return false;
+        }
+        let param_env = error.param_env;
+
+        if let Some(error) = error.predicate.as_trait_clause() {
             self.enter_forall(error, |error| {
-                elaborate(self.tcx, std::iter::once(cond))
+                elaborate(self.tcx, std::iter::once(cond.predicate))
                     .filter_map(|implied| implied.as_trait_clause())
-                    .any(|implied| self.can_match_trait(error, implied))
+                    .any(|implied| self.can_match_trait(param_env, error, implied))
             })
-        } else if let Some(error) = error.as_projection_clause() {
+        } else if let Some(error) = error.predicate.as_projection_clause() {
             self.enter_forall(error, |error| {
-                elaborate(self.tcx, std::iter::once(cond))
+                elaborate(self.tcx, std::iter::once(cond.predicate))
                     .filter_map(|implied| implied.as_projection_clause())
-                    .any(|implied| self.can_match_projection(error, implied))
+                    .any(|implied| self.can_match_projection(param_env, error, implied))
             })
         } else {
             false
diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs
index efee6e2aa1d45..8ff7030717a8b 100644
--- a/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs
+++ b/compiler/rustc_trait_selection/src/error_reporting/traits/mod.rs
@@ -12,6 +12,7 @@ use rustc_errors::{Applicability, Diag, E0038, E0276, MultiSpan, struct_span_cod
 use rustc_hir::def_id::{DefId, LocalDefId};
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::{self as hir, AmbigArg, LangItem};
+use rustc_infer::traits::solve::Goal;
 use rustc_infer::traits::{
     DynCompatibilityViolation, Obligation, ObligationCause, ObligationCauseCode,
     PredicateObligation, SelectionError,
@@ -144,7 +145,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
 
         #[derive(Debug)]
         struct ErrorDescriptor<'tcx> {
-            predicate: ty::Predicate<'tcx>,
+            goal: Goal<'tcx, ty::Predicate<'tcx>>,
             index: Option<usize>, // None if this is an old error
         }
 
@@ -152,15 +153,8 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
             .reported_trait_errors
             .borrow()
             .iter()
-            .map(|(&span, predicates)| {
-                (
-                    span,
-                    predicates
-                        .0
-                        .iter()
-                        .map(|&predicate| ErrorDescriptor { predicate, index: None })
-                        .collect(),
-                )
+            .map(|(&span, goals)| {
+                (span, goals.0.iter().map(|&goal| ErrorDescriptor { goal, index: None }).collect())
             })
             .collect();
 
@@ -186,10 +180,10 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                 span = expn_data.call_site;
             }
 
-            error_map.entry(span).or_default().push(ErrorDescriptor {
-                predicate: error.obligation.predicate,
-                index: Some(index),
-            });
+            error_map
+                .entry(span)
+                .or_default()
+                .push(ErrorDescriptor { goal: error.obligation.as_goal(), index: Some(index) });
         }
 
         // We do this in 2 passes because we want to display errors in order, though
@@ -210,9 +204,9 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                             continue;
                         }
 
-                        if self.error_implies(error2.predicate, error.predicate)
+                        if self.error_implies(error2.goal, error.goal)
                             && !(error2.index >= error.index
-                                && self.error_implies(error.predicate, error2.predicate))
+                                && self.error_implies(error.goal, error2.goal))
                         {
                             info!("skipping {:?} (implied by {:?})", error, error2);
                             is_suppressed[index] = true;
@@ -243,7 +237,7 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
                         .entry(span)
                         .or_insert_with(|| (vec![], guar))
                         .0
-                        .push(error.obligation.predicate);
+                        .push(error.obligation.as_goal());
                 }
             }
         }
diff --git a/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.rs b/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.rs
new file mode 100644
index 0000000000000..231bb5252de1a
--- /dev/null
+++ b/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.rs
@@ -0,0 +1,41 @@
+// compile-flags: -Znext-solver
+
+// Test for a weird diagnostics corner case. In the error reporting code, when reporting
+// fulfillment errors for goals A and B, we try to see if elaborating A will result in
+// another goal that can equate with B. That would signal that B is "implied by" A,
+// allowing us to skip reporting it, which is beneficial for cutting down on the number
+// of diagnostics we report. In the new trait solver especially, but even in the old trait
+// solver through things like defining opaque type usages, this `can_equate` call was not
+// properly taking the param-env of the goals, resulting in nested obligations that had
+// empty param-envs. If one of these nested obligations was a `ConstParamHasTy` goal, then
+// we would ICE, since those goals are particularly strict about the param-env they're
+// evaluated in.
+
+// This is morally a fix for <https://github.com/rust-lang/rust/issues/139314>, but that
+// repro uses details about how defining usages in the `check_opaque_well_formed` code
+// can spring out of type equality, and will likely stop failing soon coincidentally once
+// we start using `PostBorrowck` mode in that check.
+
+trait Foo: Baz<()> {}
+trait Baz<T> {}
+
+trait IdentityWithConstArgGoal<const N: usize> {
+    type Assoc;
+}
+impl<T, const N: usize> IdentityWithConstArgGoal<N> for T {
+    type Assoc = T;
+}
+
+fn unsatisfied<T, const N: usize>()
+where
+    T: Foo,
+    T: Baz<<T as IdentityWithConstArgGoal<N>>::Assoc>,
+{
+}
+
+fn test<const N: usize>() {
+    unsatisfied::<(), N>();
+    //~^ ERROR the trait bound `(): Foo` is not satisfied
+}
+
+fn main() {}
diff --git a/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.stderr b/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.stderr
new file mode 100644
index 0000000000000..77bba4945523a
--- /dev/null
+++ b/tests/ui/const-generics/const-param-has-ty-goal-in-error-implies.stderr
@@ -0,0 +1,23 @@
+error[E0277]: the trait bound `(): Foo` is not satisfied
+  --> $DIR/const-param-has-ty-goal-in-error-implies.rs:37:19
+   |
+LL |     unsatisfied::<(), N>();
+   |                   ^^ the trait `Foo` is not implemented for `()`
+   |
+help: this trait has no implementations, consider adding one
+  --> $DIR/const-param-has-ty-goal-in-error-implies.rs:19:1
+   |
+LL | trait Foo: Baz<()> {}
+   | ^^^^^^^^^^^^^^^^^^
+note: required by a bound in `unsatisfied`
+  --> $DIR/const-param-has-ty-goal-in-error-implies.rs:31:8
+   |
+LL | fn unsatisfied<T, const N: usize>()
+   |    ----------- required by a bound in this function
+LL | where
+LL |     T: Foo,
+   |        ^^^ required by this bound in `unsatisfied`
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0277`.