Skip to content

Commit abaf0c6

Browse files
committed
Improve identification of K and V nodes for param updates
1 parent a34900a commit abaf0c6

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

ggml/include/ggml.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,13 @@ extern "C" {
552552
GGML_TENSOR_FLAG_PARAM = 4,
553553
};
554554

555+
// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
556+
enum ggml_kv_cache_flag {
557+
GGML_KV_CACHE_FLAG_NONE = 0,
558+
GGML_KV_CACHE_FLAG_K = 1,
559+
GGML_KV_CACHE_FLAG_V = 2
560+
};
561+
555562
// ggml object
556563
struct ggml_object {
557564
size_t offs;
@@ -586,6 +593,8 @@ extern "C" {
586593
// op params - allocated as int32_t for alignment
587594
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
588595

596+
enum ggml_kv_cache_flag kv_cache_flag;
597+
589598
int32_t flags;
590599

591600
struct ggml_tensor * grad;
@@ -601,7 +610,7 @@ extern "C" {
601610

602611
void * extra; // extra things e.g. for ggml-cuda.cu
603612

604-
// char padding[4];
613+
char padding[1];
605614
};
606615

607616
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3638,6 +3638,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
36383638
/*.nb =*/ { 0, 0, 0, 0 },
36393639
/*.op =*/ GGML_OP_NONE,
36403640
/*.op_params =*/ { 0 },
3641+
/*.kv_cache_flag=*/ GGML_KV_CACHE_FLAG_NONE,
36413642
/*.flags =*/ 0,
36423643
/*.grad =*/ NULL,
36433644
/*.src =*/ { NULL },
@@ -3646,7 +3647,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
36463647
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
36473648
/*.name =*/ { 0 },
36483649
/*.extra =*/ NULL,
3649-
///*.padding =*/ { 0 },
3650+
/*.padding =*/ { 0 },
36503651
};
36513652

36523653
#ifdef __clang__

src/llama.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7794,7 +7794,9 @@ static void llm_build_kv_store(
77947794
cb(k_cache_view, "k_cache_view", il);
77957795

77967796
// note: storing RoPE-ed version of K in the KV cache
7797-
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
7797+
ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
7798+
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
7799+
ggml_build_forward_expand(graph, tmp);
77987800

77997801
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
78007802

@@ -7812,8 +7814,9 @@ static void llm_build_kv_store(
78127814
v_cur = ggml_transpose(ctx, v_cur);
78137815
}
78147816
cb(v_cache_view, "v_cache_view", il);
7815-
7816-
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
7817+
tmp=ggml_cpy(ctx, v_cur, v_cache_view);
7818+
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
7819+
ggml_build_forward_expand(graph, tmp);
78177820
}
78187821

78197822
static struct ggml_tensor * llm_build_norm(
@@ -14606,48 +14609,42 @@ static int llama_decode_internal(
1460614609

1460714610
if(ggml_use_cached_graph(lctx.sched)) {
1460814611

14609-
// If using flash attention, find mask node so it can be skipped when updating
14610-
// KV cache paramaters in cached graph nodes below
14611-
void * flash_attn_mask_node = nullptr;
14612-
if(cparams.flash_attn) {
14613-
for (int i = 0; i < gf->n_nodes; i++) {
14614-
ggml_tensor * node = gf->nodes[i];
14615-
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
14616-
flash_attn_mask_node = node->src[3];
14617-
break;
14618-
}
14619-
}
14620-
}
14621-
1462214612
// Temporarily store KV cache parameters that will need updated in cached graph.
1462314613
const struct llama_hparams & hparams = model.hparams;
1462414614
const int64_t n_layer = hparams.n_layer;
1462514615
const int64_t kv_head = kv_self.head;
1462614616
std::vector<void *> kv_cache_ptrs;
14617+
std::vector<void *> k_cache_ptrs;
14618+
std::vector<void *> v_cache_ptrs;
1462714619
for (int il = 0; il < n_layer; ++il) {
1462814620
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1462914621
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
1463014622
ggml_tensor * tmp_tensor = kv_self.k_l[il];
1463114623
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
1463214624
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14625+
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
1463314626
tmp_tensor = kv_self.v_l[il];
1463414627
if (cparams.flash_attn) {
1463514628
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
1463614629
} else {
1463714630
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
1463814631
}
1463914632
kv_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14633+
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
1464014634
}
1464114635

1464214636
// Update KV cache parameters in cached graph.
14643-
int copy_op_count = 0;
14637+
int k_count = 0;
14638+
int v_count = 0;
1464414639
if(gf != nullptr && gf->nodes != nullptr){
1464514640
for (int i = 0; i < gf->n_nodes; i++) {
1464614641
ggml_tensor * node = gf->nodes[i];
1464714642
if (node->op == GGML_OP_CPY) {
14648-
if (node != flash_attn_mask_node) {
14649-
node->src[1]->data = kv_cache_ptrs[copy_op_count];
14650-
copy_op_count++;
14643+
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
14644+
node->src[1]->data = k_cache_ptrs[k_count++];
14645+
}
14646+
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
14647+
node->src[1]->data = v_cache_ptrs[v_count++];
1465114648
}
1465214649
}
1465314650
}

0 commit comments

Comments
 (0)