-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Add attention and final logit soft-capping, update scaling factor to Gemma2 #8197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d3f17b
d3d3c4e
d1137c2
f4424c1
3a24718
8edf73a
bb71599
6f2464e
a894279
51f0bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,8 @@ enum llm_kv { | |
LLM_KV_POOLING_TYPE, | ||
LLM_KV_LOGIT_SCALE, | ||
LLM_KV_DECODER_START_TOKEN_ID, | ||
LLM_KV_ATTN_LOGIT_SOFTCAPPING, | ||
LLM_KV_FINAL_LOGIT_SOFTCAPPING, | ||
|
||
LLM_KV_ATTENTION_HEAD_COUNT, | ||
LLM_KV_ATTENTION_HEAD_COUNT_KV, | ||
|
@@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { | |
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" }, | ||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, | ||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, | ||
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, | ||
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, | ||
|
||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, | ||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, | ||
|
@@ -2099,6 +2103,9 @@ struct llama_hparams { | |
float f_norm_eps; | ||
float f_norm_rms_eps; | ||
|
||
float f_attn_logit_softcapping = 50.0f; | ||
float f_final_logit_softcapping = 30.0f; | ||
|
||
float rope_attn_factor = 1.0f; | ||
float rope_freq_base_train; | ||
float rope_freq_scale_train; | ||
|
@@ -2115,8 +2122,9 @@ struct llama_hparams { | |
float f_max_alibi_bias = 0.0f; | ||
float f_logit_scale = 0.0f; | ||
|
||
bool causal_attn = true; | ||
bool use_alibi = false; | ||
bool causal_attn = true; | ||
bool use_alibi = false; | ||
bool attn_soft_cap = false; | ||
|
||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; | ||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; | ||
|
@@ -4702,6 +4710,9 @@ static void llm_load_hparams( | |
case LLM_ARCH_GEMMA2: | ||
{ | ||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); | ||
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); | ||
hparams.attn_soft_cap = true; | ||
|
||
switch (hparams.n_layer) { | ||
case 42: model.type = e_model::MODEL_9B; break; | ||
|
@@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv( | |
kq = ggml_scale(ctx, kq, 30); | ||
} | ||
|
||
if (hparams.attn_soft_cap) { | ||
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); | ||
kq = ggml_tanh(ctx, kq); | ||
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); | ||
} | ||
|
||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); | ||
cb(kq, "kq_soft_max_ext", il); | ||
|
||
|
@@ -11039,7 +11056,7 @@ struct llm_build_context { | |
ext_factor, attn_factor, beta_fast, beta_slow); | ||
cb(Qcur, "Qcur", il); | ||
|
||
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); | ||
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); | ||
slaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cb(Qcur, "Qcur_scaled", il); | ||
|
||
Kcur = ggml_rope_ext( | ||
|
@@ -11106,6 +11123,12 @@ struct llm_build_context { | |
|
||
// lm_head | ||
cur = ggml_mul_mat(ctx0, model.output, cur); | ||
|
||
// final logit soft-capping | ||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Total nitpick that probably should be ignored. I came here from curiosity, and I know this is merged by now and I have absolutely no place to comment. But While I’m a proponent of the “rule of 3” l think there’s merit in extracting it to something like a separate |
||
cur = ggml_tanh(ctx0, cur); | ||
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); | ||
|
||
cb(cur, "result_output", -1); | ||
|
||
ggml_build_forward_expand(gf, cur); | ||
|
@@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model( | |
params.flash_attn = false; | ||
} | ||
|
||
if (params.flash_attn && model->hparams.attn_soft_cap) { | ||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); | ||
params.flash_attn = false; | ||
} | ||
|
||
|
||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { | ||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); | ||
params.flash_attn = false; | ||
|
Uh oh!
There was an error while loading. Please reload this page.