Skip to content

Commit 50fae10

Browse files
slarenggerganov
andauthored
Add --ignore-eos parameter (abetlen#181)
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 084e2f0 commit 50fae10

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

main.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#define ANSI_COLOR_RESET "\x1b[0m"
2828
#define ANSI_BOLD "\x1b[1m"
2929

30+
static const int EOS_TOKEN_ID = 2;
31+
3032
// determine number of model parts based on the dimension
3133
static const std::map<int, int> LLAMA_N_PARTS = {
3234
{ 4096, 1 },
@@ -956,6 +958,11 @@ int main(int argc, char ** argv) {
956958
{
957959
const int64_t t_start_sample_us = ggml_time_us();
958960

961+
if (params.ignore_eos) {
962+
// set the logit of the eos token to zero to avoid sampling it
963+
logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0;
964+
}
965+
959966
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
960967

961968
last_n_tokens.erase(last_n_tokens.begin());
@@ -1055,7 +1062,8 @@ int main(int argc, char ** argv) {
10551062
}
10561063

10571064
// end of text token
1058-
if (embd.back() == 2) {
1065+
1066+
if (embd.back() == EOS_TOKEN_ID) {
10591067
if (params.interactive) {
10601068
is_interacting = true;
10611069
} else {

utils.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7171
params.use_color = true;
7272
} else if (arg == "-r" || arg == "--reverse-prompt") {
7373
params.antiprompt = argv[++i];
74+
} else if (arg == "--ignore-eos") {
75+
params.ignore_eos = true;
7476
} else if (arg == "-h" || arg == "--help") {
7577
gpt_print_usage(argc, argv, params);
7678
exit(0);
@@ -106,6 +108,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
106108
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
107109
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
108110
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
111+
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n");
109112
fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n");
110113
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
111114
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);

utils.h

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct gpt_params {
3636

3737
bool interactive = false; // interactive mode
3838
bool instruct = false; // instruction mode (used for Alpaca models)
39+
40+
bool ignore_eos = false; // do not stop generating after eos
3941
};
4042

4143
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

0 commit comments

Comments
 (0)