Skip to content

metal : fix out-of-bounds access + style changes #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>

#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

#ifdef GGML_METAL_NDEBUG
#define metal_printf(...)
#else
Expand All @@ -15,6 +20,8 @@

#define UNUSED(x) (void)(x)

#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)

struct ggml_metal_buffer {
const char * name;

Expand All @@ -36,7 +43,7 @@
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];

int concur_list[GGML_MAX_NODES];
int concur_list[GGML_MAX_CONCUR];
int concur_list_len;

// custom kernels
Expand Down Expand Up @@ -370,44 +377,56 @@ void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
int nodes_unused[GGML_MAX_NODES];
int nodes_unused[GGML_MAX_CONCUR];

for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
ctx->concur_list_len = 0;

int n_left = gf->n_nodes;
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
int n_left = gf->n_nodes;
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos

while (n_left > 0) {
// number of nodes at a layer (that can be issued concurrently)
int concurrency = 0;
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
if (nodes_unused[i]) {
// if the requirements for gf->nodes[i] are satisfied
int exe_flag=1;
int exe_flag = 1;

// scan all srcs
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
if (src_cur) {
// if is leaf nodes it's satisfied.
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
// TODO: ggml_is_leaf()
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
continue;
}

// otherwise this src should be the output from previous nodes.
int is_found = 0;

// scan 2*search_depth back because we inserted barrier.
for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
is_found = 1;
break;
}
}
if (is_found == 0) {
exe_flag = 0;
break;
}
if (is_found == 0) {exe_flag = 0; break;}
}
}
if (exe_flag) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
&& gf->nodes[j]->op != GGML_OP_VIEW \
Expand All @@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
} else {
exe_flag = 0;
}

exe_flag = 0;
}
}
}
Expand All @@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
ctx->concur_list[level_pos + concurrency] = -1;
ctx->concur_list_len++;
// jump all sorted nodes at nodes_bak
while (!nodes_unused[n_start]) {n_start++;}
while (!nodes_unused[n_start]) {
n_start++;
}
level_pos += concurrency + 1;
}

if (ctx->concur_list_len > GGML_MAX_NODES) {
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}
Expand All @@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
// else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;

const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;

const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
Expand Down