From 1038d1d2bc6aa44b931bd8849149483d1943ed5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Jul 2023 10:10:51 +0300 Subject: [PATCH 1/2] metal : fix out-of-bounds access + style changes --- ggml-metal.m | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 74a6bff404117..e63a24ef07b33 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -7,6 +7,11 @@ #import #import +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #ifdef GGML_METAL_NDEBUG #define metal_printf(...) #else @@ -372,13 +377,13 @@ void ggml_metal_graph_find_concurrency( int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time int nodes_unused[GGML_MAX_NODES]; - for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;} - for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;} + for (int i = 0; i < GGML_MAX_NODES; i++) { ctx->concur_list[i] = 0; } + for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } ctx->concur_list_len = 0; - int n_left = gf->n_nodes; - int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list - int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos + int n_left = gf->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list + int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos while (n_left > 0) { // number of nodes at a layer (that can be issued concurrently) @@ -386,28 +391,40 @@ void ggml_metal_graph_find_concurrency( for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { if (nodes_unused[i]) { // if the requirements for gf->nodes[i] are satisfied - int exe_flag=1; + int exe_flag = 1; + // scan all srcs for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; if (src_cur) { // if is leaf nodes it's satisfied. - if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} + // TODO: ggml_is_leaf() + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) { + continue; + } // otherwise this src should be the output from previous nodes. int is_found = 0; + // scan 2*search_depth back because we inserted barrier. - for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { - if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;} + //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { + if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { + is_found = 1; + break; + } + } + if (is_found == 0) { + exe_flag = 0; + break; } - if (is_found == 0) {exe_flag = 0; break;} } } if (exe_flag) { // check if nodes[i]'s data will be overwritten by a node before nodes[i]. // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] int64_t data_start = (int64_t) gf->nodes[i]->data; - int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); + int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); for (int j = n_start; j < i; j++) { if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ && gf->nodes[j]->op != GGML_OP_VIEW \ @@ -416,9 +433,9 @@ void ggml_metal_graph_find_concurrency( if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { continue; - } else { - exe_flag = 0; } + + exe_flag = 0; } } } @@ -435,7 +452,9 @@ void ggml_metal_graph_find_concurrency( ctx->concur_list[level_pos + concurrency] = -1; ctx->concur_list_len++; // jump all sorted nodes at nodes_bak - while (!nodes_unused[n_start]) {n_start++;} + while (!nodes_unused[n_start]) { + n_start++; + } level_pos += concurrency + 1; } From 30ea0e16853a0e3f1e1d06ddcf190165bcb0a773 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 7 Aug 2023 10:52:13 +0300 Subject: [PATCH 2/2] metal : increase concurrency nodes to 2*GGML_MAX_NODES --- ggml-metal.m | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e63a24ef07b33..a6f41718a454a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -20,6 +20,8 @@ #define UNUSED(x) (void)(x) +#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) + struct ggml_metal_buffer { const char * name; @@ -41,7 +43,7 @@ int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; - int concur_list[GGML_MAX_NODES]; + int concur_list[GGML_MAX_CONCUR]; int concur_list_len; // custom kernels @@ -375,10 +377,10 @@ void ggml_metal_graph_find_concurrency( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time - int nodes_unused[GGML_MAX_NODES]; + int nodes_unused[GGML_MAX_CONCUR]; - for (int i = 0; i < GGML_MAX_NODES; i++) { ctx->concur_list[i] = 0; } - for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } + for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } + for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } ctx->concur_list_len = 0; int n_left = gf->n_nodes; @@ -458,7 +460,7 @@ void ggml_metal_graph_find_concurrency( level_pos += concurrency + 1; } - if (ctx->concur_list_len > GGML_MAX_NODES) { + if (ctx->concur_list_len > GGML_MAX_CONCUR) { fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); } } @@ -472,7 +474,7 @@ void ggml_metal_graph_compute( // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;