diff --git a/ggml.c b/ggml.c index beb7f464167d5..2797e358f61c8 100644 --- a/ggml.c +++ b/ggml.c @@ -16215,8 +16215,11 @@ static void clear_numa_thread_affinity(void) {} #endif struct ggml_compute_state_shared { - const struct ggml_cgraph * cgraph; - const struct ggml_cplan * cplan; + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; + + struct ggml_compute_state * workers; + bool workers_created; int64_t perf_node_start_cycles; int64_t perf_node_start_time_us; @@ -16246,6 +16249,8 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const node->perf_time_us += time_us_cur; } +void ggml_create_workers(struct ggml_compute_state_shared * state_shared); + static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; @@ -16264,7 +16269,23 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { state->shared->node_n += 1; return (thread_ret_t) GGML_EXIT_ABORTED; } - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + + int n_active; + if (!state->shared->workers_created) { + // if the worker pool has not yet been created: + // there is only a single active thread + n_active = 1; + } else if (node_n == -1) { + // if the worker pool has been created by another thread and this is the first iteration: + // go straight to the else block as if the thread had been spinning all along + n_active = -1; + } else { + // if the worker pool has been created and this is not the first iteration: + // decrement the number of active threads and start spinning if there are still other active threads + n_active = atomic_fetch_sub(&state->shared->n_active, 1); + } + + if (n_active == 1) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again struct ggml_compute_params params = { @@ -16316,6 +16337,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_graph_compute_perf_stats_node(node, state->shared); } else { + // lazily create worker pool only once there is a node with >1 tasks + if (!state->shared->workers_created) { + state->shared->workers_created = true; + ggml_create_workers(state->shared); + } break; } @@ -16727,6 +16753,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; } + bool node_and_src_all_cpu = node->backend == GGML_BACKEND_CPU; + for (int j = 0; node_and_src_all_cpu && j < GGML_MAX_SRC; ++j) { + if (node->src[j] != NULL && node->src[j]->backend != GGML_BACKEND_CPU) { + node_and_src_all_cpu = false; + } + } + if (!node_and_src_all_cpu) { + n_tasks = 1; + } + cplan.n_tasks[i] = n_tasks; } @@ -16741,6 +16777,22 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { return cplan; } +void ggml_create_workers(struct ggml_compute_state_shared * state_shared) { + if (state_shared->n_threads > 1) { + for (int j = 1; j < state_shared->n_threads; ++j) { + state_shared->workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = state_shared, + }; + + const int rc = ggml_thread_create(&state_shared->workers[j].thrd, NULL, + ggml_graph_compute_thread, &state_shared->workers[j]); + GGML_ASSERT(rc == 0); + } + } +} + int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { { GGML_ASSERT(cplan); @@ -16759,9 +16811,12 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { const int n_threads = cplan->n_threads; + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); struct ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph, /*.cgraph_plan =*/ cplan, + /*.workers =*/ workers, + /*.workers_created =*/ false, /*.perf_node_start_cycles =*/ 0, /*.perf_node_start_time_us =*/ 0, /*.n_threads =*/ n_threads, @@ -16770,21 +16825,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { /*.abort_callback =*/ NULL, /*.abort_callback_data =*/ NULL, }; - struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); - // create thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - }; - - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - } - } workers[0].ith = 0; workers[0].shared = &state_shared; @@ -16798,7 +16839,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { clear_numa_thread_affinity(); // join or kill thread pool - if (n_threads > 1) { + if (n_threads > 1 && state_shared.workers_created) { for (int j = 1; j < n_threads; j++) { const int rc = ggml_thread_join(workers[j].thrd, NULL); GGML_ASSERT(rc == 0);