Skip to content

Commit 7592d41

Browse files
james7132notgull
andauthored
feat: Use actual thread local queues instead of using a RwLock
Currently, runner local queues rely on a RwLock<Vec<Arc<ConcurrentQueue>>>> to store the queues instead of using actual thread-local storage. This adds thread_local as a dependency, but this should allow the executor to work steal without needing to hold a lock, as well as allow tasks to schedule onto the local queue directly, where possible, instead of always relying on the global injector queue. Fixes #62 Co-authored-by: John Nunley <[email protected]>
1 parent 188f976 commit 7592d41

File tree

4 files changed

+92
-47
lines changed

4 files changed

+92
-47
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ name = "async-executor"
66
version = "1.8.0"
77
authors = ["Stjepan Glavina <[email protected]>"]
88
edition = "2021"
9-
rust-version = "1.60"
9+
rust-version = "1.61"
1010
description = "Async executor"
1111
license = "Apache-2.0 OR MIT"
1212
repository = "https://github.com/smol-rs/async-executor"
@@ -17,10 +17,12 @@ exclude = ["/.*"]
1717
[dependencies]
1818
async-lock = "3.0.0"
1919
async-task = "4.4.0"
20+
atomic-waker = "1.0"
2021
concurrent-queue = "2.0.0"
2122
fastrand = "2.0.0"
2223
futures-lite = { version = "2.0.0", default-features = false }
2324
slab = "0.4.4"
25+
thread_local = "1.1"
2426

2527
[target.'cfg(target_family = "wasm")'.dependencies]
2628
futures-lite = { version = "2.0.0", default-features = false, features = ["std"] }

benches/executor.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::future::Future;
21
use std::thread::available_parallelism;
32

43
use async_executor::Executor;

examples/priority.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! An executor with task priorities.
22
3-
use std::future::Future;
43
use std::thread;
54

65
use async_executor::{Executor, Task};

src/lib.rs

Lines changed: 89 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@
3434
)]
3535

3636
use std::fmt;
37-
use std::future::Future;
3837
use std::marker::PhantomData;
3938
use std::panic::{RefUnwindSafe, UnwindSafe};
4039
use std::rc::Rc;
4140
use std::sync::atomic::{AtomicBool, Ordering};
42-
use std::sync::{Arc, Mutex, RwLock, TryLockError};
41+
use std::sync::{Arc, Mutex, TryLockError};
4342
use std::task::{Poll, Waker};
4443

4544
use async_lock::OnceCell;
4645
use async_task::{Builder, Runnable};
46+
use atomic_waker::AtomicWaker;
4747
use concurrent_queue::ConcurrentQueue;
4848
use futures_lite::{future, prelude::*};
4949
use slab::Slab;
50+
use thread_local::ThreadLocal;
5051

