Skip to content

Commit 91f78e1

Browse files
authored
Merge pull request #220 from agray3/ag_ggml_graph_caching
ggml: avoid rebuild of GGML graph for each token (ggml-org#7456)
2 parents 33b1c1e + bca068f commit 91f78e1

File tree

3 files changed

+152
-8
lines changed

3 files changed

+152
-8
lines changed

ggml/include/ggml-backend.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ extern "C" {
230230
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
231231
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
232232

233+
// Utility to query whether cached GGML graph is in use
234+
GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched);
235+
236+
// Set whether or not to use GGML graph caching
237+
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
233238

234239
#ifdef __cplusplus
235240
}

ggml/src/ggml-backend.c

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,13 @@ struct ggml_backend_sched_split {
10361036
struct ggml_cgraph graph;
10371037
};
10381038

1039+
// Object to facilitate GML graph caching
1040+
struct ggml_cached_graph {
1041+
bool is_active;
1042+
ggml_backend_t input_backend;
1043+
struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS];
1044+
};
1045+
10391046
struct ggml_backend_sched {
10401047
bool is_reset; // true if the scheduler has been reset since the last graph split
10411048
bool is_alloc;
@@ -1087,6 +1094,8 @@ struct ggml_backend_sched {
10871094
__attribute__((aligned(GGML_MEM_ALIGN)))
10881095
#endif
10891096
char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
1097+
1098+
struct ggml_cached_graph cached_graph;
10901099
};
10911100

10921101
#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
@@ -1753,6 +1762,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
17531762
struct ggml_tensor * input = split->inputs[j];
17541763
struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
17551764

1765+
if (!sched->cached_graph.is_active) {
1766+
sched->cached_graph.input_backend = input_backend;
1767+
sched->cached_graph.input_cpy[j] = input_cpy;
1768+
}
1769+
else {
1770+
input_backend = sched->cached_graph.input_backend;
1771+
input_cpy = sched->cached_graph.input_cpy[j];
1772+
}
17561773
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
17571774
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
17581775
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
@@ -1872,6 +1889,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
18721889

18731890
ggml_backend_sched_reset(sched);
18741891

1892+
sched->cached_graph.is_active = false;
1893+
18751894
return sched;
18761895
}
18771896

@@ -1947,6 +1966,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
19471966
}
19481967

19491968
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
1969+
1970+
if(!sched->cached_graph.is_active)
1971+
{
19501972
if (!sched->is_reset && !sched->is_alloc) {
19511973
ggml_backend_sched_reset(sched);
19521974
}
@@ -1956,7 +1978,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
19561978
return GGML_STATUS_ALLOC_FAILED;
19571979
}
19581980
}
1959-
1981+
}
19601982
return ggml_backend_sched_compute_splits(sched);
19611983
}
19621984

@@ -2223,3 +2245,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
22232245

22242246
return true;
22252247
}
2248+
2249+
bool ggml_use_cached_graph(ggml_backend_sched_t sched) {
2250+
return sched->cached_graph.is_active;
2251+
}
2252+
2253+
void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) {
2254+
sched->cached_graph.is_active = set_value;
2255+
}
2256+

src/llama.cpp

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,6 +2712,16 @@ struct llama_model {
27122712
}
27132713
};
27142714

