From 9936a399dfd85b7ed07f948d602f332e5258f2c0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?John=20K=C3=A5re=20Alsaker?= <john.kare.alsaker@gmail.com>
Date: Fri, 25 Aug 2023 17:26:24 +0200
Subject: [PATCH] Add a `CurrentGcx` type to let the deadlock handler access
 `TyCtxt`

---
 compiler/rustc_interface/src/interface.rs | 13 ++++--
 compiler/rustc_interface/src/passes.rs    |  1 +
 compiler/rustc_interface/src/util.rs      | 33 +++++++++-----
 compiler/rustc_middle/src/ty/context.rs   | 53 ++++++++++++++++++++++-
 compiler/rustc_middle/src/ty/mod.rs       |  4 +-
 5 files changed, 88 insertions(+), 16 deletions(-)

diff --git a/compiler/rustc_interface/src/interface.rs b/compiler/rustc_interface/src/interface.rs
index 656c7ffae1933..da2fb490a36f2 100644
--- a/compiler/rustc_interface/src/interface.rs
+++ b/compiler/rustc_interface/src/interface.rs
@@ -10,7 +10,9 @@ use rustc_data_structures::sync::Lrc;
 use rustc_errors::registry::Registry;
 use rustc_errors::{DiagCtxt, ErrorGuaranteed};
 use rustc_lint::LintStore;
+
 use rustc_middle::ty;
+use rustc_middle::ty::CurrentGcx;
 use rustc_middle::util::Providers;
 use rustc_parse::maybe_new_parser_from_source_str;
 use rustc_query_impl::QueryCtxt;
@@ -39,6 +41,7 @@ pub struct Compiler {
     pub sess: Session,
     pub codegen_backend: Box<dyn CodegenBackend>,
     pub(crate) override_queries: Option<fn(&Session, &mut Providers)>,
+    pub(crate) current_gcx: CurrentGcx,
 }
 
 /// Converts strings provided as `--cfg [cfgspec]` into a `Cfg`.
@@ -336,7 +339,7 @@ pub fn run_compiler<R: Send>(config: Config, f: impl FnOnce(&Compiler) -> R + Se
     util::run_in_thread_pool_with_globals(
         config.opts.edition,
         config.opts.unstable_opts.threads,
-        || {
+        |current_gcx| {
             crate::callbacks::setup_callbacks();
 
             let early_dcx = EarlyDiagCtxt::new(config.opts.error_format);
@@ -430,8 +433,12 @@ pub fn run_compiler<R: Send>(config: Config, f: impl FnOnce(&Compiler) -> R + Se
             }
             sess.lint_store = Some(Lrc::new(lint_store));
 
-            let compiler =
-                Compiler { sess, codegen_backend, override_queries: config.override_queries };
+            let compiler = Compiler {
+                sess,
+                codegen_backend,
+                override_queries: config.override_queries,
+                current_gcx,
+            };
 
             rustc_span::set_source_map(compiler.sess.psess.clone_source_map(), move || {
                 // There are two paths out of `f`.
diff --git a/compiler/rustc_interface/src/passes.rs b/compiler/rustc_interface/src/passes.rs
index 4b4c1d6cf672b..d763a12f816b0 100644
--- a/compiler/rustc_interface/src/passes.rs
+++ b/compiler/rustc_interface/src/passes.rs
@@ -680,6 +680,7 @@ pub fn create_global_ctxt<'tcx>(
                     incremental,
                 ),
                 providers.hooks,
+                compiler.current_gcx.clone(),
             )
         })
     })
diff --git a/compiler/rustc_interface/src/util.rs b/compiler/rustc_interface/src/util.rs
index d09f8d7d7cf75..d0f04fccc4810 100644
--- a/compiler/rustc_interface/src/util.rs
+++ b/compiler/rustc_interface/src/util.rs
@@ -5,6 +5,7 @@ use rustc_codegen_ssa::traits::CodegenBackend;
 #[cfg(parallel_compiler)]
 use rustc_data_structures::sync;
 use rustc_metadata::{load_symbol_from_dylib, DylibError};
+use rustc_middle::ty::CurrentGcx;
 use rustc_parse::validate_attr;
 use rustc_session as session;
 use rustc_session::config::{Cfg, OutFileName, OutputFilenames, OutputTypes};
@@ -64,7 +65,7 @@ fn init_stack_size() -> usize {
     })
 }
 
