@@ -317,6 +317,7 @@ enum llm_kv {
317
317
LLM_KV_ATTENTION_Q_LORA_RANK,
318
318
LLM_KV_ATTENTION_KV_LORA_RANK,
319
319
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
320
+ LLM_KV_ATTENTION_SLIDING_WINDOW,
320
321
321
322
LLM_KV_ROPE_DIMENSION_COUNT,
322
323
LLM_KV_ROPE_FREQ_BASE,
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
409
410
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
410
411
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
411
412
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
413
+ { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
412
414
413
415
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
414
416
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
@@ -2085,6 +2087,7 @@ struct llama_hparams {
2085
2087
uint32_t n_head_kv;
2086
2088
uint32_t n_layer;
2087
2089
uint32_t n_rot;
2090
+ uint32_t n_swa = 0; // sliding window attention (SWA)
2088
2091
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
2089
2092
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
2090
2093
uint32_t n_ff;
@@ -2139,6 +2142,7 @@ struct llama_hparams {
2139
2142
if (this->n_head_kv != other.n_head_kv) return true;
2140
2143
if (this->n_layer != other.n_layer) return true;
2141
2144
if (this->n_rot != other.n_rot) return true;
2145
+ if (this->n_swa != other.n_swa) return true;
2142
2146
if (this->n_embd_head_k != other.n_embd_head_k) return true;
2143
2147
if (this->n_embd_head_v != other.n_embd_head_v) return true;
2144
2148
if (this->n_ff != other.n_ff) return true;
@@ -2649,17 +2653,18 @@ struct llama_context {
2649
2653
void * abort_callback_data = nullptr;
2650
2654
2651
2655
// input tensors
2652
- struct ggml_tensor * inp_tokens; // I32 [n_batch]
2653
- struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2654
- struct ggml_tensor * inp_pos; // I32 [n_batch]
2655
- struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2656
- struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2657
- struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2658
- struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2659
- struct ggml_tensor * inp_cls; // I32 [n_batch]
2660
- struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2661
- struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2662
- struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2656
+ struct ggml_tensor * inp_tokens; // I32 [n_batch]
2657
+ struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
2658
+ struct ggml_tensor * inp_pos; // I32 [n_batch]
2659
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
2660
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
2661
+ struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
2662
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2663
+ struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2664
+ struct ggml_tensor * inp_cls; // I32 [n_batch]
2665
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2666
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
2667
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2663
2668
2664
2669
// control vectors
2665
2670
struct llama_control_vector cvec;
@@ -4710,6 +4715,8 @@ static void llm_load_hparams(
4710
4715
} break;
4711
4716
case LLM_ARCH_GEMMA2:
4712
4717
{
4718
+ hparams.n_swa = 4096; // default value of gemma 2
4719
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
4713
4720
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4714
4721
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4715
4722
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -5420,6 +5427,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
5420
5427
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
5421
5428
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
5422
5429
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
5430
+ LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
5423
5431
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
5424
5432
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
5425
5433
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
@@ -7776,17 +7784,18 @@ struct llm_build_context {
7776
7784
7777
7785
ctx0 = ggml_init(params);
7778
7786
7779
- lctx.inp_tokens = nullptr;
7780
- lctx.inp_embd = nullptr;
7781
- lctx.inp_pos = nullptr;
7782
- lctx.inp_out_ids = nullptr;
7783
- lctx.inp_KQ_mask = nullptr;
7784
- lctx.inp_K_shift = nullptr;
7785
- lctx.inp_mean = nullptr;
7786
- lctx.inp_cls = nullptr;
7787
- lctx.inp_s_copy = nullptr;
7788
- lctx.inp_s_mask = nullptr;
7789
- lctx.inp_s_seq = nullptr;
7787
+ lctx.inp_tokens = nullptr;
7788
+ lctx.inp_embd = nullptr;
7789
+ lctx.inp_pos = nullptr;
7790
+ lctx.inp_out_ids = nullptr;
7791
+ lctx.inp_KQ_mask = nullptr;
7792
+ lctx.inp_KQ_mask_swa = nullptr;
7793
+ lctx.inp_K_shift = nullptr;
7794
+ lctx.inp_mean = nullptr;
7795
+ lctx.inp_cls = nullptr;
7796
+ lctx.inp_s_copy = nullptr;
7797
+ lctx.inp_s_mask = nullptr;
7798
+ lctx.inp_s_seq = nullptr;
7790
7799
}
7791
7800
7792
7801
void free() {
@@ -7805,7 +7814,6 @@ struct llm_build_context {
7805
7814
cb(lctx.inp_K_shift, "K_shift", -1);
7806
7815
ggml_set_input(lctx.inp_K_shift);
7807
7816
7808
-
7809
7817
for (int il = 0; il < n_layer; ++il) {
7810
7818
struct ggml_tensor * rope_factors = build_rope_factors(il);
7811
7819
struct ggml_tensor * tmp =
@@ -7940,16 +7948,27 @@ struct llm_build_context {
7940
7948
}
7941
7949
7942
7950
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7943
- if (causal) {
7944
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7945
- } else {
7946
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7947
- }
7951
+ lctx.inp_KQ_mask = causal
7952
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7953
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7948
7954
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
7949
7955
ggml_set_input(lctx.inp_KQ_mask);
7956
+
7950
7957
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
7951
7958
}
7952
7959
7960
+ struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
7961
+ GGML_ASSERT(hparams.n_swa > 0);
7962
+
7963
+ lctx.inp_KQ_mask_swa = causal
7964
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7965
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7966
+ cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
7967
+ ggml_set_input(lctx.inp_KQ_mask_swa);
7968
+
7969
+ return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
7970
+ }
7971
+
7953
7972
struct ggml_tensor * build_inp_mean() {
7954
7973
lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
7955
7974
cb(lctx.inp_mean, "inp_mean", -1);
@@ -11030,9 +11049,14 @@ struct llm_build_context {
11030
11049
struct ggml_tensor * inp_pos = build_inp_pos();
11031
11050
11032
11051
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11033
- struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
11052
+ // gemma 2 requires different mask for layers using sliding window (SWA)
11053
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
11054
+ struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
11034
11055
11035
11056
for (int il = 0; il < n_layer; ++il) {
11057
+ // (il % 2) layers use SWA
11058
+ struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
11059
+
11036
11060
// norm
11037
11061
cur = llm_build_norm(ctx0, inpL, hparams,
11038
11062
model.layers[il].attn_norm, NULL,
@@ -11068,7 +11092,7 @@ struct llm_build_context {
11068
11092
11069
11093
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
11070
11094
model.layers[il].wo, NULL,
11071
- Kcur, Vcur, Qcur, KQ_mask , n_tokens, kv_head, n_kv, 1.0f, cb, il);
11095
+ Kcur, Vcur, Qcur, KQ_mask_l , n_tokens, kv_head, n_kv, 1.0f, cb, il);
11072
11096
}
11073
11097
11074
11098
cur = llm_build_norm(ctx0, cur, hparams,
@@ -12671,7 +12695,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12671
12695
12672
12696
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
12673
12697
12674
- float * data = (float *) lctx.inp_KQ_mask->data;
12698
+ float * data = (float *) lctx.inp_KQ_mask->data;
12699
+ float * data_swa = nullptr;
12700
+
12701
+ if (lctx.inp_KQ_mask_swa) {
12702
+ data_swa = (float *) lctx.inp_KQ_mask_swa->data;
12703
+ }
12675
12704
12676
12705
// For causal attention, use only the previous KV cells
12677
12706
// of the correct sequence for each token of the batch.
@@ -12693,6 +12722,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12693
12722
}
12694
12723
}
12695
12724
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12725
+
12726
+ // may need to cut off old tokens for sliding window
12727
+ if (data_swa) {
12728
+ if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
12729
+ f = -INFINITY;
12730
+ }
12731
+ data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
12732
+ }
12696
12733
}
12697
12734
}
12698
12735
0 commit comments