Skip to content

Commit f6f9896

Browse files
authored
metal : fix out-of-bounds access + inc concurrency nodes (#2416)
* metal : fix out-of-bounds access + style changes * metal : increase concurrency nodes to 2*GGML_MAX_NODES
1 parent 34a14b2 commit f6f9896

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

ggml-metal.m

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
#import <Metal/Metal.h>
88
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
99

10+
#undef MIN
11+
#undef MAX
12+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
13+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
14+
1015
#ifdef GGML_METAL_NDEBUG
1116
#define metal_printf(...)
1217
#else
@@ -15,6 +20,8 @@
1520

1621
#define UNUSED(x) (void)(x)
1722

23+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
24+
1825
struct ggml_metal_buffer {
1926
const char * name;
2027

@@ -36,7 +43,7 @@
3643
int n_buffers;
3744
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
3845

39-
int concur_list[GGML_MAX_NODES];
46+
int concur_list[GGML_MAX_CONCUR];
4047
int concur_list_len;
4148

4249
// custom kernels
@@ -370,44 +377,56 @@ void ggml_metal_graph_find_concurrency(
370377
struct ggml_metal_context * ctx,
371378
struct ggml_cgraph * gf) {
372379
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373-
int nodes_unused[GGML_MAX_NODES];
380+
int nodes_unused[GGML_MAX_CONCUR];
374381

375-
for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376-
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
382+
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
383+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
377384
ctx->concur_list_len = 0;
378385

379-
int n_left = gf->n_nodes;
380-
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381-
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
386+
int n_left = gf->n_nodes;
387+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
388+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382389

383390
while (n_left > 0) {
384391
// number of nodes at a layer (that can be issued concurrently)
385392
int concurrency = 0;
386393
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387394
if (nodes_unused[i]) {
388395
// if the requirements for gf->nodes[i] are satisfied
389-
int exe_flag=1;
396+
int exe_flag = 1;
397+
390398
// scan all srcs
391399
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392400
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393401
if (src_cur) {
394402
// if is leaf nodes it's satisfied.
395-
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
403+
// TODO: ggml_is_leaf()
404+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
405+
continue;
406+
}
396407

397408
// otherwise this src should be the output from previous nodes.
398409
int is_found = 0;
410+
399411
// scan 2*search_depth back because we inserted barrier.
400-
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401-
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
412+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
413+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
414+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
415+
is_found = 1;
416+
break;
417+
}
418+
}
419+
if (is_found == 0) {
420+
exe_flag = 0;
421+
break;
402422
}
403-
if (is_found == 0) {exe_flag = 0; break;}
404423
}
405424
}
406425
if (exe_flag) {
407426
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
408427
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409428
int64_t data_start = (int64_t) gf->nodes[i]->data;
410-
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
429+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411430
for (int j = n_start; j < i; j++) {
412431
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413432
&& gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
416435
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417436
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418437
continue;
419-
} else {
420-
exe_flag = 0;
421438
}
439+
440+
exe_flag = 0;
422441
}
423442
}
424443
}
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
435454
ctx->concur_list[level_pos + concurrency] = -1;
436455
ctx->concur_list_len++;
437456
// jump all sorted nodes at nodes_bak
438-
while (!nodes_unused[n_start]) {n_start++;}
457+
while (!nodes_unused[n_start]) {
458+
n_start++;
459+
}
439460
level_pos += concurrency + 1;
440461
}
441462

442-
if (ctx->concur_list_len > GGML_MAX_NODES) {
463+
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
443464
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444465
}
445466
}
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
453474
// else fallback to serial dispatch
454475
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455476

456-
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
477+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
457478

458479
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459480
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;

0 commit comments

Comments
 (0)