From 04c603c4d73c348bd56f54d60bf01f1b7e0c9573 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sat, 13 Aug 2022 00:56:29 +0000
Subject: [PATCH 1/2] Don't create an extra infcx in
 report_closure_arg_mismatch

---
 .../src/traits/error_reporting/suggestions.rs                 | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index 219413121d812..a1ba4b7b647a9 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -20,7 +20,6 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::{AsyncGeneratorKind, GeneratorKind, Node};
-use rustc_infer::infer::TyCtxtInferExt;
 use rustc_middle::hir::map;
 use rustc_middle::ty::{
     self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
@@ -1589,8 +1588,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         let expected = build_fn_sig_ty(self.tcx, expected);
         let found = build_fn_sig_ty(self.tcx, found);
 
-        let (expected_str, found_str) =
-            self.tcx.infer_ctxt().enter(|infcx| infcx.cmp(expected, found));
+        let (expected_str, found_str) = self.cmp(expected, found);
 
         let signature_kind = format!("{argument_kind} signature");
         err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);

From 0320b4de7d8eb8cefa7ee720549718c12c3b8654 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sat, 13 Aug 2022 02:29:48 +0000
Subject: [PATCH 2/2] Use real inference variable in build_fn_sig_ty

---
 .../src/traits/error_reporting/suggestions.rs | 25 ++++++++++++-------
 1 file changed, 16 insertions(+), 9 deletions(-)

diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
index a1ba4b7b647a9..a5295685c5354 100644
--- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
+++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs
@@ -20,6 +20,7 @@ use rustc_hir::def_id::DefId;
 use rustc_hir::intravisit::Visitor;
 use rustc_hir::lang_items::LangItem;
 use rustc_hir::{AsyncGeneratorKind, GeneratorKind, Node};
+use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
 use rustc_middle::hir::map;
 use rustc_middle::ty::{
     self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
@@ -1540,32 +1541,38 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         expected: ty::PolyTraitRef<'tcx>,
     ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> {
         pub(crate) fn build_fn_sig_ty<'tcx>(
-            tcx: TyCtxt<'tcx>,
+            infcx: &InferCtxt<'_, 'tcx>,
             trait_ref: ty::PolyTraitRef<'tcx>,
         ) -> Ty<'tcx> {
             let inputs = trait_ref.skip_binder().substs.type_at(1);
             let sig = match inputs.kind() {
                 ty::Tuple(inputs)
-                    if tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
+                    if infcx.tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
                 {
-                    tcx.mk_fn_sig(
+                    infcx.tcx.mk_fn_sig(
                         inputs.iter(),
-                        tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
+                        infcx.next_ty_var(TypeVariableOrigin {
+                            span: DUMMY_SP,
+                            kind: TypeVariableOriginKind::MiscVariable,
+                        }),
                         false,
                         hir::Unsafety::Normal,
                         abi::Abi::Rust,
                     )
                 }
-                _ => tcx.mk_fn_sig(
+                _ => infcx.tcx.mk_fn_sig(
                     std::iter::once(inputs),
-                    tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
+                    infcx.next_ty_var(TypeVariableOrigin {
+                        span: DUMMY_SP,
+                        kind: TypeVariableOriginKind::MiscVariable,
+                    }),
                     false,
                     hir::Unsafety::Normal,
                     abi::Abi::Rust,
                 ),
             };
 
-            tcx.mk_fn_ptr(trait_ref.rebind(sig))
+            infcx.tcx.mk_fn_ptr(trait_ref.rebind(sig))
         }
 
         let argument_kind = match expected.skip_binder().self_ty().kind() {
@@ -1585,8 +1592,8 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
         let found_span = found_span.unwrap_or(span);
         err.span_label(found_span, "found signature defined here");
 
-        let expected = build_fn_sig_ty(self.tcx, expected);
-        let found = build_fn_sig_ty(self.tcx, found);
+        let expected = build_fn_sig_ty(self, expected);
+        let found = build_fn_sig_ty(self, found);
 
         let (expected_str, found_str) = self.cmp(expected, found);