Skip to content

Commit 610c291

Browse files
ngxsonarlo-phoenixggerganov
authored and
Neo Zhang
committed
gemma2: add sliding window mask (ggml-org#8227)
* gemma2: add sliding window mask * fix data_swa uninitialized * better naming * add co-author Co-authored-by: Arlo Phoenix <[email protected]> * replace list with single tensor * update * llama : minor styling * convert : add sanity check for query_pre_attn_scalar * fix small typo in README --------- Co-authored-by: Arlo Phoenix <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 952b9a3 commit 610c291

File tree

5 files changed

+79
-32
lines changed

5 files changed

+79
-32
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
218218
**Tools:**
219219

220220
- [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
221-
[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
221+
- [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
222222

223223
---
224224

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,6 +2369,12 @@ def set_gguf_parameters(self):
23692369
self.gguf_writer.add_final_logit_softcapping(
23702370
self.hparams["final_logit_softcapping"]
23712371
)
2372+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
2373+
2374+
# sanity check
2375+
attn_scalar = self.hparams["query_pre_attn_scalar"]
2376+
if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
2377+
raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
23722378

23732379
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
23742380
del bid # unusem

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class Attention:
6666
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
6767
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
6868
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
69+
SLIDING_WINDOW = "{arch}.attention.sliding_window"
6970

7071
class Rope:
7172
DIMENSION_COUNT = "{arch}.rope.dimension_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None:
552552
def add_relative_attn_buckets_count(self, value: int) -> None:
553553
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
554554

555+
def add_sliding_window(self, value: int) -> None:
556+
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
557+
555558
def add_pooling_type(self, value: PoolingType) -> None:
556559
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
557560

src/llama.cpp

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ enum llm_kv {
317317
LLM_KV_ATTENTION_Q_LORA_RANK,
318318
LLM_KV_ATTENTION_KV_LORA_RANK,
319319
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
320+
LLM_KV_ATTENTION_SLIDING_WINDOW,
320321

321322
LLM_KV_ROPE_DIMENSION_COUNT,
322323
LLM_KV_ROPE_FREQ_BASE,
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
409410
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
410411
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
411412
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
413+
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
412414

413415
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
414416
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -2085,6 +2087,7 @@ struct llama_hparams {
20852087
uint32_t n_head_kv;
20862088
uint32_t n_layer;
20872089
uint32_t n_rot;
2090+
uint32_t n_swa = 0; // sliding window attention (SWA)
20882091
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
20892092
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
20902093
uint32_t n_ff;
@@ -2139,6 +2142,7 @@ struct llama_hparams {
21392142
if (this->n_head_kv != other.n_head_kv) return true;
21402143
if (this->n_layer != other.n_layer) return true;
21412144
if (this->n_rot != other.n_rot) return true;
2145+
if (this->n_swa != other.n_swa) return true;
21422146
if (this->n_embd_head_k != other.n_embd_head_k) return true;
21432147
if (this->n_embd_head_v != other.n_embd_head_v) return true;
21442148
if (this->n_ff != other.n_ff) return true;
@@ -2649,17 +2653,18 @@ struct llama_context {
26492653
void * abort_callback_data = nullptr;
26502654

26512655
// input tensors
2652-
struct ggml_tensor * inp_tokens; // I32 [n_batch]
2653-
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2654-
struct ggml_tensor * inp_pos; // I32 [n_batch]
2655-
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2656-
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2657-
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2658-
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2659-
struct ggml_tensor * inp_cls; // I32 [n_batch]
2660-
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2661-
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2662-
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2656+
struct ggml_tensor * inp_tokens; // I32 [n_batch]
2657+
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2658+
struct ggml_tensor * inp_pos; // I32 [n_batch]
2659+
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2660+
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2661+
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
2662+
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2663+
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2664+
struct ggml_tensor * inp_cls; // I32 [n_batch]
2665+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2666+
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2667+
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
26632668

26642669
// control vectors
26652670
struct llama_control_vector cvec;
@@ -4710,6 +4715,8 @@ static void llm_load_hparams(
47104715
} break;
47114716
case LLM_ARCH_GEMMA2:
47124717
{
4718+
hparams.n_swa = 4096; // default value of gemma 2
4719+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
47134720
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
47144721
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
47154722
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -5420,6 +5427,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
54205427
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
54215428
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
54225429
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
5430+
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
54235431
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
54245432
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
54255433
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
@@ -7776,17 +7784,18 @@ struct llm_build_context {
77767784

77777785
ctx0 = ggml_init(params);
77787786

7779-
lctx.inp_tokens = nullptr;
7780-
lctx.inp_embd = nullptr;
7781-
lctx.inp_pos = nullptr;
7782-
lctx.inp_out_ids = nullptr;
7783-
lctx.inp_KQ_mask = nullptr;
7784-
lctx.inp_K_shift = nullptr;
7785-
lctx.inp_mean = nullptr;
7786-
lctx.inp_cls = nullptr;
7787-
lctx.inp_s_copy = nullptr;
7788-
lctx.inp_s_mask = nullptr;
7789-
lctx.inp_s_seq = nullptr;
7787+
lctx.inp_tokens = nullptr;
7788+
lctx.inp_embd = nullptr;
7789+
lctx.inp_pos = nullptr;
7790+
lctx.inp_out_ids = nullptr;
7791+
lctx.inp_KQ_mask = nullptr;
7792+
lctx.inp_KQ_mask_swa = nullptr;
7793+
lctx.inp_K_shift = nullptr;
7794+
lctx.inp_mean = nullptr;
7795+
lctx.inp_cls = nullptr;
7796+
lctx.inp_s_copy = nullptr;
7797+
lctx.inp_s_mask = nullptr;
7798+
lctx.inp_s_seq = nullptr;
77907799
}
77917800

77927801
void free() {
@@ -7805,7 +7814,6 @@ struct llm_build_context {
78057814
cb(lctx.inp_K_shift, "K_shift", -1);
78067815
ggml_set_input(lctx.inp_K_shift);
78077816

7808-
78097817
for (int il = 0; il < n_layer; ++il) {
78107818
struct ggml_tensor * rope_factors = build_rope_factors(il);
78117819
struct ggml_tensor * tmp =
@@ -7940,16 +7948,27 @@ struct llm_build_context {
79407948
}
79417949

79427950
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7943-
if (causal) {
7944-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7945-
} else {
7946-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7947-
}
7951+
lctx.inp_KQ_mask = causal
7952+
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7953+
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
79487954
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
79497955
ggml_set_input(lctx.inp_KQ_mask);
7956+
79507957
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
79517958
}
79527959

7960+
struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
7961+
GGML_ASSERT(hparams.n_swa > 0);
7962+
7963+
lctx.inp_KQ_mask_swa = causal
7964+
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7965+
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7966+
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
7967+
ggml_set_input(lctx.inp_KQ_mask_swa);
7968+
7969+
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
7970+
}
7971+
79537972
struct ggml_tensor * build_inp_mean() {
79547973
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
79557974
cb(lctx.inp_mean, "inp_mean", -1);
@@ -11030,9 +11049,14 @@ struct llm_build_context {
1103011049
struct ggml_tensor * inp_pos = build_inp_pos();
1103111050

1103211051
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11033-
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11052+
// gemma 2 requires different mask for layers using sliding window (SWA)
11053+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11054+
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
1103411055

1103511056
for (int il = 0; il < n_layer; ++il) {
11057+
// (il % 2) layers use SWA
11058+
struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11059+
1103611060
// norm
1103711061
cur = llm_build_norm(ctx0, inpL, hparams,
1103811062
model.layers[il].attn_norm, NULL,
@@ -11068,7 +11092,7 @@ struct llm_build_context {
1106811092

1106911093
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
1107011094
model.layers[il].wo, NULL,
11071-
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
11095+
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
1107211096
}
1107311097

1107411098
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12671,7 +12695,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1267112695

1267212696
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1267312697

12674-
float * data = (float *) lctx.inp_KQ_mask->data;
12698+
float * data = (float *) lctx.inp_KQ_mask->data;
12699+
float * data_swa = nullptr;
12700+
12701+
if (lctx.inp_KQ_mask_swa) {
12702+
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
12703+
}
1267512704

1267612705
// For causal attention, use only the previous KV cells
1267712706
// of the correct sequence for each token of the batch.
@@ -12693,6 +12722,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1269312722
}
1269412723
}
1269512724
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12725+
12726+
// may need to cut off old tokens for sliding window
12727+
if (data_swa) {
12728+
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
12729+
f = -INFINITY;
12730+
}
12731+
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12732+
}
1269612733
}
1269712734
}
1269812735

0 commit comments

Comments
 (0)