2715+
// Object used to allow caching of GGML graph between tokens where possible.
2716+
struct ggml_cached_graph {
2717+
ggml_cgraph * gf;
2718+
size_t n;
2719+
ggml_backend_t backend_res;
2720+
ggml_backend_t backend_embd;
2721+
struct ggml_tensor * res;
2722+
struct ggml_tensor * embd;
2723+
};
2724+
27152725
struct llama_context {
27162726
llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {}
27172727
~llama_context() {
@@ -2813,6 +2823,8 @@ struct llama_context {
28132823

28142824
// control vectors
28152825
struct llama_control_vector cvec;
2826+
2827+
struct ggml_cached_graph cached_graph;
28162828
};
28172829

28182830
static size_t llama_get_device_count(const llama_model & model) {
@@ -14524,12 +14536,37 @@ static int llama_decode_internal(
1452414536
ggml_backend_sched_reset(lctx.sched);
1452514537
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1452614538

14527-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
14528-
14539+
ggml_cgraph * gf;
1452914540
// the output is always the last tensor in the graph
14530-
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
14531-
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
14541+
struct ggml_tensor * res;
14542+
struct ggml_tensor * embd;
14543+
14544+
bool n_has_changed_since_last_token = false;
14545+
if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true;
14546+
lctx.cached_graph.n = kv_self.n;
14547+
14548+
// Re-build graph only if graph caching is not possible
14549+
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14550+
14551+
gf = llama_build_graph(lctx, u_batch, false);
14552+
14553+
// disable future graph caching in presense of env var,
14554+
// if there are multiple devices, or if batch size is greater than 1
14555+
// TO DO enable graph caching for these cases
14556+
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14557+
|| (llama_get_device_count(model) > 1);
14558+
for (int i = 0 ; i < gf->n_nodes; i++) {
14559+
if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14560+
disable_cached_ggml_graph = true;
14561+
break;
14562+
}
14563+
}
14564+
14565+
if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
1453214566

14567+
// the output is always the last tensor in the graph
14568+
res = gf->nodes[gf->n_nodes - 1];
14569+
embd = gf->nodes[gf->n_nodes - 2];
1453314570
if (lctx.n_outputs == 0) {
1453414571
// no output
1453514572
res = nullptr;
@@ -14545,10 +14582,71 @@ static int llama_decode_internal(
1454514582
embd = nullptr; // do not extract embeddings when not needed
1454614583
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
1454714584
}
14585+
lctx.cached_graph.res = res;
14586+
lctx.cached_graph.embd = embd;
1454814587
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1454914588

1455014589
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1455114590

14591+
}
14592+
else {
14593+
gf = lctx.cached_graph.gf;
14594+
res = lctx.cached_graph.res;
14595+
embd = lctx.cached_graph.embd;
14596+
}
14597+
lctx.cached_graph.gf = gf;
14598+
14599+
if(ggml_use_cached_graph(lctx.sched)) {
14600+
14601+
// If using flash attention, find mask node so it can be skipped when updating
14602+
// KV cache paramaters in cached graph nodes below
14603+
void * flash_attn_mask_node = nullptr;
14604+
if(cparams.flash_attn) {
14605+
for (int i = 0; i < gf->n_nodes; i++) {
14606+
ggml_tensor * node = gf->nodes[i];
14607+
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14608+
flash_attn_mask_node = node->src[3];
14609+
break;
14610+
}
14611+
}
14612+
}
14613+
14614+
// Temporarily store KV cache parameters that will need updated in cached graph.
14615+
const struct llama_hparams & hparams = model.hparams;
14616+
const int64_t n_layer = hparams.n_layer;
14617+
const int64_t kv_head = kv_self.head;
14618+
std::vector<void *> kv_cache_ptrs;
14619+
for (int il = 0; il < n_layer; ++il) {
14620+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14621+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14622+
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14623+
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14624+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625+
tmp_tensor = kv_self.v_l[il];
14626+
if (cparams.flash_attn) {
14627+
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14628+
} else {
14629+
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14630+
}
14631+
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14632+
}
14633+
14634+
// Update KV cache parameters in cached graph.
14635+
int copy_op_count = 0;
14636+
if(gf != nullptr && gf->nodes != nullptr){
14637+
for (int i = 0; i < gf->n_nodes; i++) {
14638+
ggml_tensor * node = gf->nodes[i];
14639+
if (node->op == GGML_OP_CPY) {
14640+
if (node != flash_attn_mask_node) {
14641+
node->src[1]->data = kv_cache_ptrs[copy_op_count];
14642+
copy_op_count++;
14643+
}
14644+
}
14645+
}
14646+
}
14647+
14648+
}
14649+
1455214650
llama_set_inputs(lctx, u_batch);
1455314651

1455414652
llama_graph_compute(lctx, gf, n_threads);
@@ -14571,11 +14669,15 @@ static int llama_decode_internal(
1457114669
// extract logits
1457214670
if (res) {
1457314671
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
14574-
GGML_ASSERT(backend_res != nullptr);
14575-
GGML_ASSERT(lctx.logits != nullptr);
14576-
1457714672
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
1457814673
const int32_t n_outputs_new = lctx.n_outputs;
14674+
if(!ggml_use_cached_graph(lctx.sched))
14675+
lctx.cached_graph.backend_res = backend_res;
14676+
else
14677+
backend_res = lctx.cached_graph.backend_res;
14678+
14679+
GGML_ASSERT(backend_res != nullptr);
14680+
GGML_ASSERT(lctx.logits != nullptr);
1457914681

1458014682
if (n_outputs_new) {
1458114683
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
@@ -14587,6 +14689,12 @@ static int llama_decode_internal(
1458714689
// extract embeddings
1458814690
if (embd) {
1458914691
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
14692+
14693+
14694+
if(!ggml_use_cached_graph(lctx.sched))
14695+
lctx.cached_graph.backend_embd = backend_embd;
14696+
else
14697+
backend_embd = lctx.cached_graph.backend_embd;
1459014698
GGML_ASSERT(backend_embd != nullptr);
1459114699

1459214700
switch (cparams.pooling_type) {

0 commit comments

Comments
 (0)