5152
#[doc(no_inline)]
5253
pub use async_task::Task;
@@ -266,8 +267,23 @@ impl<'a> Executor<'a> {
266267
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
267268
let state = self.state().clone();
268269

269-
// TODO: If possible, push into the current local queue and notify the ticker.
270-
move |runnable| {
270+
move |mut runnable| {
271+
// If possible, push into the current local queue and notify the ticker.
272+
if let Some(local) = state.local_queue.get() {
273+
runnable = if let Err(err) = local.queue.push(runnable) {
274+
err.into_inner()
275+
} else {
276+
// Wake up this thread if it's asleep, otherwise notify another
277+
// thread to try to have the task stolen.
278+
if let Some(waker) = local.waker.take() {
279+
waker.wake();
280+
} else {
281+
state.notify();
282+
}
283+
return;
284+
}
285+
}
286+
// If the local queue is full, fallback to pushing onto the global injector queue.
271287
state.queue.push(runnable).unwrap();
272288
state.notify();
273289
}
@@ -510,7 +526,16 @@ struct State {
510526
queue: ConcurrentQueue<Runnable>,
511527

512528
/// Local queues created by runners.
513-
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
529+
///
530+
/// If possible, tasks are scheduled onto the local queue, and will only defer
531+
/// to other global queue when they're full, or the task is being scheduled from
532+
/// a thread without a runner.
533+
///
534+
/// Note: if a runner terminates and drains its local queue, any subsequent
535+
/// spawn calls from the same thread will be added to the same queue, but won't
536+
/// be executed until `Executor::run` is run on the thread again, or another
537+
/// thread steals the task.
538+
local_queue: ThreadLocal<LocalQueue>,
514539

515540
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
516541
notified: AtomicBool,
@@ -527,7 +552,7 @@ impl State {
527552
fn new() -> State {
528553
State {
529554
queue: ConcurrentQueue::unbounded(),
530-
local_queues: RwLock::new(Vec::new()),
555+
local_queue: ThreadLocal::new(),
531556
notified: AtomicBool::new(true),
532557
sleepers: Mutex::new(Sleepers {
533558
count: 0,
@@ -654,6 +679,12 @@ impl Ticker<'_> {
654679
///
655680
/// Returns `false` if the ticker was already sleeping and unnotified.
656681
fn sleep(&mut self, waker: &Waker) -> bool {
682+
self.state
683+
.local_queue
684+
.get_or_default()
685+
.waker
686+
.register(waker);
687+
657688
let mut sleepers = self.state.sleepers.lock().unwrap();
658689

659690
match self.sleeping {
@@ -692,7 +723,14 @@ impl Ticker<'_> {
692723

693724
/// Waits for the next runnable task to run.
694725
async fn runnable(&mut self) -> Runnable {
695-
self.runnable_with(|| self.state.queue.pop().ok()).await
726+
self.runnable_with(|| {
727+
self.state
728+
.local_queue
729+
.get()
730+
.and_then(|local| local.queue.pop().ok())
731+
.or_else(|| self.state.queue.pop().ok())
732+
})
733+
.await
696734
}
697735

698736
/// Waits for the next runnable task to run, given a function that searches for a task.
@@ -754,9 +792,6 @@ struct Runner<'a> {
754792
/// Inner ticker.
755793
ticker: Ticker<'a>,
756794

757-
/// The local queue.
758-
local: Arc<ConcurrentQueue<Runnable>>,
759-
760795
/// Bumped every time a runnable task is found.
761796
ticks: usize,
762797
}
@@ -767,38 +802,34 @@ impl Runner<'_> {
767802
let runner = Runner {
768803
state,
769804
ticker: Ticker::new(state),
770-
local: Arc::new(ConcurrentQueue::bounded(512)),
771805
ticks: 0,
772806
};
773-
state
774-
.local_queues
775-
.write()
776-
.unwrap()
777-
.push(runner.local.clone());
778807
runner
779808
}
780809

781810
/// Waits for the next runnable task to run.
782811
async fn runnable(&mut self, rng: &mut fastrand::Rng) -> Runnable {
812+
let local = self.state.local_queue.get_or_default();
813+
783814
let runnable = self
784815
.ticker
785816
.runnable_with(|| {
786817
// Try the local queue.
787-
if let Ok(r) = self.local.pop() {
818+
if let Ok(r) = local.queue.pop() {
788819
return Some(r);
789820
}
790821

791822
// Try stealing from the global queue.
792823
if let Ok(r) = self.state.queue.pop() {
793-
steal(&self.state.queue, &self.local);
824+
steal(&self.state.queue, &local.queue);
794825
return Some(r);
795826
}
796827

797828
// Try stealing from other runners.
798-
let local_queues = self.state.local_queues.read().unwrap();
829+
let local_queues = &self.state.local_queue;
799830

800831
// Pick a random starting point in the iterator list and rotate the list.
801-
let n = local_queues.len();
832+
let n = local_queues.iter().count();
802833
let start = rng.usize(..n);
803834
let iter = local_queues
804835
.iter()
@@ -807,12 +838,12 @@ impl Runner<'_> {
807838
.take(n);
808839

809840
// Remove this runner's local queue.
810-
let iter = iter.filter(|local| !Arc::ptr_eq(local, &self.local));
841+
let iter = iter.filter(|other| !core::ptr::eq(*other, local));
811842

812843
// Try stealing from each local queue in the list.
813-
for local in iter {
814-
steal(local, &self.local);
815-
if let Ok(r) = self.local.pop() {
844+
for other in iter {
845+
steal(&other.queue, &local.queue);
846+
if let Ok(r) = local.queue.pop() {
816847
return Some(r);
817848
}
818849
}
@@ -826,7 +857,7 @@ impl Runner<'_> {
826857

827858
if self.ticks % 64 == 0 {
828859
// Steal tasks from the global queue to ensure fair task scheduling.
829-
steal(&self.state.queue, &self.local);
860+
steal(&self.state.queue, &local.queue);
830861
}
831862

832863
runnable
@@ -836,15 +867,13 @@ impl Runner<'_> {
836867
impl Drop for Runner<'_> {
837868
fn drop(&mut self) {
838869
// Remove the local queue.
839-
self.state
840-
.local_queues
841-
.write()
842-
.unwrap()
843-
.retain(|local| !Arc::ptr_eq(local, &self.local));
844-
845-
// Re-schedule remaining tasks in the local queue.
846-
while let Ok(r) = self.local.pop() {
847-
r.schedule();
870+
if let Some(local) = self.state.local_queue.get() {
871+
// Re-schedule remaining tasks in the local queue.
872+
for r in local.queue.try_iter() {
873+
// Explicitly reschedule the runnable back onto the global
874+
// queue to avoid rescheduling onto the local one.
875+
self.state.queue.push(r).unwrap();
876+
}
848877
}
849878
}
850879
}
@@ -904,18 +933,13 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
904933
}
905934

906935
/// Debug wrapper for the local runners.
907-
struct LocalRunners<'a>(&'a RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>);
936+
struct LocalRunners<'a>(&'a ThreadLocal<LocalQueue>);
908937

909938
impl fmt::Debug for LocalRunners<'_> {
910939
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
911-
match self.0.try_read() {
912-
Ok(lock) => f
913-
.debug_list()
914-
.entries(lock.iter().map(|queue| queue.len()))
915-
.finish(),
916-
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
917-
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),
918-
}
940+
f.debug_list()
941+
.entries(self.0.iter().map(|local| local.queue.len()))
942+
.finish()
919943
}
920944
}
921945

@@ -935,11 +959,32 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
935959
f.debug_struct(name)
936960
.field("active", &ActiveTasks(&state.active))
937961
.field("global_tasks", &state.queue.len())
938-
.field("local_runners", &LocalRunners(&state.local_queues))
962+
.field("local_runners", &LocalRunners(&state.local_queue))
939963
.field("sleepers", &SleepCount(&state.sleepers))
940964
.finish()
941965
}
942966

967+
/// A queue local to each thread.
968+
///
969+
/// It's Default implementation is used for initializing each
970+
/// thread's queue via `ThreadLocal::get_or_default`.
971+
///
972+
/// The local queue *must* be flushed, and all pending runnables
973+
/// rescheduled onto the global queue when a runner is dropped.
974+
struct LocalQueue {
975+
queue: ConcurrentQueue<Runnable>,
976+
waker: AtomicWaker,
977+
}
978+
979+
impl Default for LocalQueue {
980+
fn default() -> Self {
981+
Self {
982+
queue: ConcurrentQueue::bounded(512),
983+
waker: AtomicWaker::new(),
984+
}
985+
}
986+
}
987+
943988
/// Runs a closure when dropped.
944989
struct CallOnDrop<F: FnMut()>(F);
945990

0 commit comments

Comments
 (0)