-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
RFC: Work-stealing scheduler #43366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Work-stealing scheduler #43366
Changes from all commits
bcaf64e
f2b47d2
b702ba9
a730ab1
dd7d5c2
36a8436
b1faf76
46173e6
f94ca11
ecec557
a4a20cb
e862411
531e479
0e50d01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,7 +14,6 @@ | |||||||||||
extern "C" { | ||||||||||||
#endif | ||||||||||||
|
||||||||||||
|
||||||||||||
// thread sleep state | ||||||||||||
|
||||||||||||
// default to DEFAULT_THREAD_SLEEP_THRESHOLD; set via $JULIA_THREAD_SLEEP_THRESHOLD | ||||||||||||
|
@@ -57,6 +56,12 @@ JL_DLLEXPORT int jl_set_task_tid(jl_task_t *task, int tid) JL_NOTSAFEPOINT | |||||||||||
extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache, | ||||||||||||
jl_gc_mark_sp_t *sp, jl_value_t *obj) JL_NOTSAFEPOINT; | ||||||||||||
|
||||||||||||
// partr dynamic dispatch | ||||||||||||
void (*jl_gc_mark_enqueued_tasks)(jl_gc_mark_cache_t *, jl_gc_mark_sp_t *); | ||||||||||||
static int (*partr_enqueue_task)(jl_task_t *, int16_t); | ||||||||||||
static jl_task_t *(*partr_dequeue_task)(void); | ||||||||||||
static int (*partr_check_empty)(void); | ||||||||||||
|
||||||||||||
// multiq | ||||||||||||
// --- | ||||||||||||
|
||||||||||||
|
@@ -83,20 +88,6 @@ static int32_t heap_p; | |||||||||||
static uint64_t cong_unbias; | ||||||||||||
|
||||||||||||
|
||||||||||||
static inline void multiq_init(void) | ||||||||||||
{ | ||||||||||||
heap_p = heap_c * jl_n_threads; | ||||||||||||
heaps = (taskheap_t *)calloc(heap_p, sizeof(taskheap_t)); | ||||||||||||
for (int32_t i = 0; i < heap_p; ++i) { | ||||||||||||
uv_mutex_init(&heaps[i].lock); | ||||||||||||
heaps[i].tasks = (jl_task_t **)calloc(tasks_per_heap, sizeof(jl_task_t*)); | ||||||||||||
jl_atomic_store_relaxed(&heaps[i].ntasks, 0); | ||||||||||||
jl_atomic_store_relaxed(&heaps[i].prio, INT16_MAX); | ||||||||||||
} | ||||||||||||
unbias_cong(heap_p, &cong_unbias); | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static inline void sift_up(taskheap_t *heap, int32_t idx) | ||||||||||||
{ | ||||||||||||
if (idx > 0) { | ||||||||||||
|
@@ -208,7 +199,7 @@ static inline jl_task_t *multiq_deletemin(void) | |||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
void jl_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp) | ||||||||||||
void multiq_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp) | ||||||||||||
{ | ||||||||||||
int32_t i, j; | ||||||||||||
for (i = 0; i < heap_p; ++i) | ||||||||||||
|
@@ -228,6 +219,170 @@ static int multiq_check_empty(void) | |||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static inline void multiq_init(void) | ||||||||||||
{ | ||||||||||||
heap_p = heap_c * jl_n_threads; | ||||||||||||
heaps = (taskheap_t *)calloc(heap_p, sizeof(taskheap_t)); | ||||||||||||
for (int32_t i = 0; i < heap_p; ++i) { | ||||||||||||
uv_mutex_init(&heaps[i].lock); | ||||||||||||
heaps[i].tasks = (jl_task_t **)calloc(tasks_per_heap, sizeof(jl_task_t*)); | ||||||||||||
jl_atomic_store_relaxed(&heaps[i].ntasks, 0); | ||||||||||||
jl_atomic_store_relaxed(&heaps[i].prio, INT16_MAX); | ||||||||||||
} | ||||||||||||
unbias_cong(heap_p, &cong_unbias); | ||||||||||||
jl_gc_mark_enqueued_tasks = &multiq_gc_mark_enqueued_tasks; | ||||||||||||
partr_enqueue_task = &multiq_insert; | ||||||||||||
partr_dequeue_task = &multiq_deletemin; | ||||||||||||
partr_check_empty = &multiq_check_empty; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
|
||||||||||||
// work-stealing deque | ||||||||||||
|
||||||||||||
// The work-stealing deque by Chase and Lev (2005). Le et al. (2013) provides | ||||||||||||
// C11-complienet memory ordering. | ||||||||||||
// | ||||||||||||
// Ref: | ||||||||||||
// * Chase and Lev (2005) https://doi.org/10.1145/1073970.1073974 | ||||||||||||
// * Le et al. (2013) https://doi.org/10.1145/2442516.2442524 | ||||||||||||
// | ||||||||||||
// TODO: Dynamic buffer resizing. | ||||||||||||
typedef struct _wsdeque_t { | ||||||||||||
union { | ||||||||||||
struct { | ||||||||||||
jl_task_t **tasks; | ||||||||||||
_Atomic(int64_t) top; | ||||||||||||
_Atomic(int64_t) bottom; | ||||||||||||
}; | ||||||||||||
uint8_t padding[JL_CACHE_BYTE_ALIGNMENT]; | ||||||||||||
}; | ||||||||||||
} wsdeque_t; | ||||||||||||
|
||||||||||||
static wsdeque_t *wsdeques; | ||||||||||||
|
||||||||||||
|
||||||||||||
static int wsdeque_push(jl_task_t *task, int16_t priority_ignord) | ||||||||||||
{ | ||||||||||||
int16_t tid = jl_threadid(); | ||||||||||||
int64_t b = jl_atomic_load_relaxed(&wsdeques[tid].bottom); | ||||||||||||
int64_t t = jl_atomic_load_acquire(&wsdeques[tid].top); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is being load-acquired? There are no stores in
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm trying to digest this and I realized that I still don't get it. As we discussed, there is a CAS on Lines 316 to 318 in 0e50d01
The loaded task only matters in the success path where we have a seq_cst write that supersets release write. So, we have:
So they establish happens-before and it looks like we know that the task is loaded by the time we load the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure. But edges are not transitive, unless all of them are seq-cst, IIUC. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But yes, looks like we have a proper release/acquire pair here on top to ensure the ops on tasks are okay.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm.... Lahav et al.'s definition of happens-before
But I don't know if that's a mismatch between their definition and the C11 semantics (though they use the same notions for discussing C11 so probably not) and/or some difference to the actual definition in the standard. |
||||||||||||
int64_t size = b - t; | ||||||||||||
if (size >= tasks_per_heap - 1) // full | ||||||||||||
return -1; | ||||||||||||
jl_atomic_store_relaxed( | ||||||||||||
(_Atomic(jl_task_t *) *)&wsdeques[tid].tasks[b % tasks_per_heap], task); | ||||||||||||
if (jl_atomic_load_acquire(&task->tid) != -1) | ||||||||||||
// If the `task` still hasn't finished the context switch at this point, abort push | ||||||||||||
// and put it in the sticky queue. | ||||||||||||
return -1; | ||||||||||||
Comment on lines
+275
to
+278
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The |
||||||||||||
jl_fence_release(); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only store was the relaxed to |
||||||||||||
jl_atomic_store_relaxed(&wsdeques[tid].bottom, b + 1); | ||||||||||||
return 0; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static jl_task_t *wsdeque_pop(void) | ||||||||||||
{ | ||||||||||||
int16_t tid = jl_threadid(); | ||||||||||||
int64_t b = jl_atomic_load_relaxed(&wsdeques[tid].bottom) - 1; | ||||||||||||
jl_atomic_store_relaxed(&wsdeques[tid].bottom, b); | ||||||||||||
jl_fence(); | ||||||||||||
int64_t t = jl_atomic_load_relaxed(&wsdeques[tid].top); | ||||||||||||
int64_t size = b - t; | ||||||||||||
if (size < 0) { | ||||||||||||
jl_atomic_store_relaxed(&wsdeques[tid].bottom, t); | ||||||||||||
return NULL; | ||||||||||||
} | ||||||||||||
jl_task_t *task = jl_atomic_load_relaxed( | ||||||||||||
(_Atomic(jl_task_t *) *)&wsdeques[tid].tasks[b % tasks_per_heap]); | ||||||||||||
if (size > 0) | ||||||||||||
return task; | ||||||||||||
if (!jl_atomic_cmpswap(&wsdeques[tid].top, &t, t + 1)) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this might need a fence, since it is trying to order the relaxed load above with the relaxed store below it? |
||||||||||||
task = NULL; | ||||||||||||
jl_atomic_store_relaxed(&wsdeques[tid].bottom, b + 1); | ||||||||||||
return task; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static jl_task_t *wsdeque_steal(int16_t tid) | ||||||||||||
{ | ||||||||||||
int64_t t = jl_atomic_load_acquire(&wsdeques[tid].top); | ||||||||||||
jl_fence(); | ||||||||||||
int64_t b = jl_atomic_load_acquire(&wsdeques[tid].bottom); | ||||||||||||
int64_t size = b - t; | ||||||||||||
if (size <= 0) | ||||||||||||
return NULL; | ||||||||||||
jl_task_t *task = jl_atomic_load_relaxed( | ||||||||||||
(_Atomic(jl_task_t *) *)&wsdeques[tid].tasks[t % tasks_per_heap]); | ||||||||||||
if (!jl_atomic_cmpswap(&wsdeques[tid].top, &t, t + 1)) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. possibly also need an explicit fence? I am not certain if seq-cst cmpswp is sufficient to enforce an order on relaxed ops nearby. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am surprised this doesn't try to steal size*fraction items, instead of a constant (1), but that does not need to be changed for this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's Stealing Multi-Queue that does something similar to what you said https://arxiv.org/abs/2109.00657 But stealing like this works quite well for fork-join use cases in which you typically only enqueue/"materialize" O(log n) tasks in the queue while the task itself can "reveal" O(n) child tasks upon execution (n = e.g., length of array for a parallel map). |
||||||||||||
return NULL; | ||||||||||||
return task; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static jl_task_t *wsdeque_pop_or_steal(void) | ||||||||||||
{ | ||||||||||||
jl_ptls_t ptls = jl_current_task->ptls; | ||||||||||||
jl_task_t *task = wsdeque_pop(); | ||||||||||||
if (task || jl_n_threads < 2) | ||||||||||||
return task; | ||||||||||||
|
||||||||||||
int ntries = jl_n_threads; | ||||||||||||
for (int i = 0; i < ntries; ++i) { | ||||||||||||
uint64_t tid = cong(jl_n_threads - 1, cong_unbias, &ptls->rngseed); | ||||||||||||
if (tid >= ptls->tid) | ||||||||||||
tid++; | ||||||||||||
task = wsdeque_steal(tid); | ||||||||||||
if (task) | ||||||||||||
return task; | ||||||||||||
} | ||||||||||||
return NULL; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
void wsdeque_gc_mark_enqueued_tasks(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp) | ||||||||||||
{ | ||||||||||||
for (int i = 0; i < jl_n_threads; ++i) { | ||||||||||||
int64_t t = jl_atomic_load_relaxed(&wsdeques[i].top); | ||||||||||||
int64_t b = jl_atomic_load_relaxed(&wsdeques[i].bottom); | ||||||||||||
for (int j = t; j < b; ++j) | ||||||||||||
jl_gc_mark_queue_obj_explicit( | ||||||||||||
gc_cache, sp, (jl_value_t *)wsdeques[i].tasks[j % tasks_per_heap]); | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static int wsdeque_check_empty(void) | ||||||||||||
{ | ||||||||||||
for (int32_t i = 0; i < jl_n_threads; ++i) { | ||||||||||||
int64_t t = jl_atomic_load_relaxed(&wsdeques[i].top); | ||||||||||||
int64_t b = jl_atomic_load_relaxed(&wsdeques[i].bottom); | ||||||||||||
int64_t size = b - t; | ||||||||||||
if (size > 0) | ||||||||||||
return 0; | ||||||||||||
} | ||||||||||||
return 1; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
static void wsdeque_init(void) | ||||||||||||
{ | ||||||||||||
// Manually align the pointer since `jl_malloc_aligned` is not available here. | ||||||||||||
wsdeques = (wsdeque_t *)(((uintptr_t)calloc(1, sizeof(wsdeque_t) * jl_n_threads + | ||||||||||||
JL_CACHE_BYTE_ALIGNMENT - 1) + | ||||||||||||
JL_CACHE_BYTE_ALIGNMENT - 1) & | ||||||||||||
(-JL_CACHE_BYTE_ALIGNMENT)); | ||||||||||||
Comment on lines
+372
to
+375
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
for (int32_t i = 0; i < jl_n_threads; ++i) { | ||||||||||||
wsdeques[i].tasks = (jl_task_t **)calloc(tasks_per_heap, sizeof(jl_task_t *)); | ||||||||||||
} | ||||||||||||
unbias_cong(jl_n_threads, &cong_unbias); | ||||||||||||
jl_gc_mark_enqueued_tasks = &wsdeque_gc_mark_enqueued_tasks; | ||||||||||||
partr_enqueue_task = &wsdeque_push; | ||||||||||||
partr_dequeue_task = &wsdeque_pop_or_steal; | ||||||||||||
partr_check_empty = &wsdeque_check_empty; | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
||||||||||||
// parallel task runtime | ||||||||||||
// --- | ||||||||||||
|
@@ -236,8 +391,12 @@ static int multiq_check_empty(void) | |||||||||||
// (used only by the main thread) | ||||||||||||
void jl_init_threadinginfra(void) | ||||||||||||
{ | ||||||||||||
/* initialize the synchronization trees pool and the multiqueue */ | ||||||||||||
multiq_init(); | ||||||||||||
/* choose and initialize the scheduler */ | ||||||||||||
char *sch = getenv("JULIA_THREAD_SCHEDULER"); | ||||||||||||
if (sch && !strncasecmp(sch, "workstealing", 12)) | ||||||||||||
wsdeque_init(); | ||||||||||||
else | ||||||||||||
multiq_init(); | ||||||||||||
|
||||||||||||
sleep_threshold = DEFAULT_THREAD_SLEEP_THRESHOLD; | ||||||||||||
char *cp = getenv(THREAD_SLEEP_THRESHOLD_NAME); | ||||||||||||
|
@@ -292,7 +451,7 @@ void jl_threadfun(void *arg) | |||||||||||
// enqueue the specified task for execution | ||||||||||||
JL_DLLEXPORT int jl_enqueue_task(jl_task_t *task) | ||||||||||||
{ | ||||||||||||
if (multiq_insert(task, task->prio) == -1) | ||||||||||||
if (partr_enqueue_task(task, task->prio) == -1) | ||||||||||||
return 1; | ||||||||||||
return 0; | ||||||||||||
} | ||||||||||||
|
@@ -419,7 +578,7 @@ static jl_task_t *get_next_task(jl_value_t *trypoptask, jl_value_t *q) | |||||||||||
jl_set_task_tid(task, self); | ||||||||||||
return task; | ||||||||||||
} | ||||||||||||
return multiq_deletemin(); | ||||||||||||
return partr_dequeue_task(); | ||||||||||||
} | ||||||||||||
|
||||||||||||
static int may_sleep(jl_ptls_t ptls) JL_NOTSAFEPOINT | ||||||||||||
|
@@ -444,7 +603,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q) | |||||||||||
|
||||||||||||
// quick, race-y check to see if there seems to be any stuff in there | ||||||||||||
jl_cpu_pause(); | ||||||||||||
if (!multiq_check_empty()) { | ||||||||||||
if (!partr_check_empty()) { | ||||||||||||
start_cycles = 0; | ||||||||||||
continue; | ||||||||||||
} | ||||||||||||
|
@@ -453,7 +612,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q) | |||||||||||
jl_ptls_t ptls = ct->ptls; | ||||||||||||
if (sleep_check_after_threshold(&start_cycles) || (!jl_atomic_load_relaxed(&_threadedregion) && ptls->tid == 0)) { | ||||||||||||
jl_atomic_store(&ptls->sleep_check_state, sleeping); // acquire sleep-check lock | ||||||||||||
if (!multiq_check_empty()) { | ||||||||||||
if (!partr_check_empty()) { | ||||||||||||
if (jl_atomic_load_relaxed(&ptls->sleep_check_state) != not_sleeping) | ||||||||||||
jl_atomic_store(&ptls->sleep_check_state, not_sleeping); // let other threads know they don't need to wake us | ||||||||||||
continue; | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has no alignment whatsoever