Skip to content

Commit 477e410

Browse files
committed
refactor WorkerLocal
1 parent d610b0c commit 477e410

File tree

2 files changed

+48
-31
lines changed

2 files changed

+48
-31
lines changed

compiler/rustc_data_structures/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ indexmap = { version = "1.9.1" }
1414
jobserver_crate = { version = "0.1.13", package = "jobserver" }
1515
libc = "0.2"
1616
measureme = "10.0.0"
17-
rayon-core = { version = "0.4.0", package = "rustc-rayon-core", optional = true }
17+
rayon-core = { version = "0.4.0", package = "rustc-rayon-core" }
1818
rayon = { version = "0.4.0", package = "rustc-rayon", optional = true }
1919
rustc_graphviz = { path = "../rustc_graphviz" }
2020
rustc-hash = "1.1.0"
@@ -43,4 +43,4 @@ winapi = { version = "0.3", features = ["fileapi", "psapi", "winerror"] }
4343
memmap2 = "0.2.1"
4444

4545
[features]
46-
rustc_use_parallel_compiler = ["indexmap/rustc-rayon", "rayon", "rayon-core"]
46+
rustc_use_parallel_compiler = ["indexmap/rustc-rayon", "rayon"]

compiler/rustc_data_structures/src/sync.rs

+46-29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
use crate::owning_ref::{Erased, OwningRef};
2121
use std::collections::HashMap;
2222
use std::hash::{BuildHasher, Hash};
23+
use std::mem::MaybeUninit;
2324
use std::ops::{Deref, DerefMut};
2425
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
2526

@@ -30,6 +31,8 @@ pub use vec::AppendOnlyVec;
3031

3132
mod vec;
3233

34+
static PARALLEL: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
35+
3336
cfg_if! {
3437
if #[cfg(not(parallel_compiler))] {
3538
pub auto trait Send {}
@@ -182,33 +185,6 @@ cfg_if! {
182185

183186
use std::cell::Cell;
184187

185-
#[derive(Debug)]
186-
pub struct WorkerLocal<T>(OneThread<T>);
187-
188-
impl<T> WorkerLocal<T> {
189-
/// Creates a new worker local where the `initial` closure computes the
190-
/// value this worker local should take for each thread in the thread pool.
191-
#[inline]
192-
pub fn new<F: FnMut(usize) -> T>(mut f: F) -> WorkerLocal<T> {
193-
WorkerLocal(OneThread::new(f(0)))
194-
}
195-
196-
/// Returns the worker-local value for each thread
197-
#[inline]
198-
pub fn into_inner(self) -> Vec<T> {
199-
vec![OneThread::into_inner(self.0)]
200-
}
201-
}
202-
203-
impl<T> Deref for WorkerLocal<T> {
204-
type Target = T;
205-
206-
#[inline(always)]
207-
fn deref(&self) -> &T {
208-
&self.0
209-
}
210-
}
211-
212188
pub type MTRef<'a, T> = &'a mut T;
213189

214190
#[derive(Debug, Default)]
@@ -328,8 +304,6 @@ cfg_if! {
328304
};
329305
}
330306

331-
pub use rayon_core::WorkerLocal;
332-
333307
pub use rayon::iter::ParallelIterator;
334308
use rayon::iter::IntoParallelIterator;
335309

@@ -364,6 +338,49 @@ cfg_if! {
364338
}
365339
}
366340

341+
#[derive(Debug)]
342+
pub struct WorkerLocal<T> {
343+
single_thread: bool,
344+
inner: T,
345+
mt_inner: Option<rayon_core::WorkerLocal<T>>,
346+
}
347+
348+
impl<T> WorkerLocal<T> {
349+
/// Creates a new worker local where the `initial` closure computes the
350+
/// value this worker local should take for each thread in the thread pool.
351+
#[inline]
352+
pub fn new<F: FnMut(usize) -> T>(mut f: F) -> WorkerLocal<T> {
353+
if !PARALLEL.load(Ordering::Relaxed) {
354+
WorkerLocal { single_thread: true, inner: f(0), mt_inner: None }
355+
} else {
356+
// Safety: `inner` would never be accessed when multiple threads
357+
WorkerLocal {
358+
single_thread: false,
359+
inner: unsafe { MaybeUninit::uninit().assume_init() },
360+
mt_inner: Some(rayon_core::WorkerLocal::new(f)),
361+
}
362+
}
363+
}
364+
365+
/// Returns the worker-local value for each thread
366+
#[inline]
367+
pub fn into_inner(self) -> Vec<T> {
368+
if self.single_thread { vec![self.inner] } else { self.mt_inner.unwrap().into_inner() }
369+
}
370+
}
371+
372+
impl<T> Deref for WorkerLocal<T> {
373+
type Target = T;
374+
375+
#[inline(always)]
376+
fn deref(&self) -> &T {
377+
if self.single_thread { &self.inner } else { self.mt_inner.as_ref().unwrap().deref() }
378+
}
379+
}
380+
381+
// Just for speed test
382+
unsafe impl<T: Send> std::marker::Sync for WorkerLocal<T> {}
383+
367384
pub fn assert_sync<T: ?Sized + Sync>() {}
368385
pub fn assert_send<T: ?Sized + Send>() {}
369386
pub fn assert_send_val<T: ?Sized + Send>(_t: &T) {}

0 commit comments

Comments
 (0)