@@ -7794,7 +7794,9 @@ static void llm_build_kv_store(
7794
7794
cb(k_cache_view, "k_cache_view", il);
7795
7795
7796
7796
// note: storing RoPE-ed version of K in the KV cache
7797
- ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7797
+ ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798
+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799
+ ggml_build_forward_expand(graph, tmp);
7798
7800
7799
7801
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
7800
7802
@@ -7812,8 +7814,9 @@ static void llm_build_kv_store(
7812
7814
v_cur = ggml_transpose(ctx, v_cur);
7813
7815
}
7814
7816
cb(v_cache_view, "v_cache_view", il);
7815
-
7816
- ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7817
+ tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818
+ tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819
+ ggml_build_forward_expand(graph, tmp);
7817
7820
}
7818
7821
7819
7822
static struct ggml_tensor * llm_build_norm(
@@ -14606,48 +14609,42 @@ static int llama_decode_internal(
14606
14609
14607
14610
if(ggml_use_cached_graph(lctx.sched)) {
14608
14611
14609
- // If using flash attention, find mask node so it can be skipped when updating
14610
- // KV cache paramaters in cached graph nodes below
14611
- void * flash_attn_mask_node = nullptr;
14612
- if(cparams.flash_attn) {
14613
- for (int i = 0; i < gf->n_nodes; i++) {
14614
- ggml_tensor * node = gf->nodes[i];
14615
- if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14616
- flash_attn_mask_node = node->src[3];
14617
- break;
14618
- }
14619
- }
14620
- }
14621
-
14622
14612
// Temporarily store KV cache parameters that will need updated in cached graph.
14623
14613
const struct llama_hparams & hparams = model.hparams;
14624
14614
const int64_t n_layer = hparams.n_layer;
14625
14615
const int64_t kv_head = kv_self.head;
14626
14616
std::vector<void *> kv_cache_ptrs;
14617
+ std::vector<void *> k_cache_ptrs;
14618
+ std::vector<void *> v_cache_ptrs;
14627
14619
for (int il = 0; il < n_layer; ++il) {
14628
14620
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14629
14621
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
14630
14622
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14631
14623
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14632
14624
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625
+ k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14633
14626
tmp_tensor = kv_self.v_l[il];
14634
14627
if (cparams.flash_attn) {
14635
14628
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14636
14629
} else {
14637
14630
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14638
14631
}
14639
14632
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14633
+ v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14640
14634
}
14641
14635
14642
14636
// Update KV cache parameters in cached graph.
14643
- int copy_op_count = 0;
14637
+ int k_count = 0;
14638
+ int v_count = 0;
14644
14639
if(gf != nullptr && gf->nodes != nullptr){
14645
14640
for (int i = 0; i < gf->n_nodes; i++) {
14646
14641
ggml_tensor * node = gf->nodes[i];
14647
14642
if (node->op == GGML_OP_CPY) {
14648
- if (node != flash_attn_mask_node) {
14649
- node->src[1]->data = kv_cache_ptrs[copy_op_count];
14650
- copy_op_count++;
14643
+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14644
+ node->src[1]->data = k_cache_ptrs[k_count++];
14645
+ }
14646
+ if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14647
+ node->src[1]->data = v_cache_ptrs[v_count++];
14651
14648
}
14652
14649
}
14653
14650
}
0 commit comments