From 779d7969c01aa6dd30fe510b68a04482fafa3619 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 14:32:05 -0500 Subject: [PATCH 1/6] Add llama_get_num_logits() function to Llama.cpp API --- llama.cpp | 4 ++++ llama.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/llama.cpp b/llama.cpp index 4bba93a111ae4..efff9d9ade0f4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2779,6 +2779,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +size_t llama_get_num_logits(struct llama_context * ctx) { + return ctx->logits.size(); +} + float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index 58c6e0699a999..25a7d6ef5e696 100644 --- a/llama.h +++ b/llama.h @@ -178,6 +178,8 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + LLAMA_API size_t llama_get_num_logits(struct llama_context * ctx); + // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); From 30754bbaf98733e1a6df29d80843edfecbadef69 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 14:39:40 -0500 Subject: [PATCH 2/6] Add allowed response regex, response bias regex, and response bias value to main example --- Makefile | 5 +++++ examples/common.cpp | 18 ++++++++++++++++++ examples/common.h | 3 +++ examples/main/main.cpp | 24 +++++++++++++++++++++++- 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0ddff9961ac76..e4c338a580fcf 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,11 @@ ifdef LLAMA_PERF CFLAGS += -DGGML_PERF CXXFLAGS += -DGGML_PERF endif +ifdef LLAMA_USE_BOOST + LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex + CFLAGS += -DLLAMA_USE_BOOST + CXXFLAGS += -DLLAMA_USE_BOOST +endif ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) diff --git a/examples/common.cpp b/examples/common.cpp index f3085b08e5b25..84380e4603aaf 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--allowed-response-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.allowed_regex = argv[i]; + } else if (arg == "--response-bias-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex = argv[i]; + } else if (arg == "--response-bias-value") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex_value = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); diff --git a/examples/common.h b/examples/common.h index 499671b2e8d6d..e7293cd8fddcd 100644 --- a/examples/common.h +++ b/examples/common.h @@ -49,6 +49,9 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string allowed_regex = ""; // regex string used to force prediction of matching tokens + std::string bias_regex = ""; // matching tokens are biased by bias_regex_value + float bias_regex_value = -1; // value to bias tokens matching bias_regex by std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bd1c4ab558521..af31952e4d79f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -18,6 +18,10 @@ #include #include +#if defined(LLAMA_USE_BOOST) + #include +#endif + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include @@ -51,6 +55,8 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; +// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE"); +// boost::regex negative_bias_regex = boost::regex("^NONE"); if (gpt_params_parse(argc, argv, params) == false) { return 1; @@ -97,6 +103,12 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + +#if defined(LLAMA_USE_BOOST) + boost::regex response_allowed_regex = boost::regex(params.allowed_regex); + boost::regex response_bias_regex = boost::regex(params.bias_regex); +#endif + // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; @@ -305,7 +317,7 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; - + std::string partial_completion; while (n_remain != 0 || params.interactive) { // predict if (embd.size() > 0) { @@ -410,6 +422,15 @@ int main(int argc, char ** argv) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } +#if defined(LLAMA_USE_BOOST) + for (size_t i = 0; i < llama_get_num_logits(ctx); i++) { + if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) + logits[i] = -INFINITY; + else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { + logits[i] += params.bias_regex_value; + } + } +#endif std::vector candidates; candidates.reserve(n_vocab); @@ -459,6 +480,7 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); + partial_completion += llama_token_to_str(ctx, id); } // replace end of text token with newline token when in interactive mode From e0acd1a7bf974a012855a41ae524cd2cb8e5a970 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 14:54:32 -0500 Subject: [PATCH 3/6] Fix partial_completion not being reset --- examples/main/main.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index af31952e4d79f..e79540d020588 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -503,6 +503,7 @@ int main(int argc, char ** argv) { --n_remain; } else { // some user input remains from prompt or interaction, forward it to processing + partial_completion = ""; while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); last_n_tokens.erase(last_n_tokens.begin()); From 58d848dadd4a2badbdbd99907d5242d30005c49b Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 17:47:22 -0500 Subject: [PATCH 4/6] Remove unneeded llama_get_num_logits() function --- llama.cpp | 4 ---- llama.h | 2 -- 2 files changed, 6 deletions(-) diff --git a/llama.cpp b/llama.cpp index efff9d9ade0f4..4bba93a111ae4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2779,10 +2779,6 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } -size_t llama_get_num_logits(struct llama_context * ctx) { - return ctx->logits.size(); -} - float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index 25a7d6ef5e696..58c6e0699a999 100644 --- a/llama.h +++ b/llama.h @@ -178,8 +178,6 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); - LLAMA_API size_t llama_get_num_logits(struct llama_context * ctx); - // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); From 827ac3a457b5588d4695a2d1dd54da4dcf31e6e5 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 17:48:37 -0500 Subject: [PATCH 5/6] Add check for whether regex should be used --- examples/common.h | 2 +- examples/main/main.cpp | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/common.h b/examples/common.h index e7293cd8fddcd..0c89e586bcb52 100644 --- a/examples/common.h +++ b/examples/common.h @@ -51,7 +51,7 @@ struct gpt_params { std::string input_suffix = ""; // string to suffix user inputs with std::string allowed_regex = ""; // regex string used to force prediction of matching tokens std::string bias_regex = ""; // matching tokens are biased by bias_regex_value - float bias_regex_value = -1; // value to bias tokens matching bias_regex by + float bias_regex_value = 0; // value to bias tokens matching bias_regex by std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e79540d020588..2e98504fe1c6b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -423,11 +423,13 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } #if defined(LLAMA_USE_BOOST) - for (size_t i = 0; i < llama_get_num_logits(ctx); i++) { - if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) - logits[i] = -INFINITY; - else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { - logits[i] += params.bias_regex_value; + if (params.allowed_regex != "" || params.bias_regex != "") { + for (size_t i = 0; i < llama_n_vocab(ctx); i++) { + if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) + logits[i] = -INFINITY; + else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { + logits[i] += params.bias_regex_value; + } } } #endif From 2a8935297b9cfdbe503107c5246c4198482c6416 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 17:49:52 -0500 Subject: [PATCH 6/6] Remove commented-out regex code --- examples/main/main.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2e98504fe1c6b..77e8e62072d43 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -55,8 +55,6 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; -// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE"); -// boost::regex negative_bias_regex = boost::regex("^NONE"); if (gpt_params_parse(argc, argv, params) == false) { return 1;