From 10f1c9ed30b191aae2ff093625a030dd21424e46 Mon Sep 17 00:00:00 2001 From: Johnman <> Date: Sun, 19 Mar 2023 16:26:21 +0100 Subject: [PATCH 1/5] Never exit the main loop in interactive mode. --- main.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/main.cpp b/main.cpp index 105dd91ee6065..df49839a30a83 100644 --- a/main.cpp +++ b/main.cpp @@ -1035,10 +1035,20 @@ int main(int argc, char ** argv) { } } - // end of text token - if (embd.back() == 2) { - fprintf(stderr, " [end of text]\n"); - break; + if (params.interactive) { + if (embd.size() && embd.back() == 2) { + is_interacting = true; + } + if (remaining_tokens == 0) { + remaining_tokens = params.n_predict; + is_interacting = true; + } + } else { + // end of text token + if (embd.size() && embd.back() == 2) { + fprintf(stderr, " [end of text]\n"); + break; + } } } From b78caa6bff8932fca656775b2d460f6d68166034 Mon Sep 17 00:00:00 2001 From: Johnman <> Date: Sun, 19 Mar 2023 16:57:02 +0100 Subject: [PATCH 2/5] Pause sampling if waiting for user input. --- main.cpp | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/main.cpp b/main.cpp index df49839a30a83..7646fd42d7f93 100644 --- a/main.cpp +++ b/main.cpp @@ -935,35 +935,37 @@ int main(int argc, char ** argv) { embd.clear(); if (embd_inp.size() <= input_consumed) { - // out of user input, sample next token - const float top_k = params.top_k; - const float top_p = params.top_p; - const float temp = params.temp; - const float repeat_penalty = params.repeat_penalty; + if (!is_interacting) { + // out of user input, sample next token + const float top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; - const int n_vocab = model.hparams.n_vocab; + const int n_vocab = model.hparams.n_vocab; - gpt_vocab::id id = 0; + gpt_vocab::id id = 0; - { - const int64_t t_start_sample_us = ggml_time_us(); + { + const int64_t t_start_sample_us = ggml_time_us(); - id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); + id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); - t_sample_us += ggml_time_us() - t_start_sample_us; - } + t_sample_us += ggml_time_us() - t_start_sample_us; + } - // add it to the context - embd.push_back(id); + // add it to the context + embd.push_back(id); - // echo this to console - input_noecho = false; + // echo this to console + input_noecho = false; - // decrement remaining sampling budget - --remaining_tokens; + // decrement remaining sampling budget + --remaining_tokens; + } } else { // some user input remains from prompt or interaction, forward it to processing while (embd_inp.size() > input_consumed) { From c62cffc2d96235dc3662760a33d7d974f98b3c46 Mon Sep 17 00:00:00 2001 From: Johnman <> Date: Sun, 19 Mar 2023 16:59:45 +0100 Subject: [PATCH 3/5] Make prompt randomization optional. --- main.cpp | 2 +- utils.cpp | 5 ++++- utils.h | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/main.cpp b/main.cpp index 7646fd42d7f93..99ed9d9ce91f5 100644 --- a/main.cpp +++ b/main.cpp @@ -805,7 +805,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); - if (params.prompt.empty()) { + if (params.random_prompt) { params.prompt = gpt_random_prompt(rng); } diff --git a/utils.cpp b/utils.cpp index efa2e3c35f728..5aab13f88036c 100644 --- a/utils.cpp +++ b/utils.cpp @@ -75,6 +75,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); @@ -98,7 +100,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); - fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " prompt to start generation with (default: empty)\n"); + fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); diff --git a/utils.h b/utils.h index c1a8498a78d68..f5b145b429c4b 100644 --- a/utils.h +++ b/utils.h @@ -30,6 +30,8 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt; + bool random_prompt = false; + bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode From 80825b01735ebf4135b063c8de1ff67f085c6edf Mon Sep 17 00:00:00 2001 From: Johnman <> Date: Sun, 19 Mar 2023 17:29:27 +0100 Subject: [PATCH 4/5] Support for multiple reverse prompts. --- main.cpp | 30 ++++++++++++++++++++---------- utils.cpp | 5 +++-- utils.h | 2 +- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/main.cpp b/main.cpp index 99ed9d9ce91f5..3765c05ac2e41 100644 --- a/main.cpp +++ b/main.cpp @@ -850,7 +850,11 @@ int main(int argc, char ** argv) { params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); // tokenize the reverse prompt - std::vector antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); + std::vector> antipromptv_inp; + + for (auto antiprompt : params.antiprompt) { + antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false)); + } fprintf(stderr, "\n"); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); @@ -872,13 +876,16 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: interactive mode on.\n", __func__); - if(antiprompt_inp.size()) { - fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); - fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); - for (int i = 0; i < (int) antiprompt_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + if(antipromptv_inp.size()) { + for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { + auto antiprompt_inp = antipromptv_inp.at(apindex); + fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); + fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); + for (int i = 0; i < (int) antiprompt_inp.size(); i++) { + fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); + } + fprintf(stderr, "\n"); } - fprintf(stderr, "\n"); } } fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); @@ -996,9 +1003,12 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && embd_inp.size() <= input_consumed) { // check for reverse prompt - if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { - // reverse prompt found - is_interacting = true; + for (auto antiprompt_inp : antipromptv_inp) { + if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { + // reverse prompt found + is_interacting = true; + break; + } } if (is_interacting) { // currently being interactive diff --git a/utils.cpp b/utils.cpp index 5aab13f88036c..207dc0e58fe64 100644 --- a/utils.cpp +++ b/utils.cpp @@ -71,7 +71,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--color") { params.use_color = true; } else if (arg == "-r" || arg == "--reverse-prompt") { - params.antiprompt = argv[++i]; + params.antiprompt.push_back(argv[++i]); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); @@ -95,7 +95,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n"); + fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n"); + fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/utils.h b/utils.h index f5b145b429c4b..b7e0bc31a9a9d 100644 --- a/utils.h +++ b/utils.h @@ -36,7 +36,7 @@ struct gpt_params { bool interactive = false; // interactive mode bool interactive_start = false; // reverse prompt immediately - std::string antiprompt = ""; // string upon seeing which more user input is prompted + std::vector antiprompt; // string upon seeing which more user input is prompted }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); From bb5e8ec79a0fb0ec4f53272e1f8e0eba7f0ab372 Mon Sep 17 00:00:00 2001 From: Johnman <> Date: Sun, 19 Mar 2023 16:26:21 +0100 Subject: [PATCH 5/5] Never exit the main loop in interactive mode. If the end of stream token mark is found, when in interactive mode, ask for user input instead of exiting the main loop. In case of running out of token budget, reset it and ask for user input. With these changes, embd can end up empty and cause a crash in the next iteration of the loop, so we check for its size as well. --- main.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/main.cpp b/main.cpp index 105dd91ee6065..df49839a30a83 100644 --- a/main.cpp +++ b/main.cpp @@ -1035,10 +1035,20 @@ int main(int argc, char ** argv) { } } - // end of text token - if (embd.back() == 2) { - fprintf(stderr, " [end of text]\n"); - break; + if (params.interactive) { + if (embd.size() && embd.back() == 2) { + is_interacting = true; + } + if (remaining_tokens == 0) { + remaining_tokens = params.n_predict; + is_interacting = true; + } + } else { + // end of text token + if (embd.size() && embd.back() == 2) { + fprintf(stderr, " [end of text]\n"); + break; + } } }