7
7
#import < Metal/Metal.h>
8
8
#import < MetalPerformanceShaders/MetalPerformanceShaders.h>
9
9
10
+ #undef MIN
11
+ #undef MAX
12
+ #define MIN (a, b ) ((a) < (b) ? (a) : (b))
13
+ #define MAX (a, b ) ((a) > (b) ? (a) : (b))
14
+
10
15
#ifdef GGML_METAL_NDEBUG
11
16
#define metal_printf (...)
12
17
#else
15
20
16
21
#define UNUSED (x ) (void )(x)
17
22
23
+ #define GGML_MAX_CONCUR (2 *GGML_MAX_NODES)
24
+
18
25
struct ggml_metal_buffer {
19
26
const char * name;
20
27
36
43
int n_buffers;
37
44
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
45
39
- int concur_list[GGML_MAX_NODES ];
46
+ int concur_list[GGML_MAX_CONCUR ];
40
47
int concur_list_len;
41
48
42
49
// custom kernels
@@ -370,44 +377,56 @@ void ggml_metal_graph_find_concurrency(
370
377
struct ggml_metal_context * ctx,
371
378
struct ggml_cgraph * gf) {
372
379
int search_depth = gf->n_nodes ; // we only find concurrency in this range to avoid wasting too much time
373
- int nodes_unused[GGML_MAX_NODES ];
380
+ int nodes_unused[GGML_MAX_CONCUR ];
374
381
375
- for (int i = 0 ; i < GGML_MAX_NODES ; i++) {ctx->concur_list [i] = 0 ;}
376
- for (int i = 0 ; i < gf->n_nodes ; i++) {nodes_unused[i] = 1 ;}
382
+ for (int i = 0 ; i < GGML_MAX_CONCUR ; i++) { ctx->concur_list [i] = 0 ; }
383
+ for (int i = 0 ; i < gf->n_nodes ; i++) { nodes_unused[i] = 1 ; }
377
384
ctx->concur_list_len = 0 ;
378
385
379
- int n_left = gf->n_nodes ;
380
- int n_start = 0 ; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
- int level_pos = 0 ; // at ctx->concur_list, the last layer (level) ends at level_pos
386
+ int n_left = gf->n_nodes ;
387
+ int n_start = 0 ; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
388
+ int level_pos = 0 ; // at ctx->concur_list, the last layer (level) ends at level_pos
382
389
383
390
while (n_left > 0 ) {
384
391
// number of nodes at a layer (that can be issued concurrently)
385
392
int concurrency = 0 ;
386
393
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes ) ? gf->n_nodes : n_start + search_depth); i++) {
387
394
if (nodes_unused[i]) {
388
395
// if the requirements for gf->nodes[i] are satisfied
389
- int exe_flag=1 ;
396
+ int exe_flag = 1 ;
397
+
390
398
// scan all srcs
391
399
for (int src_ind = 0 ; src_ind < GGML_MAX_SRC; src_ind++) {
392
400
struct ggml_tensor * src_cur = gf->nodes [i]->src [src_ind];
393
401
if (src_cur) {
394
402
// if is leaf nodes it's satisfied.
395
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL ) {continue ;}
403
+ // TODO: ggml_is_leaf()
404
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL ) {
405
+ continue ;
406
+ }
396
407
397
408
// otherwise this src should be the output from previous nodes.
398
409
int is_found = 0 ;
410
+
399
411
// scan 2*search_depth back because we inserted barrier.
400
- for (int j = ((level_pos - 2 *search_depth) < 0 ? 0 : (level_pos - 2 *search_depth)); j < level_pos; j++) {
401
- if (gf->nodes [ctx->concur_list[j]] == src_cur) {is_found = 1 ; break ;}
412
+ // for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
413
+ for (int j = MAX (0 , level_pos - 2 *search_depth); j < level_pos; j++) {
414
+ if (ctx->concur_list [j] >= 0 && gf->nodes [ctx->concur_list[j]] == src_cur) {
415
+ is_found = 1 ;
416
+ break ;
417
+ }
418
+ }
419
+ if (is_found == 0 ) {
420
+ exe_flag = 0 ;
421
+ break ;
402
422
}
403
- if (is_found == 0 ) {exe_flag = 0 ; break ;}
404
423
}
405
424
}
406
425
if (exe_flag) {
407
426
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
427
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
428
int64_t data_start = (int64_t ) gf->nodes [i]->data ;
410
- int64_t length = (int64_t ) ggml_nbytes (gf->nodes [i]);
429
+ int64_t length = (int64_t ) ggml_nbytes (gf->nodes [i]);
411
430
for (int j = n_start; j < i; j++) {
412
431
if (nodes_unused[j] && gf->nodes [j]->op != GGML_OP_RESHAPE \
413
432
&& gf->nodes [j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
416
435
if (((int64_t )gf->nodes [j]->data ) >= data_start + length || \
417
436
((int64_t )gf->nodes [j]->data ) + (int64_t ) ggml_nbytes (gf->nodes [j]) <= data_start) {
418
437
continue ;
419
- } else {
420
- exe_flag = 0 ;
421
438
}
439
+
440
+ exe_flag = 0 ;
422
441
}
423
442
}
424
443
}
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
435
454
ctx->concur_list [level_pos + concurrency] = -1 ;
436
455
ctx->concur_list_len ++;
437
456
// jump all sorted nodes at nodes_bak
438
- while (!nodes_unused[n_start]) {n_start++;}
457
+ while (!nodes_unused[n_start]) {
458
+ n_start++;
459
+ }
439
460
level_pos += concurrency + 1 ;
440
461
}
441
462
442
- if (ctx->concur_list_len > GGML_MAX_NODES ) {
463
+ if (ctx->concur_list_len > GGML_MAX_CONCUR ) {
443
464
fprintf (stderr, " %s : too many elements for metal ctx->concur_list!\n " , __func__);
444
465
}
445
466
}
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
453
474
// else fallback to serial dispatch
454
475
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor ;
455
476
456
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES ;
477
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR ;
457
478
458
479
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes ;
459
480
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial ;
0 commit comments