From fd0eb663cecd434bf55b9d431b54844175835cc8 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 31 May 2023 00:20:51 -0400 Subject: [PATCH 01/18] llama, main : constrain sampling to grammar --- Makefile | 5 +- examples/CMakeLists.txt | 2 + examples/common.cpp | 7 + examples/common.h | 1 + examples/grammar-parser.cpp | 315 ++++++++++++++++++++++++++++++++++++ examples/grammar-parser.h | 26 +++ examples/main/main.cpp | 35 ++++ llama.cpp | 240 +++++++++++++++++++++++++++ llama.h | 32 ++++ 9 files changed, 662 insertions(+), 1 deletion(-) create mode 100644 examples/grammar-parser.cpp create mode 100644 examples/grammar-parser.h diff --git a/Makefile b/Makefile index 39265164b322c..71b1baecf16e0 100644 --- a/Makefile +++ b/Makefile @@ -250,6 +250,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) @@ -260,7 +263,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3deff4077f80e..bd043ed68b1c9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + grammar-parser.h + grammar-parser.cpp ) if (BUILD_SHARED_LIBS) diff --git a/examples/common.cpp b/examples/common.cpp index f5d886acf6539..b20a6826085d3 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -388,6 +388,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.grammar = argv[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -458,6 +464,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stderr, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n"); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); diff --git a/examples/common.h b/examples/common.h index 826e2ae59cec1..5eb6118417a15 100644 --- a/examples/common.h +++ b/examples/common.h @@ -52,6 +52,7 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp new file mode 100644 index 0000000000000..53f9b26d5b424 --- /dev/null +++ b/examples/grammar-parser.cpp @@ -0,0 +1,315 @@ +#include "grammar-parser.h" +#include +#include +#include +#include + +namespace grammar_parser { + uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint16_t next_id = static_cast(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint16_t next_id = static_cast(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + int hex_to_int(char c) { + if ('a' <= c && c <= 'f') { + return c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + return c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + return c - '0'; + } + return -1; + } + + const char * parse_space(const char * src) { + const char * pos = src; + // TODO: support newlines in some cases + while (*pos == ' ' || *pos == '\t') { + pos++; + } + return pos; + } + + std::pair parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::string("expecting name at ") + src; + } + return std::make_pair(pos, parse_space(pos)); + } + + std::pair parse_char(const char * src) { + if (*src == '\\') { + char esc = src[1]; + if (esc == 'x') { + int first = hex_to_int(src[2]); + if (first > -1) { + int second = hex_to_int(src[3]); + if (second > -1) { + return std::make_pair((first << 4) + second, src + 4); + } + } + throw std::string("expecting \\xNN at ") + src; + } else if (esc == '"' || esc == '[' || esc == ']') { + return std::make_pair(esc, src + 2); + } else if (esc == 'r') { + return std::make_pair('\r', src + 2); + } else if (esc == 'n') { + return std::make_pair('\n', src + 2); + } else if (esc == 't') { + return std::make_pair('\t', src + 2); + } + throw std::string("unknown escape at ") + src; + } else if (*src) { + return std::make_pair(*src, src + 1); + } + throw std::string("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint16_t rule_id); + + const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & outbuf) { + size_t out_start = outbuf.size(); + + // sequence size, will be replaced at end when known + outbuf.push_back(0); + + size_t last_sym_start = outbuf.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = outbuf.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + + // each char of a literal is encoded as a "range" of char - char + outbuf.push_back(2); + outbuf.push_back(char_pair.first); + outbuf.push_back(char_pair.first); + } + pos = parse_space(pos + 1); + } else if (*pos == '[') { // char range(s) + pos++; + last_sym_start = outbuf.size(); + // num chars in range - replaced at end of loop + outbuf.push_back(0); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + + outbuf.push_back(char_pair.first); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + outbuf.push_back(endchar_pair.first); + } else { + // chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) + outbuf.push_back(char_pair.first); + } + } + // replace num chars with actual + outbuf[last_sym_start] = static_cast(outbuf.size() - last_sym_start - 1); + pos = parse_space(pos + 1); + } else if (is_word_char(*pos)) { // rule reference + auto name_pair = parse_name(pos); + uint16_t ref_rule_id = get_symbol_id(state, pos, name_pair.first - pos); + pos = name_pair.second; + last_sym_start = outbuf.size(); + outbuf.push_back(1); + outbuf.push_back(ref_rule_id); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1); + uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id); + last_sym_start = outbuf.size(); + // output reference to synthesized rule + outbuf.push_back(1); + outbuf.push_back(sub_rule_id); + if (*pos != ')') { + throw std::string("expecting ')' at ") + pos; + } + pos = parse_space(pos + 1); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (outbuf.size() - out_start - 1 == 0) { + throw std::string("expecting preceeding item to */+/? at ") + pos; + } + std::vector & out_grammar = state.out_grammar; + + // apply transformation to previous symbol (last_sym_start - + // end) according to rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + out_grammar.push_back(sub_rule_id); + size_t sub_rule_start = out_grammar.size(); + // placeholder for size of 1st alternate + out_grammar.push_back(0); + // add preceding symbol to generated rule + out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + out_grammar.push_back(1); + out_grammar.push_back(sub_rule_id); + } + // apply actual size + out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; + // mark end of 1st alternate + out_grammar.push_back(0); + sub_rule_start = out_grammar.size(); + // placeholder for size of 2nd alternate + out_grammar.push_back(0); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' + out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + } + // apply actual size of 2nd alternate + out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; + // mark end of 2nd alternate, then end of rule + out_grammar.push_back(0); + out_grammar.push_back(0); + + // in original rule, replace previous symbol with reference to generated rule + outbuf.resize(last_sym_start); + outbuf.push_back(1); + outbuf.push_back(sub_rule_id); + + pos = parse_space(pos + 1); + } else { + break; + } + } + // apply actual size of this alternate sequence + outbuf[out_start] = static_cast(outbuf.size() - out_start); + // mark end of alternate + outbuf.push_back(0); + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint16_t rule_id) { + std::vector outbuf; + const char * pos = parse_sequence(state, src, rule_name, outbuf); + while (*pos == '|') { + pos = parse_space(pos + 1); + pos = parse_sequence(state, pos, rule_name, outbuf); + } + state.out_grammar.push_back(rule_id); + state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end()); + state.out_grammar.push_back(0); + return pos; + } + + const char * parse_rule(parse_state & state, const char * src) { + auto name_pair = parse_name(src); + const char * pos = name_pair.second; + size_t name_len = name_pair.first - src; + uint16_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::string("expecting ::= at ") + pos; + } + pos = parse_space(pos + 3); + + pos = parse_alternates(state, pos, name, rule_id); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::string("expecting newline or end at ") + pos; + } + return parse_space(pos); + } + + parse_state parse(const char * src) { + parse_state state; + const char * pos = parse_space(src); + while (*pos) { + pos = parse_rule(state, pos); + } + state.out_grammar.push_back(0xffff); + return state; + } + + const uint16_t * print_rule( + FILE * file, + const uint16_t * base, + const uint16_t * src, + const std::map & symbol_id_names) { + uint16_t rule_id = *src; + fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str()); + const uint16_t * pos = src + 1; + while (*pos) { + if (pos - 1 > src) { + fprintf(file, "| "); + } + pos++; // sequence size, not needed here + while (*pos) { + if (*pos == 1) { + uint16_t ref_rule_id = pos[1]; + fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str()); + pos += 2; + } else { + fprintf(file, "<%zu>[", pos - base); + uint16_t num_chars = *pos; + pos++; + + for (uint16_t i = 0; i < num_chars; i += 2) { + fprintf(file, "%lc-", static_cast(pos[i])); // REVIEW + if (i + 1 < num_chars) { + fprintf(file, "%lc", static_cast(pos[i + 1])); + } + } + fprintf(file, "] "); + pos += num_chars; + } + } + pos++; + } + fprintf(file, "\n"); + return pos + 1; + } + + void print_grammar(FILE * file, const parse_state & state) { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + const uint16_t * pos = state.out_grammar.data(); + while (*pos != 0xffff) { + pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names); + } + } +} + diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h new file mode 100644 index 0000000000000..c9e27d4cd8cb8 --- /dev/null +++ b/examples/grammar-parser.h @@ -0,0 +1,26 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by llama.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include +#include +#include +#include + +namespace grammar_parser { + struct parse_state { + std::map symbol_ids; + std::vector out_grammar; + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index de63faa3eea76..f43eb4fe42d68 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include #include @@ -291,6 +292,17 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + grammar = llama_grammar_init( + parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -454,6 +466,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -479,6 +495,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + id = llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -609,6 +629,17 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + if (grammar != NULL) { + llama_grammar_free(grammar); + } + grammar = llama_grammar_init( + parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -638,5 +669,9 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + if (grammar != NULL) { + llama_grammar_free(grammar); + } + return 0; } diff --git a/llama.cpp b/llama.cpp index 16d6f6ef1c68c..877a6a2a175cf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1821,6 +1821,168 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } +// +// grammar - internal +// + +struct llama_grammar { + const std::vector rules; + std::vector> stacks; +}; + +// transforms a grammar pushdown stack into N possible stacks, all terminating +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const uint16_t * pos = stack.back(); + + if (*pos == 1) { + // rule reference, apply rule to stack + const uint16_t * subpos = rules[pos[1]] + 1; + while (*subpos) { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (pos[2]) { + // if the rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 2); + } + if (subpos[1]) { + // if the referenced rule is nonempty, add that to the stack + new_stack.push_back(subpos + 1); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + subpos += 1 + *subpos; + } + } else { + // rule element size > 1 -> character reference + LLAMA_ASSERT(*pos); + new_stacks.push_back(stack); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector & rules, + const std::vector> & stacks, + const uint16_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + const uint16_t * pos = stack.back(); + const uint16_t num_chars = *pos; + LLAMA_ASSERT(num_chars > 1); + + pos++; // skip num chars indicator + bool found = false; + // loop over the inclusive char pairs to find a match on the given char + for (int i = 0; i < num_chars; i += 2) { + if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { + found = true; + break; + } + } + if (!found) { + continue; + } + + // advance past char range, updating top of stack to next element, if any + pos += num_chars; + std::vector new_stack(stack.begin(), stack.end() - 1); + if (*pos) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + + return new_stacks; +} + +// returns `true` if one of the pushdown stacks can accept the given char. +static bool llama_grammar_peek( + const std::vector> & stacks, + const uint16_t chr) { + + for (const auto & stack : stacks) { + if (stack.empty()) { + if (!chr) { + return true; + } + } else { + const uint16_t * pos = stack.back(); + const uint16_t num_chars = *pos; + LLAMA_ASSERT(num_chars > 1); + + pos++; + for (int i = 0; i < num_chars; i += 2) { + if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { + return true; + } + } + } + } + return false; +} + + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) { + const uint16_t * pos = src; + std::vector rules; + + // build `rules` as list of pointers to rules embedded in binary grammar `src` + while (*pos != 0xffff) { + uint16_t rule_id = *pos; + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = pos; + // skip rule id + pos++; + // skip rule alternates + while (*pos) { + pos += 1 + *pos; + } + // skip 0 denoting end of rule + pos++; + } + + // TODO: handle if start rule has alternates + const uint16_t * start_rule = rules[start_rule_id]; + + // rule starts with rule id and 1st alternate's size; skip that so initial + // stack starts at 1st element in 1st alternate + LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]); + const std::vector stack = { start_rule + 2 }; + + std::vector> stacks; + llama_grammar_advance_stack(rules, stack, stacks); + + return new llama_grammar{ rules, stacks }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + // // sampling // @@ -2097,6 +2259,30 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + const llama_token eos = llama_token_eos(); + // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' + const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' '); + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); + + // prune tokens based on first char only - in `llama_grammar_accept_token` we will find the + // full matching prefix of the selected token + const bool valid = str[0] == ' ' + ? llama_grammar_peek(stacks_after_space, str[1]) + : llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]); + + if (!valid) { + candidates->data[i].logit = -INFINITY; + } + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); @@ -2223,6 +2409,60 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos()) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return token; + } + LLAMA_ASSERT(false); + } + } + + const char * str = llama_token_to_str(ctx, token); + const char * suffix = str; + + // Find prefix of selected token that matches grammar, expecting at least 1 char + auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix); + LLAMA_ASSERT(!new_stacks.empty()); + if (*suffix) { + ++suffix; + for ( ; *suffix; ++suffix) { + new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix); + if (new_stacks.empty()) { + break; + } + } + } + + // if full token is matched, accept new stacks + if (!(*suffix)) { + grammar->stacks = new_stacks; + return token; + } + + // otherwise, tokenize the string prefix that did match + llama_token tokens[32]; // TODO - determine actual max token size + const std::string prefix_str(str, suffix - str); + int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false); + if (n_tokens < 1) { + return token; // REVIEW + } + + // accept the first token of the matching prefix into the grammar + llama_token first_prefix_token = tokens[0]; + const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); + for ( ; *first_prefix_str; ++first_prefix_str) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str); + LLAMA_ASSERT(!grammar->stacks.empty()); + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + return first_prefix_token; +} + // // quantization // diff --git a/llama.h b/llama.h index dc033b71dc036..49e26b1b8fd79 100644 --- a/llama.h +++ b/llama.h @@ -55,6 +55,8 @@ extern "C" { struct llama_context; + struct llama_grammar; + typedef int llama_token; typedef struct llama_token_data { @@ -233,6 +235,30 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); LLAMA_API llama_token llama_token_nl(); + // Grammar + // + // Accepts a binary encoding of a context-free grammar. The returned struct can be used to + // constrain sampled tokens (see below). + // + // The binary format represents one or more production rules, each with one or more alternate + // defininitions: + // + // ( ( )+ 0000)+ FFFF + // + // rule_ids should be assigned sequentially from zero but may appear out of order. Each + // rule alternate is a sequence of zero or more symbols, each prefixed with size: + // + // ( )* 0000 + // + // A symbol of size 1 is interpreted as a rule reference (whose value is the single following + // u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to + // match. Note that symbol sizes greater than 7FFF are reserved for future use. + // + // The provided `src` must be kept valid for the lifetime of the `llama_grammar`. + // + LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id); + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -257,6 +283,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -278,6 +307,9 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Accepts the sampled token into the grammar, possibly transforming to a new token + LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx); From 834d423edf2c5bfe44bd19860d36fc4f27455041 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 11 Jun 2023 22:37:16 -0400 Subject: [PATCH 02/18] allow loading grammar from file --- examples/common.cpp | 17 +++++++++++++++++ grammars/arithmetic.gbnf | 6 ++++++ grammars/chess.gbnf | 5 +++++ grammars/json.gbnf | 8 ++++++++ 4 files changed, 36 insertions(+) create mode 100644 grammars/arithmetic.gbnf create mode 100644 grammars/chess.gbnf create mode 100644 grammars/json.gbnf diff --git a/examples/common.cpp b/examples/common.cpp index b20a6826085d3..d68bd4ba7f95c 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -394,6 +394,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.grammar = argv[i]; + } else if (arg == "--grammar-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.grammar) + ); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -465,6 +481,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stderr, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n"); + fprintf(stderr, " --grammar-file FNAME file to read grammar from\n"); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); diff --git a/grammars/arithmetic.gbnf b/grammars/arithmetic.gbnf new file mode 100644 index 0000000000000..3aa95a9dda7e8 --- /dev/null +++ b/grammars/arithmetic.gbnf @@ -0,0 +1,6 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 0000000000000..2da6b7139dd48 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,5 @@ +root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ +move ::= (pawn | nonpawn | castle) [+#]? +nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] +pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? +castle ::= "O-O" "-O"? diff --git a/grammars/json.gbnf b/grammars/json.gbnf new file mode 100644 index 0000000000000..cba085d97411c --- /dev/null +++ b/grammars/json.gbnf @@ -0,0 +1,8 @@ +root ::= object | array +value ::= object | array | string | number | boolean +object ::= "{" ws (string ":" ws value ("," ws string ":" ws value)*)? "}" +array ::= "[" ws (value ("," ws value)*)? "]" +string ::= "\"" [ \t!#-\[\]-~]* "\"" ws +number ::= [0-9]+ ws +boolean ::= ("true" | "false") ws +ws ::= [ \t\t] ws | From 9e77f42ef7e7f58692952cca44d14ce7bbf55aa7 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 11 Jun 2023 22:38:13 -0400 Subject: [PATCH 03/18] fix whitespace errors --- examples/grammar-parser.cpp | 3 +-- examples/main/main.cpp | 2 +- llama.cpp | 8 ++++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 53f9b26d5b424..650c3cfc5a637 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -265,7 +265,7 @@ namespace grammar_parser { const uint16_t * print_rule( FILE * file, const uint16_t * base, - const uint16_t * src, + const uint16_t * src, const std::map & symbol_id_names) { uint16_t rule_id = *src; fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str()); @@ -312,4 +312,3 @@ namespace grammar_parser { } } } - diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f43eb4fe42d68..47b4728e97797 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -293,7 +293,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n\n"); grammar_parser::parse_state parsed_grammar; - llama_grammar * grammar = NULL; + llama_grammar * grammar = NULL; if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); fprintf(stderr, "%s: grammar:\n", __func__); diff --git a/llama.cpp b/llama.cpp index 877a6a2a175cf..cc66a601e0c76 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1843,7 +1843,7 @@ static void llama_grammar_advance_stack( } const uint16_t * pos = stack.back(); - + if (*pos == 1) { // rule reference, apply rule to stack const uint16_t * subpos = rules[pos[1]] + 1; @@ -1871,7 +1871,7 @@ static void llama_grammar_advance_stack( // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those -// positions +// positions static std::vector> llama_grammar_accept( const std::vector & rules, const std::vector> & stacks, @@ -1913,7 +1913,7 @@ static std::vector> llama_grammar_accept( return new_stacks; } -// returns `true` if one of the pushdown stacks can accept the given char. +// returns `true` if one of the pushdown stacks can accept the given char. static bool llama_grammar_peek( const std::vector> & stacks, const uint16_t chr) { @@ -1942,7 +1942,7 @@ static bool llama_grammar_peek( // // grammar - external -// +// struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) { const uint16_t * pos = src; From 674bb08b20749d8aa866c6106c9c28bb729267df Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 11 Jun 2023 22:40:01 -0400 Subject: [PATCH 04/18] handle & print parser errors --- examples/grammar-parser.cpp | 33 +++++++++++++++++++-------------- examples/main/main.cpp | 11 ++++++----- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 650c3cfc5a637..bdd48e7960255 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -47,7 +47,7 @@ namespace grammar_parser { pos++; } if (pos == src) { - throw std::string("expecting name at ") + src; + throw std::runtime_error(std::string("expecting name at ") + src); } return std::make_pair(pos, parse_space(pos)); } @@ -63,7 +63,7 @@ namespace grammar_parser { return std::make_pair((first << 4) + second, src + 4); } } - throw std::string("expecting \\xNN at ") + src; + throw std::runtime_error(std::string("expecting \\xNN at ") + src); } else if (esc == '"' || esc == '[' || esc == ']') { return std::make_pair(esc, src + 2); } else if (esc == 'r') { @@ -73,11 +73,11 @@ namespace grammar_parser { } else if (esc == 't') { return std::make_pair('\t', src + 2); } - throw std::string("unknown escape at ") + src; + throw std::runtime_error(std::string("unknown escape at ") + src); } else if (*src) { return std::make_pair(*src, src + 1); } - throw std::string("unexpected end of input"); + throw std::runtime_error("unexpected end of input"); } const char * parse_alternates( @@ -151,12 +151,12 @@ namespace grammar_parser { outbuf.push_back(1); outbuf.push_back(sub_rule_id); if (*pos != ')') { - throw std::string("expecting ')' at ") + pos; + throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator if (outbuf.size() - out_start - 1 == 0) { - throw std::string("expecting preceeding item to */+/? at ") + pos; + throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); } std::vector & out_grammar = state.out_grammar; @@ -236,7 +236,7 @@ namespace grammar_parser { const std::string name(src, name_len); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::string("expecting ::= at ") + pos; + throw std::runtime_error(std::string("expecting ::= at ") + pos); } pos = parse_space(pos + 3); @@ -247,19 +247,24 @@ namespace grammar_parser { } else if (*pos == '\n') { pos++; } else if (*pos) { - throw std::string("expecting newline or end at ") + pos; + throw std::runtime_error(std::string("expecting newline or end at ") + pos); } return parse_space(pos); } parse_state parse(const char * src) { - parse_state state; - const char * pos = parse_space(src); - while (*pos) { - pos = parse_rule(state, pos); + try { + parse_state state; + const char * pos = parse_space(src); + while (*pos) { + pos = parse_rule(state, pos); + } + state.out_grammar.push_back(0xffff); + return state; + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return parse_state(); } - state.out_grammar.push_back(0xffff); - return state; } const uint16_t * print_rule( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 47b4728e97797..8d9371e19d3b8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -296,6 +296,10 @@ int main(int argc, char ** argv) { llama_grammar * grammar = NULL; if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.out_grammar.empty()) { + return 1; + } fprintf(stderr, "%s: grammar:\n", __func__); grammar_parser::print_grammar(stderr, parsed_grammar); fprintf(stderr, "\n"); @@ -631,11 +635,8 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { // reset grammar state if we're restarting generation - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - if (grammar != NULL) { - llama_grammar_free(grammar); - } + if (grammar != NULL) { + llama_grammar_free(grammar); grammar = llama_grammar_init( parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); } From 98a9587ce481f0e643f7e4dd361d14caef9849eb Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 11 Jun 2023 23:41:25 -0400 Subject: [PATCH 05/18] add comments to grammar syntax and allow newlines where unambiguous --- examples/grammar-parser.cpp | 67 +++++++++++++++++++++---------------- grammars/chess.gbnf | 8 +++++ grammars/json.gbnf | 28 ++++++++++++---- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index bdd48e7960255..9ff57d92fd360 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -32,16 +32,22 @@ namespace grammar_parser { return -1; } - const char * parse_space(const char * src) { + const char * parse_space(const char * src, bool newline_ok) { const char * pos = src; - // TODO: support newlines in some cases - while (*pos == ' ' || *pos == '\t') { - pos++; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } } return pos; } - std::pair parse_name(const char * src) { + const char * parse_name(const char * src) { const char * pos = src; while (is_word_char(*pos)) { pos++; @@ -49,7 +55,7 @@ namespace grammar_parser { if (pos == src) { throw std::runtime_error(std::string("expecting name at ") + src); } - return std::make_pair(pos, parse_space(pos)); + return pos; } std::pair parse_char(const char * src) { @@ -84,13 +90,15 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id); + uint16_t rule_id, + bool is_nested); const char * parse_sequence( parse_state & state, const char * src, const std::string & rule_name, - std::vector & outbuf) { + std::vector & outbuf, + bool is_nested) { size_t out_start = outbuf.size(); // sequence size, will be replaced at end when known @@ -111,7 +119,7 @@ namespace grammar_parser { outbuf.push_back(char_pair.first); outbuf.push_back(char_pair.first); } - pos = parse_space(pos + 1); + pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) pos++; last_sym_start = outbuf.size(); @@ -133,19 +141,19 @@ namespace grammar_parser { } // replace num chars with actual outbuf[last_sym_start] = static_cast(outbuf.size() - last_sym_start - 1); - pos = parse_space(pos + 1); + pos = parse_space(pos + 1, is_nested); } else if (is_word_char(*pos)) { // rule reference - auto name_pair = parse_name(pos); - uint16_t ref_rule_id = get_symbol_id(state, pos, name_pair.first - pos); - pos = name_pair.second; + const char * name_end = parse_name(pos); + uint16_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + pos = parse_space(name_end, is_nested); last_sym_start = outbuf.size(); outbuf.push_back(1); outbuf.push_back(ref_rule_id); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule - pos = parse_space(pos + 1); + pos = parse_space(pos + 1, true); uint16_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id); + pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); last_sym_start = outbuf.size(); // output reference to synthesized rule outbuf.push_back(1); @@ -153,7 +161,7 @@ namespace grammar_parser { if (*pos != ')') { throw std::runtime_error(std::string("expecting ')' at ") + pos); } - pos = parse_space(pos + 1); + pos = parse_space(pos + 1, is_nested); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator if (outbuf.size() - out_start - 1 == 0) { throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); @@ -199,7 +207,7 @@ namespace grammar_parser { outbuf.push_back(1); outbuf.push_back(sub_rule_id); - pos = parse_space(pos + 1); + pos = parse_space(pos + 1, is_nested); } else { break; } @@ -215,12 +223,13 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id) { + uint16_t rule_id, + bool is_nested) { std::vector outbuf; - const char * pos = parse_sequence(state, src, rule_name, outbuf); + const char * pos = parse_sequence(state, src, rule_name, outbuf, is_nested); while (*pos == '|') { - pos = parse_space(pos + 1); - pos = parse_sequence(state, pos, rule_name, outbuf); + pos = parse_space(pos + 1, true); + pos = parse_sequence(state, pos, rule_name, outbuf, is_nested); } state.out_grammar.push_back(rule_id); state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end()); @@ -229,18 +238,18 @@ namespace grammar_parser { } const char * parse_rule(parse_state & state, const char * src) { - auto name_pair = parse_name(src); - const char * pos = name_pair.second; - size_t name_len = name_pair.first - src; - uint16_t rule_id = get_symbol_id(state, src, name_len); + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint16_t rule_id = get_symbol_id(state, src, name_len); const std::string name(src, name_len); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { throw std::runtime_error(std::string("expecting ::= at ") + pos); } - pos = parse_space(pos + 3); + pos = parse_space(pos + 3, true); - pos = parse_alternates(state, pos, name, rule_id); + pos = parse_alternates(state, pos, name, rule_id, false); if (*pos == '\r') { pos += pos[1] == '\n' ? 2 : 1; @@ -249,13 +258,13 @@ namespace grammar_parser { } else if (*pos) { throw std::runtime_error(std::string("expecting newline or end at ") + pos); } - return parse_space(pos); + return parse_space(pos, true); } parse_state parse(const char * src) { try { parse_state state; - const char * pos = parse_space(src); + const char * pos = parse_space(src, true); while (*pos) { pos = parse_rule(state, pos); } diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf index 2da6b7139dd48..ef0fc1b07f01c 100644 --- a/grammars/chess.gbnf +++ b/grammars/chess.gbnf @@ -1,5 +1,13 @@ +# Specifies chess moves as a list in algebraic notation, using PGN conventions + +# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ move ::= (pawn | nonpawn | castle) [+#]? + +# piece type, optional file/rank, optional capture, dest file & rank nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] + +# optional file & capture, dest file & rank, optional promotion pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? + castle ::= "O-O" "-O"? diff --git a/grammars/json.gbnf b/grammars/json.gbnf index cba085d97411c..d145b4cc17f5c 100644 --- a/grammars/json.gbnf +++ b/grammars/json.gbnf @@ -1,8 +1,24 @@ -root ::= object | array +# Grammar for subset of JSON - doesn't support full string or number syntax + +root ::= object | array value ::= object | array | string | number | boolean -object ::= "{" ws (string ":" ws value ("," ws string ":" ws value)*)? "}" -array ::= "[" ws (value ("," ws value)*)? "]" -string ::= "\"" [ \t!#-\[\]-~]* "\"" ws -number ::= [0-9]+ ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" + +# Subset of JSON primitives: strings without escapes and only regular integers +string ::= "\"" [ \t!#-\[\]-~]* "\"" ws +number ::= "-"? [0-9]+ ws boolean ::= ("true" | "false") ws -ws ::= [ \t\t] ws | + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= [ \t\n] ws | From 3e78f0071a76fac0a9807bd32de805d2ac67401a Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Mon, 12 Jun 2023 00:13:27 -0400 Subject: [PATCH 06/18] add missing include --- examples/grammar-parser.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 9ff57d92fd360..56dc7da2f6f2d 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include namespace grammar_parser { uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) { From 421c6e1ca158a1b34e4f2f8d148e819c2fb7da62 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 14 Jun 2023 23:53:12 -0400 Subject: [PATCH 07/18] support alternates in root rule --- llama.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index b0c2270c61502..44147d9356714 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1870,7 +1870,7 @@ static void llama_grammar_advance_stack( new_stack.push_back(pos + 2); } if (subpos[1]) { - // if the referenced rule is nonempty, add that to the stack + // if this alternate is nonempty, add that to the stack new_stack.push_back(subpos + 1); } llama_grammar_advance_stack(rules, new_stack, new_stacks); @@ -1980,16 +1980,22 @@ struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_r pos++; } - // TODO: handle if start rule has alternates const uint16_t * start_rule = rules[start_rule_id]; - // rule starts with rule id and 1st alternate's size; skip that so initial - // stack starts at 1st element in 1st alternate - LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]); - const std::vector stack = { start_rule + 2 }; + LLAMA_ASSERT(*start_rule == start_rule_id); + // loop over alternates of start rule to build initial stacks + pos = start_rule + 1; std::vector> stacks; - llama_grammar_advance_stack(rules, stack, stacks); + while (*pos) { + std::vector stack; + if (pos[1]) { + // if alernate is nonempty, add to stack + stack.push_back(pos + 1); + } + llama_grammar_advance_stack(rules, stack, stacks); + pos += 1 + *pos; + } return new llama_grammar{ rules, stacks }; } From b876d19cff85dcb2bbb74539af6a3b8609a46f34 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 14 Jun 2023 23:53:55 -0400 Subject: [PATCH 08/18] fix bugs with empty token and EOS --- llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 44147d9356714..8cd2209d615a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2295,7 +2295,9 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c // full matching prefix of the selected token const bool valid = str[0] == ' ' ? llama_grammar_peek(stacks_after_space, str[1]) - : llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]); + : str[0] || id == eos + ? llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]) + : false; if (!valid) { candidates->data[i].logit = -INFINITY; @@ -2438,8 +2440,8 @@ llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_ if (stack.empty()) { return token; } - LLAMA_ASSERT(false); } + LLAMA_ASSERT(false); } const char * str = llama_token_to_str(ctx, token); From 58ca9bc6c00d02d0a5bc1c62bb73d6a8044898fc Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Thu, 15 Jun 2023 00:06:54 -0400 Subject: [PATCH 09/18] adjust JSON grammar --- grammars/.json.gbnf.swp | Bin 0 -> 12288 bytes grammars/json.gbnf | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 grammars/.json.gbnf.swp diff --git a/grammars/.json.gbnf.swp b/grammars/.json.gbnf.swp new file mode 100644 index 0000000000000000000000000000000000000000..a7a12008433c3c04b131966bdee999c09cace69a GIT binary patch literal 12288 zcmeI2&2AGh5XW7PMQ!;yaGW|wq@|m*R6;ZoRH}p!h*G4gy(Bq!H{N6`$6ndqgose? zNIV0GSD=poS9k}Wfji7@HlR`v9D0c~mj0RD@pxwZ+vKo{!}Y^QyRcPTXE-h~cKOZg z&d&WC?8`i3av*f>{?~YbaW2^1c~%qND-#J-)6$<^b(|-xD<%4(7Nu#uYel~=tg9b6 zqifwz$7k1&hX@dXb0CoUIJmpOHg9hfKlKV>Y5g-CY zU{(nDbdJ5kW-pWtUoY+4Z??Q8Lj;Hb5g-CYfCvx)B0vO)01+SpM1Tm)AOX=~th&J1 zJ5>Jv|MdO;WQnm4sQ0KoN}&Ym5OoK24RsMUkNUpI*f-Qy)MwNu)JIf?dV<RXcox9gSq|*dbTO|Jf<*7+=+n|*v~>IJ(w6C@9#Z@)F%DJCj;pkFv)_A zCBA1e57I>GVyH6KQI!hd5LCJ)1lmj4ungk@n&_ F`vJ^+2O Date: Thu, 15 Jun 2023 00:13:13 -0400 Subject: [PATCH 10/18] remove swp file --- grammars/.json.gbnf.swp | Bin 12288 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 grammars/.json.gbnf.swp diff --git a/grammars/.json.gbnf.swp b/grammars/.json.gbnf.swp deleted file mode 100644 index a7a12008433c3c04b131966bdee999c09cace69a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2&2AGh5XW7PMQ!;yaGW|wq@|m*R6;ZoRH}p!h*G4gy(Bq!H{N6`$6ndqgose? zNIV0GSD=poS9k}Wfji7@HlR`v9D0c~mj0RD@pxwZ+vKo{!}Y^QyRcPTXE-h~cKOZg z&d&WC?8`i3av*f>{?~YbaW2^1c~%qND-#J-)6$<^b(|-xD<%4(7Nu#uYel~=tg9b6 zqifwz$7k1&hX@dXb0CoUIJmpOHg9hfKlKV>Y5g-CY zU{(nDbdJ5kW-pWtUoY+4Z??Q8Lj;Hb5g-CYfCvx)B0vO)01+SpM1Tm)AOX=~th&J1 zJ5>Jv|MdO;WQnm4sQ0KoN}&Ym5OoK24RsMUkNUpI*f-Qy)MwNu)JIf?dV<RXcox9gSq|*dbTO|Jf<*7+=+n|*v~>IJ(w6C@9#Z@)F%DJCj;pkFv)_A zCBA1e57I>GVyH6KQI!hd5LCJ)1lmj4ungk@n&_ F`vJ^+2O Date: Sat, 17 Jun 2023 21:58:17 -0400 Subject: [PATCH 11/18] rewrite ternary expressions Co-authored-by: Henri Vasserman --- llama.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 8cd2209d615a1..37d19ea912051 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2293,11 +2293,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c // prune tokens based on first char only - in `llama_grammar_accept_token` we will find the // full matching prefix of the selected token - const bool valid = str[0] == ' ' - ? llama_grammar_peek(stacks_after_space, str[1]) - : str[0] || id == eos - ? llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]) - : false; + bool valid = false; + if (id == eos) { + valid = llama_grammar_peek(grammar->stacks, 0); + } else if (str[0] == ' ') { + valid = llama_grammar_peek(stacks_after_space, str[1]); + } else if (str[0] != 0) { + valid = llama_grammar_peek(grammar->stacks, str[0]); + } if (!valid) { candidates->data[i].logit = -INFINITY; From f8baad235d1056527a5e594c80abc35c149129f8 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 20 Jun 2023 00:06:38 -0400 Subject: [PATCH 12/18] use struct for grammar elements and add Unicode support --- examples/grammar-parser.cpp | 315 ++++++++++++++++++++++-------------- examples/grammar-parser.h | 7 +- examples/main/main.cpp | 12 +- grammars/japanese.gbnf | 7 + llama.cpp | 262 ++++++++++++++++++------------ llama.h | 58 ++++--- 6 files changed, 413 insertions(+), 248 deletions(-) create mode 100644 grammars/japanese.gbnf diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 56dc7da2f6f2d..206ccac56c3b0 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -7,18 +7,45 @@ #include namespace grammar_parser { - uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint16_t next_id = static_cast(state.symbol_ids.size()); + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast(state.symbol_ids.size()); auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); return result.first->second; } - uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint16_t next_id = static_cast(state.symbol_ids.size()); + uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast(state.symbol_ids.size()); state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; return next_id; } + void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + bool is_word_char(char c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } @@ -60,9 +87,10 @@ namespace grammar_parser { return pos; } - std::pair parse_char(const char * src) { + std::pair parse_char(const char * src) { if (*src == '\\') { char esc = src[1]; + // TODO: 16- and 32-bit escapes if (esc == 'x') { int first = hex_to_int(src[2]); if (first > -1) { @@ -83,7 +111,8 @@ namespace grammar_parser { } throw std::runtime_error(std::string("unknown escape at ") + src); } else if (*src) { - return std::make_pair(*src, src + 1); + auto decoded = decode_utf8(src); + return std::make_pair(decoded.first, decoded.second); } throw std::runtime_error("unexpected end of input"); } @@ -92,132 +121,101 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id, + uint32_t rule_id, bool is_nested); const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & outbuf, - bool is_nested) { - size_t out_start = outbuf.size(); - - // sequence size, will be replaced at end when known - outbuf.push_back(0); - - size_t last_sym_start = outbuf.size(); + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); const char * pos = src; while (*pos) { if (*pos == '"') { // literal string pos++; - last_sym_start = outbuf.size(); + last_sym_start = out_elements.size(); while (*pos != '"') { auto char_pair = parse_char(pos); pos = char_pair.second; - - // each char of a literal is encoded as a "range" of char - char - outbuf.push_back(2); - outbuf.push_back(char_pair.first); - outbuf.push_back(char_pair.first); + out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); } pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) pos++; - last_sym_start = outbuf.size(); - // num chars in range - replaced at end of loop - outbuf.push_back(0); + last_sym_start = out_elements.size(); while (*pos != ']') { auto char_pair = parse_char(pos); pos = char_pair.second; + enum llama_gretype type = last_sym_start < out_elements.size() + ? LLAMA_GRETYPE_CHAR_ALT + : LLAMA_GRETYPE_CHAR; - outbuf.push_back(char_pair.first); + out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { auto endchar_pair = parse_char(pos + 1); pos = endchar_pair.second; - outbuf.push_back(endchar_pair.first); - } else { - // chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) - outbuf.push_back(char_pair.first); + out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); } } - // replace num chars with actual - outbuf[last_sym_start] = static_cast(outbuf.size() - last_sym_start - 1); pos = parse_space(pos + 1, is_nested); } else if (is_word_char(*pos)) { // rule reference const char * name_end = parse_name(pos); - uint16_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); pos = parse_space(name_end, is_nested); - last_sym_start = outbuf.size(); - outbuf.push_back(1); - outbuf.push_back(ref_rule_id); + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); - uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = outbuf.size(); + last_sym_start = out_elements.size(); // output reference to synthesized rule - outbuf.push_back(1); - outbuf.push_back(sub_rule_id); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); if (*pos != ')') { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (outbuf.size() - out_start - 1 == 0) { + if (last_sym_start == out_elements.size()) { throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); } - std::vector & out_grammar = state.out_grammar; - // apply transformation to previous symbol (last_sym_start - - // end) according to rewrite rules: + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: // S* --> S' ::= S S' | // S+ --> S' ::= S S' | S // S? --> S' ::= S | - uint16_t sub_rule_id = generate_symbol_id(state, rule_name); - out_grammar.push_back(sub_rule_id); - size_t sub_rule_start = out_grammar.size(); - // placeholder for size of 1st alternate - out_grammar.push_back(0); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; // add preceding symbol to generated rule - out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); if (*pos == '*' || *pos == '+') { // cause generated rule to recurse - out_grammar.push_back(1); - out_grammar.push_back(sub_rule_id); + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); } - // apply actual size - out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; - // mark end of 1st alternate - out_grammar.push_back(0); - sub_rule_start = out_grammar.size(); - // placeholder for size of 2nd alternate - out_grammar.push_back(0); + // mark start of alternate def + sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if (*pos == '+') { - // add preceding symbol as alternate only for '+' - out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); } - // apply actual size of 2nd alternate - out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; - // mark end of 2nd alternate, then end of rule - out_grammar.push_back(0); - out_grammar.push_back(0); + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); // in original rule, replace previous symbol with reference to generated rule - outbuf.resize(last_sym_start); - outbuf.push_back(1); - outbuf.push_back(sub_rule_id); + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); pos = parse_space(pos + 1, is_nested); } else { break; } } - // apply actual size of this alternate sequence - outbuf[out_start] = static_cast(outbuf.size() - out_start); - // mark end of alternate - outbuf.push_back(0); return pos; } @@ -225,17 +223,17 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id, + uint32_t rule_id, bool is_nested) { - std::vector outbuf; - const char * pos = parse_sequence(state, src, rule_name, outbuf, is_nested); + std::vector rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, outbuf, is_nested); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); } - state.out_grammar.push_back(rule_id); - state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end()); - state.out_grammar.push_back(0); + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); return pos; } @@ -243,7 +241,7 @@ namespace grammar_parser { const char * name_end = parse_name(src); const char * pos = parse_space(name_end, false); size_t name_len = name_end - src; - uint16_t rule_id = get_symbol_id(state, src, name_len); + uint32_t rule_id = get_symbol_id(state, src, name_len); const std::string name(src, name_len); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { @@ -270,7 +268,6 @@ namespace grammar_parser { while (*pos) { pos = parse_rule(state, pos); } - state.out_grammar.push_back(0xffff); return state; } catch (const std::exception & err) { fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); @@ -278,53 +275,131 @@ namespace grammar_parser { } } - const uint16_t * print_rule( - FILE * file, - const uint16_t * base, - const uint16_t * src, - const std::map & symbol_id_names) { - uint16_t rule_id = *src; - fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str()); - const uint16_t * pos = src + 1; - while (*pos) { - if (pos - 1 > src) { - fprintf(file, "| "); + void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } + } + + bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + void print_rule_binary(FILE * file, const std::vector & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break; } - pos++; // sequence size, not needed here - while (*pos) { - if (*pos == 1) { - uint16_t ref_rule_id = pos[1]; - fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str()); - pos += 2; - } else { - fprintf(file, "<%zu>[", pos - base); - uint16_t num_chars = *pos; - pos++; + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } - for (uint16_t i = 0; i < num_chars; i += 2) { - fprintf(file, "%lc-", static_cast(pos[i])); // REVIEW - if (i + 1 < num_chars) { - fprintf(file, "%lc", static_cast(pos[i + 1])); - } + void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); } - fprintf(file, "] "); - pos += num_chars; + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); } } - pos++; } fprintf(file, "\n"); - return pos + 1; } void print_grammar(FILE * file, const parse_state & state) { - std::map symbol_id_names; - for (auto kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; + try { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, i, state.rules[i], symbol_id_names); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); } - const uint16_t * pos = state.out_grammar.data(); - while (*pos != 0xffff) { - pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names); + } + + std::vector parse_state::c_rules() { + std::vector ret; + for (const auto & rule : rules) { + ret.push_back(rule.data()); } + return ret; } } diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index c9e27d4cd8cb8..9037d72728a42 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -10,6 +10,7 @@ // space ::= [ \t\n]* #pragma once +#include "llama.h" #include #include #include @@ -17,8 +18,10 @@ namespace grammar_parser { struct parse_state { - std::map symbol_ids; - std::vector out_grammar; + std::map symbol_ids; + std::vector> rules; + + std::vector c_rules(); }; parse_state parse(const char * src); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7f9636aae8ac5..b6538ac132876 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -300,14 +300,16 @@ int main(int argc, char ** argv) { if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); // will be empty (default) if there are parse errors - if (parsed_grammar.out_grammar.empty()) { + if (parsed_grammar.rules.empty()) { return 1; } fprintf(stderr, "%s: grammar:\n", __func__); grammar_parser::print_grammar(stderr, parsed_grammar); fprintf(stderr, "\n"); + + std::vector grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( - parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } // TODO: replace with ring-buffer @@ -653,8 +655,12 @@ int main(int argc, char ** argv) { // reset grammar state if we're restarting generation if (grammar != NULL) { llama_grammar_free(grammar); + + std::vector grammar_rules( + parsed_grammar.c_rules()); grammar = llama_grammar_init( - parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); } } is_interacting = false; diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 0000000000000..43f25ab598586 --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [ぁ-ゟ] +katakana ::= [ァ-ヿ] +punctuation ::= [、-〾] +cjk ::= [一-鿿] diff --git a/llama.cpp b/llama.cpp index 37d19ea912051..0986c24890e69 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1841,45 +1841,86 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // struct llama_grammar { - const std::vector rules; - std::vector> stacks; + const std::vector> rules; + std::vector> stacks; }; -// transforms a grammar pushdown stack into N possible stacks, all terminating +// NOTE: assumes valid utf8 (but checks for overrun) +std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; // may overrun! + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; + case LLAMA_GRETYPE_ALT: return true; + default: return false; + } +} + +// transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( - const std::vector & rules, - const std::vector & stack, - std::vector> & new_stacks) { + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { if (stack.empty()) { new_stacks.push_back(stack); return; } - const uint16_t * pos = stack.back(); + const llama_grammar_element * pos = stack.back(); - if (*pos == 1) { - // rule reference, apply rule to stack - const uint16_t * subpos = rules[pos[1]] + 1; - while (*subpos) { - // init new stack without the top (pos) - std::vector new_stack(stack.begin(), stack.end() - 1); - if (pos[2]) { - // if the rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 2); - } - if (subpos[1]) { - // if this alternate is nonempty, add that to the stack - new_stack.push_back(subpos + 1); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); - subpos += 1 + *subpos; + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; } - } else { - // rule element size > 1 -> character reference - LLAMA_ASSERT(*pos); - new_stacks.push_back(stack); + case LLAMA_GRETYPE_CHAR: + new_stacks.push_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + LLAMA_ASSERT(false); } } @@ -1887,39 +1928,43 @@ static void llama_grammar_advance_stack( // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -static std::vector> llama_grammar_accept( - const std::vector & rules, - const std::vector> & stacks, - const uint16_t chr) { +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { - std::vector> new_stacks; + std::vector> new_stacks; for (const auto & stack : stacks) { if (stack.empty()) { continue; } - const uint16_t * pos = stack.back(); - const uint16_t num_chars = *pos; - LLAMA_ASSERT(num_chars > 1); + const llama_grammar_element * pos = stack.back(); + LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); - pos++; // skip num chars indicator bool found = false; - // loop over the inclusive char pairs to find a match on the given char - for (int i = 0; i < num_chars; i += 2) { - if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { - found = true; - break; + do { + bool matches_range; + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + matches_range = pos->value <= chr && chr <= pos[1].value; + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + matches_range = pos->value == chr; + pos += 1; } - } + found = found || matches_range; + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + if (!found) { continue; } - // advance past char range, updating top of stack to next element, if any - pos += num_chars; - std::vector new_stack(stack.begin(), stack.end() - 1); - if (*pos) { + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } llama_grammar_advance_stack(rules, new_stack, new_stacks); @@ -1930,8 +1975,8 @@ static std::vector> llama_grammar_accept( // returns `true` if one of the pushdown stacks can accept the given char. static bool llama_grammar_peek( - const std::vector> & stacks, - const uint16_t chr) { + const std::vector> & stacks, + const uint32_t chr) { for (const auto & stack : stacks) { if (stack.empty()) { @@ -1939,16 +1984,24 @@ static bool llama_grammar_peek( return true; } } else { - const uint16_t * pos = stack.back(); - const uint16_t num_chars = *pos; - LLAMA_ASSERT(num_chars > 1); - - pos++; - for (int i = 0; i < num_chars; i += 2) { - if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { - return true; + const llama_grammar_element * pos = stack.back(); + LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= chr && chr <= pos[1].value) { + return true; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (pos->value == chr) { + return true; + } + pos += 1; } - } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); } } return false; @@ -1959,45 +2012,44 @@ static bool llama_grammar_peek( // grammar - external // -struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) { - const uint16_t * pos = src; - std::vector rules; +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; - // build `rules` as list of pointers to rules embedded in binary grammar `src` - while (*pos != 0xffff) { - uint16_t rule_id = *pos; - if (rules.size() <= rule_id) { - rules.resize(rule_id + 1); + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); } - rules[rule_id] = pos; - // skip rule id - pos++; - // skip rule alternates - while (*pos) { - pos += 1 + *pos; - } - // skip 0 denoting end of rule - pos++; + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); } - const uint16_t * start_rule = rules[start_rule_id]; - - LLAMA_ASSERT(*start_rule == start_rule_id); - // loop over alternates of start rule to build initial stacks - pos = start_rule + 1; - std::vector> stacks; - while (*pos) { - std::vector stack; - if (pos[1]) { - // if alernate is nonempty, add to stack - stack.push_back(pos + 1); + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; } - llama_grammar_advance_stack(rules, stack, stacks); - pos += 1 + *pos; - } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); - return new llama_grammar{ rules, stacks }; + return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; } void llama_grammar_free(struct llama_grammar * grammar) { @@ -2285,7 +2337,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c const int64_t t_start_sample_us = ggml_time_us(); const llama_token eos = llama_token_eos(); // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' - const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' '); + const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' '); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -2296,10 +2348,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c bool valid = false; if (id == eos) { valid = llama_grammar_peek(grammar->stacks, 0); - } else if (str[0] == ' ') { - valid = llama_grammar_peek(stacks_after_space, str[1]); - } else if (str[0] != 0) { - valid = llama_grammar_peek(grammar->stacks, str[0]); + } else { + const auto decoded = decode_utf8(str); + const uint32_t chr = decoded.first; + if (chr == U' ') { + const char * next = decoded.second; + valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first); + } else if (chr != 0) { + valid = llama_grammar_peek(grammar->stacks, chr); + } } if (!valid) { @@ -2451,13 +2508,15 @@ llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_ const char * suffix = str; // Find prefix of selected token that matches grammar, expecting at least 1 char - auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix); + auto decoded = decode_utf8(suffix); + auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); LLAMA_ASSERT(!new_stacks.empty()); if (*suffix) { - ++suffix; - for ( ; *suffix; ++suffix) { - new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix); - if (new_stacks.empty()) { + suffix = decoded.second; + for ( ; *suffix; suffix = decoded.second) { + decoded = decode_utf8(suffix); + new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first); + if (new_stacks.empty() ) { break; } } @@ -2480,8 +2539,9 @@ llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_ // accept the first token of the matching prefix into the grammar llama_token first_prefix_token = tokens[0]; const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); - for ( ; *first_prefix_str; ++first_prefix_str) { - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str); + for ( ; *first_prefix_str; first_prefix_str = decoded.second) { + decoded = decode_utf8(first_prefix_str); + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); LLAMA_ASSERT(!grammar->stacks.empty()); } diff --git a/llama.h b/llama.h index c7f2c841ec6e1..e827fa33c502a 100644 --- a/llama.h +++ b/llama.h @@ -55,8 +55,6 @@ extern "C" { struct llama_context; - struct llama_grammar; - typedef int llama_token; typedef struct llama_token_data { @@ -125,6 +123,37 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 5, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); @@ -243,26 +272,11 @@ extern "C" { // Grammar // - // Accepts a binary encoding of a context-free grammar. The returned struct can be used to - // constrain sampled tokens (see below). - // - // The binary format represents one or more production rules, each with one or more alternate - // defininitions: - // - // ( ( )+ 0000)+ FFFF - // - // rule_ids should be assigned sequentially from zero but may appear out of order. Each - // rule alternate is a sequence of zero or more symbols, each prefixed with size: - // - // ( )* 0000 - // - // A symbol of size 1 is interpreted as a rule reference (whose value is the single following - // u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to - // match. Note that symbol sizes greater than 7FFF are reserved for future use. - // - // The provided `src` must be kept valid for the lifetime of the `llama_grammar`. - // - LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id); + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); // Sampling functions From 014fbfd4a993c164ba94da5832c47ddd175b6484 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Mon, 10 Jul 2023 23:26:09 -0400 Subject: [PATCH 13/18] add unicode escapes --- examples/grammar-parser.cpp | 65 ++++++++++++++++++++----------------- grammars/json.gbnf | 11 +++++-- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 206ccac56c3b0..5ac0d41f8b689 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -50,15 +50,27 @@ namespace grammar_parser { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } - int hex_to_int(char c) { - if ('a' <= c && c <= 'f') { - return c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - return c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - return c - '0'; + std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } } - return -1; + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); } const char * parse_space(const char * src, bool newline_ok) { @@ -89,30 +101,23 @@ namespace grammar_parser { std::pair parse_char(const char * src) { if (*src == '\\') { - char esc = src[1]; - // TODO: 16- and 32-bit escapes - if (esc == 'x') { - int first = hex_to_int(src[2]); - if (first > -1) { - int second = hex_to_int(src[3]); - if (second > -1) { - return std::make_pair((first << 4) + second, src + 4); - } - } - throw std::runtime_error(std::string("expecting \\xNN at ") + src); - } else if (esc == '"' || esc == '[' || esc == ']') { - return std::make_pair(esc, src + 2); - } else if (esc == 'r') { - return std::make_pair('\r', src + 2); - } else if (esc == 'n') { - return std::make_pair('\n', src + 2); - } else if (esc == 't') { - return std::make_pair('\t', src + 2); + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); } - throw std::runtime_error(std::string("unknown escape at ") + src); } else if (*src) { - auto decoded = decode_utf8(src); - return std::make_pair(decoded.first, decoded.second); + return decode_utf8(src); } throw std::runtime_error("unexpected end of input"); } diff --git a/grammars/json.gbnf b/grammars/json.gbnf index 72f4857e4de87..668836df89812 100644 --- a/grammars/json.gbnf +++ b/grammars/json.gbnf @@ -15,10 +15,15 @@ array ::= ("," ws value)* )? "]" -# Subset of JSON primitives: strings without escapes and only regular integers -string ::= "\"" [ \t!#-\[\]-~]* "\"" ws +string ::= + "\"" ( + [\x20\x21\x23-\x5b\x5d-\U0010FFFF] | # any code point except " (\x22) and \ (\x5c) + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" + +# Only plain integers currently number ::= "-"? [0-9]+ ws boolean ::= ("true" | "false") ws # Optional space: by convention, applied in this grammar after literal chars when allowed -ws ::= [ \t\n] ws | +ws ::= ([ \t\n] ws)? From 8d37755bdc22b0bc2d6876cade32fa1534e75d3d Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 18 Jul 2023 21:54:44 -0400 Subject: [PATCH 14/18] add inverse char ranges --- examples/grammar-parser.cpp | 17 +++++++- grammars/json.gbnf | 6 +-- grammars/list.gbnf | 4 ++ llama.cpp | 80 +++++++++++++++++-------------------- llama.h | 7 +++- 5 files changed, 63 insertions(+), 51 deletions(-) create mode 100644 grammars/list.gbnf diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 5ac0d41f8b689..019d5e1bf9ceb 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -149,13 +149,18 @@ namespace grammar_parser { pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } last_sym_start = out_elements.size(); while (*pos != ']') { auto char_pair = parse_char(pos); pos = char_pair.second; enum llama_gretype type = last_sym_start < out_elements.size() ? LLAMA_GRETYPE_CHAR_ALT - : LLAMA_GRETYPE_CHAR; + : start_type; out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { @@ -292,6 +297,7 @@ namespace grammar_parser { bool is_char_element(llama_grammar_element elem) { switch (elem.type) { case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; default: return false; @@ -305,8 +311,9 @@ namespace grammar_parser { case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -315,6 +322,7 @@ namespace grammar_parser { fprintf(file, "(%u) ", elem.value); break; case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "(\""); @@ -353,6 +361,10 @@ namespace grammar_parser { fprintf(file, "["); print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: if (i == 0 || !is_char_element(rule[i - 1])) { throw std::runtime_error( @@ -394,6 +406,7 @@ namespace grammar_parser { // fprintf(file, "%zu: ", i); // print_rule_binary(file, state.rules[i]); print_rule(file, i, state.rules[i], symbol_id_names); + // fprintf(file, "\n"); } } catch (const std::exception & err) { fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); diff --git a/grammars/json.gbnf b/grammars/json.gbnf index 668836df89812..40fa2b6373255 100644 --- a/grammars/json.gbnf +++ b/grammars/json.gbnf @@ -1,7 +1,7 @@ # Grammar for subset of JSON - doesn't support full string or number syntax root ::= object -value ::= object | array | string | number | boolean +value ::= object | array | string | number | boolean | "null" object ::= "{" ws ( @@ -17,9 +17,9 @@ array ::= string ::= "\"" ( - [\x20\x21\x23-\x5b\x5d-\U0010FFFF] | # any code point except " (\x22) and \ (\x5c) + [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes - )* "\"" + )* "\"" ws # Only plain integers currently number ::= "-"? [0-9]+ ws diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 0000000000000..51e6c9c4b0329 --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" diff --git a/llama.cpp b/llama.cpp index 8cd13807db132..0f6eddb50bb41 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1925,6 +1925,31 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) } } +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -1969,6 +1994,7 @@ static void llama_grammar_advance_stack( break; } case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: new_stacks.push_back(stack); break; default: @@ -1995,34 +2021,17 @@ static std::vector> llama_grammar_acc continue; } - const llama_grammar_element * pos = stack.back(); - LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; - bool found = false; - do { - bool matches_range; - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - matches_range = pos->value <= chr && chr <= pos[1].value; - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - matches_range = pos->value == chr; - pos += 1; + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); } - found = found || matches_range; - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); - - if (!found) { - continue; + llama_grammar_advance_stack(rules, new_stack, new_stacks); } - - // update top of stack to next element, if any - std::vector new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); } return new_stacks; @@ -2038,25 +2047,8 @@ static bool llama_grammar_peek( if (!chr) { return true; } - } else { - const llama_grammar_element * pos = stack.back(); - LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); - - do { - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - if (pos->value <= chr && chr <= pos[1].value) { - return true; - } - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - if (pos->value == chr) { - return true; - } - pos += 1; - } - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + } else if (llama_grammar_match_char(stack.back(), chr).first) { + return true; } } return false; diff --git a/llama.h b/llama.h index 9d584b1303637..c9b2210382f2f 100644 --- a/llama.h +++ b/llama.h @@ -151,13 +151,16 @@ extern "C" { // terminal element: character (code point) LLAMA_GRETYPE_CHAR = 3, + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 5, + LLAMA_GRETYPE_CHAR_ALT = 6, }; typedef struct llama_grammar_element { From c047e8aec230098c9f2992bd4f528fc15eeaef74 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 19 Jul 2023 23:28:57 -0400 Subject: [PATCH 15/18] only sample full tokens (no peeking or truncation) --- examples/main/main.cpp | 2 +- llama.cpp | 202 ++++++++++++++++++++++++----------------- llama.h | 4 +- 3 files changed, 120 insertions(+), 88 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 035afbe28eeeb..ed86b5022d20a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -613,7 +613,7 @@ int main(int argc, char ** argv) { // printf("`%d`", candidates_p.size); if (grammar != NULL) { - id = llama_grammar_accept_token(ctx, grammar, id); + llama_grammar_accept_token(ctx, grammar, id); } last_n_tokens.erase(last_n_tokens.begin()); diff --git a/llama.cpp b/llama.cpp index 0f6eddb50bb41..d8c67a6622567 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1900,20 +1900,32 @@ struct llama_grammar { std::vector> stacks; }; +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; +}; + // NOTE: assumes valid utf8 (but checks for overrun) -std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; // may overrun! - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); +// adds a terminating 0 for use as pointer +std::vector decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = pos + len; // may overrun! + ++pos; + for ( ; pos < end && *pos != 0; ++pos) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + code_points.push_back(value); + } + code_points.push_back(0); + return code_points; } // returns true iff pos points to the end of one of the definitions of a rule @@ -2037,23 +2049,72 @@ static std::vector> llama_grammar_acc return new_stacks; } -// returns `true` if one of the pushdown stacks can accept the given char. -static bool llama_grammar_peek( +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, const std::vector> & stacks, - const uint32_t chr) { + const std::vector & candidates); - for (const auto & stack : stacks) { - if (stack.empty()) { - if (!chr) { - return true; +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + // accept nothing; EOS is handled elsewhere + rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { + if (tok.code_points[1] != 0) { + next_candidates.push_back({ tok.index, tok.code_points + 1 }); } - } else if (llama_grammar_match_char(stack.back(), chr).first) { - return true; + } else { + rejects.push_back(tok); } } - return false; + + auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1 }); + } + + return rejects; } +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + LLAMA_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} // // grammar - external @@ -2383,34 +2444,39 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); const llama_token eos = llama_token_eos(); - // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' - const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' '); - for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const char * str = llama_token_to_str(ctx, id); + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } - // prune tokens based on first char only - in `llama_grammar_accept_token` we will find the - // full matching prefix of the selected token - bool valid = false; + std::vector> decoded_candidates; + std::vector grammar_candidates; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); if (id == eos) { - valid = llama_grammar_peek(grammar->stacks, 0); - } else { - const auto decoded = decode_utf8(str); - const uint32_t chr = decoded.first; - if (chr == U' ') { - const char * next = decoded.second; - valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first); - } else if (chr != 0) { - valid = llama_grammar_peek(grammar->stacks, chr); + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; } - } - - if (!valid) { + } else if (*str == 0) { candidates->data[i].logit = -INFINITY; + } else { + decoded_candidates.push_back(decode_utf8(str)); + grammar_candidates.push_back({ i, decoded_candidates.back().data() }); } } + auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, grammar_candidates); + for (auto reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -2599,61 +2665,27 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } -llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); if (token == llama_token_eos()) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { - return token; + return; } } LLAMA_ASSERT(false); } - const char * str = llama_token_to_str(ctx, token); - const char * suffix = str; - - // Find prefix of selected token that matches grammar, expecting at least 1 char - auto decoded = decode_utf8(suffix); - auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); - LLAMA_ASSERT(!new_stacks.empty()); - if (*suffix) { - suffix = decoded.second; - for ( ; *suffix; suffix = decoded.second) { - decoded = decode_utf8(suffix); - new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first); - if (new_stacks.empty() ) { - break; - } - } - } - - // if full token is matched, accept new stacks - if (!(*suffix)) { - grammar->stacks = new_stacks; - return token; - } - - // otherwise, tokenize the string prefix that did match - llama_token tokens[32]; // TODO - determine actual max token size - const std::string prefix_str(str, suffix - str); - int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false); - if (n_tokens < 1) { - return token; // REVIEW - } - - // accept the first token of the matching prefix into the grammar - llama_token first_prefix_token = tokens[0]; - const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); - for ( ; *first_prefix_str; first_prefix_str = decoded.second) { - decoded = decode_utf8(first_prefix_str); - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); - LLAMA_ASSERT(!grammar->stacks.empty()); + const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string + auto code_points = decode_utf8(str); + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } + LLAMA_ASSERT(!grammar->stacks.empty()); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - return first_prefix_token; } // diff --git a/llama.h b/llama.h index c9b2210382f2f..25cd001ad9f78 100644 --- a/llama.h +++ b/llama.h @@ -404,8 +404,8 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Accepts the sampled token into the grammar, possibly transforming to a new token - LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); From 11315b1d61352944791db9c81db1b7bd8bd39f2e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Jul 2023 15:11:23 +0300 Subject: [PATCH 16/18] llama : minor style changes blindly applied in online editor - hopefully I didn't break something --- llama.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d8c67a6622567..c8aa8a022563d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2442,8 +2442,7 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { assert(ctx); - const int64_t t_start_sample_us = ggml_time_us(); - const llama_token eos = llama_token_eos(); + const int64_t t_start_sample_us = ggml_time_us(); bool allow_eos = false; for (const auto & stack : grammar->stacks) { @@ -2453,8 +2452,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } - std::vector> decoded_candidates; - std::vector grammar_candidates; + const llama_token eos = llama_token_eos(); + + std::vector> candidates_decoded; + std::vector candidates_grammar; for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -2466,14 +2467,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } else if (*str == 0) { candidates->data[i].logit = -INFINITY; } else { - decoded_candidates.push_back(decode_utf8(str)); - grammar_candidates.push_back({ i, decoded_candidates.back().data() }); + candidates_decoded.push_back(decode_utf8(str)); + candidates_grammar.push_back({ i, candidates_decoded.back().data() }); } } - auto rejects = - llama_grammar_reject_candidates(grammar->rules, grammar->stacks, grammar_candidates); - for (auto reject : rejects) { + const auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } From f7f1d266e395d9deed0f504a45fb628ee7ed08c4 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 23 Jul 2023 21:45:48 -0400 Subject: [PATCH 17/18] update help text --- examples/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 8186000924e11..779605f9d1cb7 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -536,7 +536,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); - fprintf(stdout, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n"); + fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); From 4cd9711dacb5c060d05159c6b3427c55f1fccdfc Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 23 Jul 2023 22:58:56 -0400 Subject: [PATCH 18/18] add warning message if EOS is disabled --- examples/main/main.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9995ab8edd8aa..16ddc22747f6b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -350,6 +350,14 @@ int main(int argc, char ** argv) { grammar_parser::print_grammar(stderr, parsed_grammar); fprintf(stderr, "\n"); + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + std::vector grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));