From 01046648cfaf88d1105c894f83703bf6099775d3 Mon Sep 17 00:00:00 2001
From: JohannesGaessler <johannesg@5d6.de>
Date: Sat, 19 Aug 2023 19:54:56 +0200
Subject: [PATCH] ggml: create thread pool lazily

---
 ggml.c | 77 ++++++++++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 59 insertions(+), 18 deletions(-)

diff --git a/ggml.c b/ggml.c
index beb7f464167d5..2797e358f61c8 100644
--- a/ggml.c
+++ b/ggml.c
@@ -16215,8 +16215,11 @@ static void clear_numa_thread_affinity(void) {}
 #endif
 
 struct ggml_compute_state_shared {
-    const struct ggml_cgraph * cgraph;
-    const struct ggml_cplan  * cplan;
+    const struct ggml_cgraph  * cgraph;
+    const struct ggml_cplan   * cplan;
+
+    struct ggml_compute_state * workers;
+    bool workers_created;
 
     int64_t perf_node_start_cycles;
     int64_t perf_node_start_time_us;
@@ -16246,6 +16249,8 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
     node->perf_time_us += time_us_cur;
 }
 
+void ggml_create_workers(struct ggml_compute_state_shared * state_shared);
+
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
 
@@ -16264,7 +16269,23 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             state->shared->node_n += 1;
             return (thread_ret_t) GGML_EXIT_ABORTED;
         }
-        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
+
+        int n_active;
+        if (!state->shared->workers_created) {
+            // if the worker pool has not yet been created:
+            // there is only a single active thread
+            n_active = 1;
+        } else if (node_n == -1) {
+            // if the worker pool has been created by another thread and this is the first iteration:
+            // go straight to the else block as if the thread had been spinning all along
+            n_active = -1;
+        } else {
+            // if the worker pool has been created and this is not the first iteration:
+            // decrement the number of active threads and start spinning if there are still other active threads
+            n_active = atomic_fetch_sub(&state->shared->n_active, 1);
+        }
+
+        if (n_active == 1) {
             // all other threads are finished and spinning
             // do finalize and init here so we don't have synchronize again
             struct ggml_compute_params params = {
@@ -16316,6 +16337,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
                     ggml_graph_compute_perf_stats_node(node, state->shared);
                 } else {
+                    // lazily create worker pool only once there is a node with >1 tasks
+                    if (!state->shared->workers_created) {
+                        state->shared->workers_created = true;
+                        ggml_create_workers(state->shared);
+                    }
                     break;
                 }
 
@@ -16727,6 +16753,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 } break;
         }
 
+        bool node_and_src_all_cpu = node->backend == GGML_BACKEND_CPU;
+        for (int j = 0; node_and_src_all_cpu && j < GGML_MAX_SRC; ++j) {
+            if (node->src[j] != NULL && node->src[j]->backend != GGML_BACKEND_CPU) {
+                node_and_src_all_cpu = false;
+            }
+        }
+        if (!node_and_src_all_cpu) {
+            n_tasks = 1;
+        }
+
         cplan.n_tasks[i] = n_tasks;
     }
 
@@ -16741,6 +16777,22 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
     return cplan;
 }
 
+void ggml_create_workers(struct ggml_compute_state_shared * state_shared) {
+    if (state_shared->n_threads > 1) {
+        for (int j = 1; j < state_shared->n_threads; ++j) {
+            state_shared->workers[j] = (struct ggml_compute_state) {
+                .thrd   = 0,
+                .ith = j,
+                .shared = state_shared,
+            };
+
+            const int rc = ggml_thread_create(&state_shared->workers[j].thrd, NULL,
+                                              ggml_graph_compute_thread, &state_shared->workers[j]);
+            GGML_ASSERT(rc == 0);
+        }
+    }
+}
+
 int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
     {
         GGML_ASSERT(cplan);
@@ -16759,9 +16811,12 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
 
     const int n_threads = cplan->n_threads;
 
+    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
     struct ggml_compute_state_shared state_shared = {
         /*.cgraph                  =*/ cgraph,
         /*.cgraph_plan             =*/ cplan,
+        /*.workers                 =*/ workers,
+        /*.workers_created         =*/ false,
         /*.perf_node_start_cycles  =*/ 0,
         /*.perf_node_start_time_us =*/ 0,
         /*.n_threads               =*/ n_threads,
@@ -16770,21 +16825,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
         /*.abort_callback          =*/ NULL,
         /*.abort_callback_data     =*/ NULL,
     };
-    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
 
-    // create thread pool
-    if (n_threads > 1) {
-        for (int j = 1; j < n_threads; ++j) {
-            workers[j] = (struct ggml_compute_state) {
-                .thrd   = 0,
-                .ith = j,
-                .shared = &state_shared,
-            };
-
-            const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            GGML_ASSERT(rc == 0);
-        }
-    }
     workers[0].ith = 0;
     workers[0].shared = &state_shared;
 
@@ -16798,7 +16839,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
     clear_numa_thread_affinity();
 
     // join or kill thread pool
-    if (n_threads > 1) {
+    if (n_threads > 1 && state_shared.workers_created) {
         for (int j = 1; j < n_threads; j++) {
             const int rc = ggml_thread_join(workers[j].thrd, NULL);
             GGML_ASSERT(rc == 0);