-pub(crate) fn run_in_thread_with_globals<F: FnOnce() -> R + Send, R: Send>(
+pub(crate) fn run_in_thread_with_globals<F: FnOnce(CurrentGcx) -> R + Send, R: Send>(
     edition: Edition,
     f: F,
 ) -> R {
@@ -82,7 +83,9 @@ pub(crate) fn run_in_thread_with_globals<F: FnOnce() -> R + Send, R: Send>(
         // `unwrap` is ok here because `spawn_scoped` only panics if the thread
         // name contains null bytes.
         let r = builder
-            .spawn_scoped(s, move || rustc_span::create_session_globals_then(edition, f))
+            .spawn_scoped(s, move || {
+                rustc_span::create_session_globals_then(edition, || f(CurrentGcx::new()))
+            })
             .unwrap()
             .join();
 
@@ -94,7 +97,7 @@ pub(crate) fn run_in_thread_with_globals<F: FnOnce() -> R + Send, R: Send>(
 }
 
 #[cfg(not(parallel_compiler))]
-pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
+pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce(CurrentGcx) -> R + Send, R: Send>(
     edition: Edition,
     _threads: usize,
     f: F,
@@ -103,7 +106,7 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
 }
 
 #[cfg(parallel_compiler)]
-pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
+pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce(CurrentGcx) -> R + Send, R: Send>(
     edition: Edition,
     threads: usize,
     f: F,
@@ -117,24 +120,34 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
     let registry = sync::Registry::new(std::num::NonZero::new(threads).unwrap());
 
     if !sync::is_dyn_thread_safe() {
-        return run_in_thread_with_globals(edition, || {
+        return run_in_thread_with_globals(edition, |current_gcx| {
             // Register the thread for use with the `WorkerLocal` type.
             registry.register();
 
-            f()
+            f(current_gcx)
         });
     }
 
+    let current_gcx = FromDyn::from(CurrentGcx::new());
+    let current_gcx2 = current_gcx.clone();
+
     let builder = rayon::ThreadPoolBuilder::new()
         .thread_name(|_| "rustc".to_string())
         .acquire_thread_handler(jobserver::acquire_thread)
         .release_thread_handler(jobserver::release_thread)
         .num_threads(threads)
-        .deadlock_handler(|| {
+        .deadlock_handler(move || {
             // On deadlock, creates a new thread and forwards information in thread
             // locals to it. The new thread runs the deadlock handler.
-            let query_map =
-                FromDyn::from(tls::with(|tcx| QueryCtxt::new(tcx).collect_active_jobs()));
+
+            // Get a `GlobalCtxt` reference from `CurrentGcx` as we cannot rely on having a
+            // `TyCtxt` TLS reference here.
+            let query_map = current_gcx2.access(|gcx| {
+                tls::enter_context(&tls::ImplicitCtxt::new(gcx), || {
+                    tls::with(|tcx| QueryCtxt::new(tcx).collect_active_jobs())
+                })
+            });
+            let query_map = FromDyn::from(query_map);
             let registry = rayon_core::Registry::current();
             thread::Builder::new()
                 .name("rustc query cycle handler".to_string())
@@ -171,7 +184,7 @@ pub(crate) fn run_in_thread_pool_with_globals<F: FnOnce() -> R + Send, R: Send>(
                         })
                     },
                     // Run `f` on the first thread in the thread pool.
-                    move |pool: &rayon::ThreadPool| pool.install(f),
+                    move |pool: &rayon::ThreadPool| pool.install(|| f(current_gcx.into_inner())),
                 )
                 .unwrap()
         })
diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs
index 3393f44484388..188cb50849dc1 100644
--- a/compiler/rustc_middle/src/ty/context.rs
+++ b/compiler/rustc_middle/src/ty/context.rs
@@ -32,6 +32,7 @@ use crate::ty::{
 };
 use crate::ty::{GenericArg, GenericArgs, GenericArgsRef};
 use rustc_ast::{self as ast, attr};
+use rustc_data_structures::defer;
 use rustc_data_structures::fingerprint::Fingerprint;
 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
 use rustc_data_structures::intern::Interned;
@@ -39,7 +40,7 @@ use rustc_data_structures::profiling::SelfProfilerRef;
 use rustc_data_structures::sharded::{IntoPointer, ShardedHashMap};
 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
 use rustc_data_structures::steal::Steal;
-use rustc_data_structures::sync::{self, FreezeReadGuard, Lock, Lrc, WorkerLocal};
+use rustc_data_structures::sync::{self, FreezeReadGuard, Lock, Lrc, RwLock, WorkerLocal};
 #[cfg(parallel_compiler)]
 use rustc_data_structures::sync::{DynSend, DynSync};
 use rustc_data_structures::unord::UnordSet;
@@ -723,6 +724,8 @@ pub struct GlobalCtxt<'tcx> {
 
     /// Stores memory for globals (statics/consts).
     pub(crate) alloc_map: Lock<interpret::AllocMap<'tcx>>,
+
+    current_gcx: CurrentGcx,
 }
 
 impl<'tcx> GlobalCtxt<'tcx> {
@@ -733,6 +736,19 @@ impl<'tcx> GlobalCtxt<'tcx> {
         F: FnOnce(TyCtxt<'tcx>) -> R,
     {
         let icx = tls::ImplicitCtxt::new(self);
+
+        // Reset `current_gcx` to `None` when we exit.
+        let _on_drop = defer(move || {
+            *self.current_gcx.value.write() = None;
+        });
+
+        // Set this `GlobalCtxt` as the current one.
+        {
+            let mut guard = self.current_gcx.value.write();
+            assert!(guard.is_none(), "no `GlobalCtxt` is currently set");
+            *guard = Some(self as *const _ as *const ());
+        }
+
         tls::enter_context(&icx, || f(icx.tcx))
     }
 
@@ -741,6 +757,39 @@ impl<'tcx> GlobalCtxt<'tcx> {
     }
 }
 
+/// This is used to get a reference to a `GlobalCtxt` if one is available.
+///
+/// This is needed to allow the deadlock handler access to `GlobalCtxt` to look for query cycles.
+/// It cannot use the `TLV` global because that's only guaranteed to be defined on the thread
+/// creating the `GlobalCtxt`. Other threads have access to the `TLV` only inside Rayon jobs, but
+/// the deadlock handler is not called inside such a job.
+#[derive(Clone)]
+pub struct CurrentGcx {
+    /// This stores a pointer to a `GlobalCtxt`. This is set to `Some` inside `GlobalCtxt::enter`
+    /// and reset to `None` when that function returns or unwinds.
+    value: Lrc<RwLock<Option<*const ()>>>,
+}
+
+#[cfg(parallel_compiler)]
+unsafe impl DynSend for CurrentGcx {}
+#[cfg(parallel_compiler)]
+unsafe impl DynSync for CurrentGcx {}
+
+impl CurrentGcx {
+    pub fn new() -> Self {
+        Self { value: Lrc::new(RwLock::new(None)) }
+    }
+
+    pub fn access<R>(&self, f: impl for<'tcx> FnOnce(&'tcx GlobalCtxt<'tcx>) -> R) -> R {
+        let read_guard = self.value.read();
+        let gcx: *const GlobalCtxt<'_> = read_guard.unwrap() as *const _;
+        // SAFETY: We hold the read lock for the `GlobalCtxt` pointer. That prevents
+        // `GlobalCtxt::enter` from returning as it would first acquire the write lock.
+        // This ensures the `GlobalCtxt` is live during `f`.
+        f(unsafe { &*gcx })
+    }
+}
+
 impl<'tcx> TyCtxt<'tcx> {
     /// Expects a body and returns its codegen attributes.
     ///
@@ -859,6 +908,7 @@ impl<'tcx> TyCtxt<'tcx> {
         query_kinds: &'tcx [DepKindStruct<'tcx>],
         query_system: QuerySystem<'tcx>,
         hooks: crate::hooks::Providers,
+        current_gcx: CurrentGcx,
     ) -> GlobalCtxt<'tcx> {
         let data_layout = s.target.parse_data_layout().unwrap_or_else(|err| {
             s.dcx().emit_fatal(err);
@@ -893,6 +943,7 @@ impl<'tcx> TyCtxt<'tcx> {
             canonical_param_env_cache: Default::default(),
             data_layout,
             alloc_map: Lock::new(interpret::AllocMap::new()),
+            current_gcx,
         }
     }
 
diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs
index 6ce53ccc8cd7a..6168874192042 100644
--- a/compiler/rustc_middle/src/ty/mod.rs
+++ b/compiler/rustc_middle/src/ty/mod.rs
@@ -86,8 +86,8 @@ pub use self::consts::{
     Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,
 };
 pub use self::context::{
-    tls, CtxtInterners, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt,
-    TyCtxtFeed,
+    tls, CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift,
+    TyCtxt, TyCtxtFeed,
 };
 pub use self::instance::{Instance, InstanceDef, ShortInstance, UnusedGenericParams};
 pub use self::list::List;