diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 19a7adb2d101b..4ea39ddde90c4 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -570,27 +570,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n return true; } -// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[] +// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[] // and are fusable. Nodes are considered fusable according to this function if: // - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses). // - all nodes except the last are a src of the following node. // - all nodes are the same shape. // TODO: Consider allowing GGML_OP_NONE nodes in between -static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { - if (node_idx + num_ops > cgraph->n_nodes) { - return false; - } - +static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) { for (int i = 0; i < num_ops; ++i) { - struct ggml_tensor * node = cgraph->nodes[node_idx + i]; + if (node_idxs[i] >= cgraph->n_nodes) { + return false; + } + + struct ggml_tensor * node = cgraph->nodes[node_idxs[i]]; if (node->op != ops[i]) { return false; } - if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) { + if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } if (i > 0) { - struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1]; + struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]]; if (node->src[0] != prev && node->src[1] != prev) { return false; } @@ -602,6 +602,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return true; } +// same as above, for sequential indices starting at node_idx +static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { + assert(num_ops < 32); + + if (node_idx + num_ops > cgraph->n_nodes) { + return false; + } + + int idxs[32]; + for (int i = 0; i < num_ops; ++i) { + idxs[i] = node_idx + i; + } + + return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); +} + #ifdef __cplusplus } #endif