-
Notifications
You must be signed in to change notification settings - Fork 13.4k
llama: implement YaRN RoPE scaling #2268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
8dec38c
6aeb46b
9348aa4
a30ae20
b5ced4f
826269a
cf731d5
dcb058c
281b26e
a06c729
dc26a0d
904d4ed
56abb9a
43eaf06
fe788c4
e0b120c
19bb74e
4d5fe73
7466415
4f4e948
5d7a3a5
9bd050f
babf0e0
0050e1e
09c3102
57c3442
a20b3e6
9ef91b1
9ae10b3
14cf93b
237f1e7
bc8395d
4d5ed83
9fc8238
15f26ef
081f738
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,12 +192,46 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |
break; | ||
} | ||
params.rope_freq_scale = std::stof(argv[i]); | ||
} else if (arg == "--rope-scaling") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
std::string value(argv[i]); | ||
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; } | ||
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; } | ||
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; } | ||
else { invalid_param = true; break; } | ||
} else if (arg == "--rope-scale") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
params.rope_freq_scale = 1.0f/std::stof(argv[i]); | ||
} else if (arg == "--yarn-ext-factor") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
params.yarn_ext_factor = std::stof(argv[i]); | ||
} else if (arg == "--yarn-attn-factor") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
params.yarn_attn_factor = std::stof(argv[i]); | ||
} else if (arg == "--yarn-beta-fast") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
params.yarn_beta_fast = std::stof(argv[i]); | ||
} else if (arg == "--yarn-beta-slow") { | ||
if (++i >= argc) { | ||
invalid_param = true; | ||
break; | ||
} | ||
params.yarn_beta_slow = std::stof(argv[i]); | ||
} else if (arg == "--memory-f32") { | ||
params.memory_f16 = false; | ||
} else if (arg == "--top-p") { | ||
|
@@ -647,9 +681,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |
printf(" --cfg-negative-prompt-file FNAME\n"); | ||
printf(" negative prompt file to use for guidance. (default: empty)\n"); | ||
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); | ||
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); | ||
printf(" --rope-scaling {none,linear,yarn}\n"); | ||
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); | ||
printf(" --rope-scale N RoPE context scaling factor, inverse of --rope-freq-scale\n"); | ||
cebtenzzre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); | ||
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); | ||
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n"); | ||
printf(" --yarn-ext-factor N YaRN extrapolation mix factor (default: %.1f)\n", params.yarn_ext_factor); | ||
printf(" --yarn-attn-factor N YaRN magnitude scaling factor (default: %.1f)\n", params.yarn_attn_factor); | ||
printf(" --yarn-beta-fast N YaRN low correction dim (default: %.1f)\n", params.yarn_beta_fast); | ||
printf(" --yarn-beta-slow N YaRN high correction dim (default: %.1f)\n", params.yarn_beta_slow); | ||
|
||
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); | ||
printf(" --no-penalize-nl do not penalize newline token\n"); | ||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); | ||
|
@@ -725,23 +765,28 @@ std::string gpt_random_prompt(std::mt19937 & rng) { | |
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { | ||
auto lparams = llama_context_default_params(); | ||
|
||
lparams.n_ctx = params.n_ctx; | ||
lparams.n_batch = params.n_batch; | ||
lparams.n_ctx = params.n_ctx; | ||
lparams.n_batch = params.n_batch; | ||
if (params.n_gpu_layers != -1) { | ||
lparams.n_gpu_layers = params.n_gpu_layers; | ||
} | ||
lparams.main_gpu = params.main_gpu; | ||
lparams.tensor_split = params.tensor_split; | ||
lparams.low_vram = params.low_vram; | ||
lparams.mul_mat_q = params.mul_mat_q; | ||
lparams.seed = params.seed; | ||
lparams.f16_kv = params.memory_f16; | ||
lparams.use_mmap = params.use_mmap; | ||
lparams.use_mlock = params.use_mlock; | ||
lparams.logits_all = params.perplexity; | ||
lparams.embedding = params.embedding; | ||
lparams.rope_freq_base = params.rope_freq_base; | ||
lparams.rope_freq_scale = params.rope_freq_scale; | ||
lparams.main_gpu = params.main_gpu; | ||
lparams.tensor_split = params.tensor_split; | ||
lparams.low_vram = params.low_vram; | ||
lparams.mul_mat_q = params.mul_mat_q; | ||
lparams.seed = params.seed; | ||
lparams.f16_kv = params.memory_f16; | ||
lparams.use_mmap = params.use_mmap; | ||
lparams.use_mlock = params.use_mlock; | ||
lparams.logits_all = params.perplexity; | ||
lparams.embedding = params.embedding; | ||
lparams.rope_scaling_type = params.rope_scaling_type; | ||
lparams.rope_freq_base = params.rope_freq_base; | ||
lparams.rope_freq_scale = params.rope_freq_scale; | ||
lparams.yarn_ext_factor = params.yarn_ext_factor; | ||
lparams.yarn_attn_factor = params.yarn_attn_factor; | ||
lparams.yarn_beta_fast = params.yarn_beta_fast; | ||
lparams.yarn_beta_slow = params.yarn_beta_slow; | ||
|
||
return lparams; | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.