diff --git a/Makefile b/Makefile index 0ddff9961ac76..e4c338a580fcf 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,11 @@ ifdef LLAMA_PERF CFLAGS += -DGGML_PERF CXXFLAGS += -DGGML_PERF endif +ifdef LLAMA_USE_BOOST + LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex + CFLAGS += -DLLAMA_USE_BOOST + CXXFLAGS += -DLLAMA_USE_BOOST +endif ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) diff --git a/examples/common.cpp b/examples/common.cpp index f3085b08e5b25..84380e4603aaf 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--allowed-response-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.allowed_regex = argv[i]; + } else if (arg == "--response-bias-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex = argv[i]; + } else if (arg == "--response-bias-value") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex_value = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); diff --git a/examples/common.h b/examples/common.h index 499671b2e8d6d..0c89e586bcb52 100644 --- a/examples/common.h +++ b/examples/common.h @@ -49,6 +49,9 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string allowed_regex = ""; // regex string used to force prediction of matching tokens + std::string bias_regex = ""; // matching tokens are biased by bias_regex_value + float bias_regex_value = 0; // value to bias tokens matching bias_regex by std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bd1c4ab558521..77e8e62072d43 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -18,6 +18,10 @@ #include #include +#if defined(LLAMA_USE_BOOST) + #include +#endif + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include @@ -97,6 +101,12 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + +#if defined(LLAMA_USE_BOOST) + boost::regex response_allowed_regex = boost::regex(params.allowed_regex); + boost::regex response_bias_regex = boost::regex(params.bias_regex); +#endif + // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; @@ -305,7 +315,7 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; - + std::string partial_completion; while (n_remain != 0 || params.interactive) { // predict if (embd.size() > 0) { @@ -410,6 +420,17 @@ int main(int argc, char ** argv) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } +#if defined(LLAMA_USE_BOOST) + if (params.allowed_regex != "" || params.bias_regex != "") { + for (size_t i = 0; i < llama_n_vocab(ctx); i++) { + if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) + logits[i] = -INFINITY; + else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { + logits[i] += params.bias_regex_value; + } + } + } +#endif std::vector candidates; candidates.reserve(n_vocab); @@ -459,6 +480,7 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); + partial_completion += llama_token_to_str(ctx, id); } // replace end of text token with newline token when in interactive mode @@ -481,6 +503,7 @@ int main(int argc, char ** argv) { --n_remain; } else { // some user input remains from prompt or interaction, forward it to processing + partial_completion = ""; while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); last_n_tokens.erase(last_n_tokens.begin());