Skip to content

Commit 81d4162

Browse files
committed
backend : group nodes in a single compute when user don't need them
1 parent 041284d commit 81d4162

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

examples/simple/simple.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,27 @@
88

99
// a function that can be called for every computed node during graph evaluation
1010
// the user can choose to whether to observe the data of the node depending on the tensor parameters
11-
static bool observe_compute(int node_index, struct ggml_tensor * t, void * user_data) {
11+
static bool observe_compute(int node_index, struct ggml_tensor * t, bool ask, void * user_data) {
1212
GGML_UNUSED(user_data);
1313

14-
// check if name contains soft_max
15-
if (strstr(t->name, "soft_max") != 0) {
16-
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
17-
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
14+
// the scheduler is asking us if we want to observe this node
15+
if (ask) {
16+
// check if name contains soft_max
17+
return strstr(t->name, "soft_max") != 0;
18+
}
1819

19-
std::vector<float> t_data(ggml_nelements(t));
20-
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
20+
// print the node data
21+
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
22+
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
2123

22-
// print first row
23-
for (int i = 0; i < t->ne[0]; i++) {
24-
printf("%8.4f ", t_data[i]);
25-
}
26-
printf("\n");
24+
std::vector<float> t_data(ggml_nelements(t));
25+
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
26+
27+
// print first row
28+
for (int i = 0; i < t->ne[0]; i++) {
29+
printf("%8.4f ", t_data[i]);
2730
}
31+
printf("\n");
2832

2933
return true;
3034
}

ggml-backend.c

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,18 +1337,25 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
13371337
for (int j = 0; j < split->graph.n_nodes; j++) {
13381338
struct ggml_tensor * t = split->graph.nodes[j];
13391339

1340-
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, j + 1);
1340+
int k = j;
13411341

1342-
ggml_backend_graph_compute(split_backend, &gv);
1343-
1344-
if (ggml_is_view_op(t->op)) {
1345-
continue;
1342+
// check if the user needs data from this node
1343+
while (!sched->callback_eval(k, t, true, sched->callback_eval_user_data) && k < split->graph.n_nodes - 1) {
1344+
t = split->graph.nodes[++k];
13461345
}
13471346

1348-
// TODO: j is node index in the split, not in the original graph
1349-
if (!sched->callback_eval(j, t, sched->callback_eval_user_data)) {
1347+
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, k + 1);
1348+
1349+
ggml_backend_graph_compute(split_backend, &gv);
1350+
1351+
// TODO: k is node index in the split, not in the original graph
1352+
// TODO: avoid the ask == true call here
1353+
if (sched->callback_eval(k, t, true, sched->callback_eval_user_data) &&
1354+
!sched->callback_eval(k, t, false, sched->callback_eval_user_data)) {
13501355
break;
13511356
}
1357+
1358+
j = k;
13521359
}
13531360
}
13541361
uint64_t compute_end_us = ggml_time_us();

ggml-backend.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,14 @@ extern "C" {
148148
struct ggml_backend_sched;
149149
typedef struct ggml_backend_sched * ggml_backend_sched_t;
150150

151+
// when ask == true, the scheduler wants to know if the user wants to observe this node
152+
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
153+
//
154+
// when ask == false, the scheduler is passing the node tensor to the user for observation
155+
// if the user returns false, the scheduler will cancel the graph compute
156+
//
151157
// TODO: propose to rename to ggml_backend_sched_callback_eval
152-
typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, void * user_data);
158+
typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, bool ask, void * user_data);
153159

154160
// Initialize a backend scheduler
155161
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);

0 commit comments

Comments
 (0)