From 31476ccc0e529d36ce3727e994b2d41d28b98606 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 29 Aug 2023 22:56:29 -0400 Subject: [PATCH 1/8] whisper : add grammar-based sampling --- Makefile | 4 +- examples/command/command.cpp | 45 +++- examples/grammar-parser.cpp | 423 ++++++++++++++++++++++++++++++++ examples/grammar-parser.h | 29 +++ whisper.cpp | 454 +++++++++++++++++++++++++++++++++++ whisper.h | 36 +++ 6 files changed, 983 insertions(+), 8 deletions(-) create mode 100644 examples/grammar-parser.cpp create mode 100644 examples/grammar-parser.h diff --git a/Makefile b/Makefile index a2631011bea..363a160ecfc 100644 --- a/Makefile +++ b/Makefile @@ -297,8 +297,8 @@ quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON) stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) -command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) +command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 54e3549f3bc..e93d8908b6d 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -9,6 +9,7 @@ #include "common.h" #include "common-sdl.h" #include "whisper.h" +#include "grammar-parser.h" #include #include @@ -32,6 +33,7 @@ struct whisper_params { float vad_thold = 0.6f; float freq_thold = 100.0f; + float grammar_penalty = 100.0f; bool speed_up = false; bool translate = false; @@ -44,6 +46,7 @@ struct whisper_params { std::string fname_out; std::string commands; std::string prompt; + std::string grammar; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -73,6 +76,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } + else if ( arg == "--grammar") { params.grammar = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -106,6 +111,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, "\n"); } @@ -115,6 +122,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con prob = 0.0f; t_ms = 0; + grammar_parser::parse_state parsed_grammar; + std::vector grammar_rules; + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; @@ -131,6 +141,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + grammar_rules = parsed_grammar.c_rules(); + wparams.grammar_rules = grammar_rules.data(); + wparams.n_grammar_rules = grammar_rules.size(); + wparams.i_start_rule = parsed_grammar.symbol_ids.at("root"); + wparams.grammar_penalty = params.grammar_penalty; + } + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } @@ -648,12 +667,26 @@ int main(int argc, char ** argv) { int ret_val = 0; - if (!params.commands.empty()) { - ret_val = process_command_list(ctx, audio, params); - } else if (!params.prompt.empty()) { - ret_val = always_prompt_transcription(ctx, audio, params); - } else { - ret_val = process_general_transcription(ctx, audio, params); + if (!params.grammar.empty()) { + auto parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + ret_val = 1; + } else { + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + } + } + + if (ret_val == 0) { + if (!params.commands.empty()) { + ret_val = process_command_list(ctx, audio, params); + } else if (!params.prompt.empty()) { + ret_val = always_prompt_transcription(ctx, audio, params); + } else { + ret_val = process_general_transcription(ctx, audio, params); + } } audio.pause(); diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp new file mode 100644 index 00000000000..b5b607fa9d0 --- /dev/null +++ b/examples/grammar-parser.cpp @@ -0,0 +1,423 @@ +#include "grammar-parser.h" +#include +#include +#include +#include +#include +#include + +namespace grammar_parser { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from whisper.cpp + std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + + bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); + } + + const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; + } + + const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; + } + + std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = out_elements.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = WHISPER_GRETYPE_CHAR_NOT; + } + last_sym_start = out_elements.size(); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum whisper_gretype type = last_sym_start < out_elements.size() + ? WHISPER_GRETYPE_CHAR_ALT + : start_type; + + out_elements.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = out_elements.size(); + out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + last_sym_start = out_elements.size(); + // output reference to synthesized rule + out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; + // add preceding symbol to generated rule + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id}); + } + // mark start of alternate def + sub_rule.push_back({WHISPER_GRETYPE_ALT, 0}); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + } + sub_rule.push_back({WHISPER_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); + + // in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start); + out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id}); + + pos = parse_space(pos + 1, is_nested); + } else { + break; + } + } + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + std::vector rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({WHISPER_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); + } + rule.push_back({WHISPER_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); + return pos; + } + + const char * parse_rule(parse_state & state, const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(state, pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + + parse_state parse(const char * src) { + try { + parse_state state; + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(state, pos); + } + return state; + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return parse_state(); + } + } + + void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } + } + + bool is_char_element(whisper_grammar_element elem) { + switch (elem.type) { + case WHISPER_GRETYPE_CHAR: return true; + case WHISPER_GRETYPE_CHAR_NOT: return true; + case WHISPER_GRETYPE_CHAR_ALT: return true; + case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + void print_rule_binary(FILE * file, const std::vector & rule) { + for (auto elem : rule) { + switch (elem.type) { + case WHISPER_GRETYPE_END: fprintf(file, "END"); break; + case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break; + case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + } + switch (elem.type) { + case WHISPER_GRETYPE_END: + case WHISPER_GRETYPE_ALT: + case WHISPER_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + case WHISPER_GRETYPE_CHAR_RNG_UPPER: + case WHISPER_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } + + void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + whisper_grammar_element elem = rule[i]; + switch (elem.type) { + case WHISPER_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case WHISPER_GRETYPE_ALT: + fprintf(file, "| "); + break; + case WHISPER_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case WHISPER_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case WHISPER_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case WHISPER_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case WHISPER_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "WHISPER_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case WHISPER_GRETYPE_CHAR_ALT: + case WHISPER_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); + } + + void print_grammar(FILE * file, const parse_state & state) { + try { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } + } + + std::vector parse_state::c_rules() { + std::vector ret; + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; + } +} diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h new file mode 100644 index 00000000000..ef0ec44174f --- /dev/null +++ b/examples/grammar-parser.h @@ -0,0 +1,29 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by whisper.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include "whisper.h" +#include +#include +#include +#include + +namespace grammar_parser { + struct parse_state { + std::map symbol_ids; + std::vector> rules; + + std::vector c_rules(); + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/whisper.cpp b/whisper.cpp index 1f4f8a06121..9a9f5fdd17b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -586,6 +586,25 @@ struct whisper_model { std::map tensors; }; +struct whisper_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct whisper_grammar { + /*const*/ std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_grammar_candidate { + whisper_token id; + const uint32_t * code_points; + whisper_partial_utf8 partial_utf8; +}; + struct whisper_sequence { std::vector tokens; @@ -607,6 +626,9 @@ struct whisper_decoder { // the currently generated sequence of tokens whisper_sequence sequence; + // grammar parse state of generated sequence of tokens + whisper_grammar grammar; + int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? @@ -3495,6 +3517,422 @@ const char * whisper_print_system_info(void) { return s.c_str(); } +////////////////////////////////// +// Grammar - ported from llama.cpp +////////////////////////////////// + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. +std::pair, whisper_partial_utf8> decode_utf8( + const char * src, + whisper_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) { + switch (pos->type) { + case WHISPER_GRETYPE_END: return true; // NOLINT + case WHISPER_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair whisper_grammar_match_char( + const whisper_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool whisper_grammar_match_partial_char( + const whisper_grammar_element * pos, + const whisper_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void whisper_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const whisper_grammar_element * pos = stack.back(); + + switch (pos->type) { + case WHISPER_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `whisper_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> whisper_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = whisper_grammar_match_char(stack.back(), chr); + if (match.first) { + const whisper_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector whisper_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const whisper_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + whisper_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + if (candidates.empty() || stacks.empty()) { + return std::vector(); + } + + auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +static struct whisper_grammar whisper_grammar_init( + const whisper_grammar_element ** rules, + size_t n_rules, + size_t i_start_rule) { + const whisper_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[i_start_rule]; + do { + std::vector stack; + if (!whisper_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + whisper_grammar_advance_stack(vec_rules, stack, stacks); + while (!whisper_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return { std::move(vec_rules), std::move(stacks), {} }; +} + +static void whisper_suppress_invalid_grammar( + whisper_context & ctx, + const whisper_full_params & params, + std::vector & logits, + const whisper_grammar & grammar) { + + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + // bool allow_eot = false; + // for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + // } + + std::vector, whisper_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + size_t size = logits.size(); + for (whisper_token id = 0; id < size; ++id) { + const std::string & text = ctx.vocab.id_to_token[id]; + if (!text.empty() && text.rfind("[_", 0) != 0) { + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); + candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + for (const auto & reject : rejects) { + if (logits[reject.id] > 0) { + logits[reject.id] /= params.grammar_penalty; + } else { + logits[reject.id] *= params.grammar_penalty; + } + } + // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); +} + +static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + // fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); + + const std::string & text = ctx.vocab.id_to_token[token]; + + if (text.rfind("[_", 0) == 0) { + // fprintf(stderr, " (skipped)\n"); + return; + } + // fprintf(stderr, "\n"); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); + } + grammar.partial_utf8 = decoded.second; +} + +////////////// +// END grammar +////////////// + //////////////////////////////////////////////////////////////////////////// struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { @@ -3575,6 +4013,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.logits_filter_callback =*/ nullptr, /*.logits_filter_callback_user_data =*/ nullptr, + + /*.grammar_rules =*/ nullptr, + /*.n_grammar_rules =*/ 0, + /*.i_start_rule =*/ 0, + /*.grammar_penalty =*/ 1000.0f, }; switch (strategy) { @@ -3744,6 +4187,8 @@ static void whisper_process_logits( params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 if (params.suppress_non_speech_tokens) { @@ -4344,6 +4789,13 @@ int whisper_full_with_state( decoder.failed = false; decoder.completed = false; decoder.has_ts = false; + + if (params.grammar_rules != nullptr) { + decoder.grammar = whisper_grammar_init( + params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + } else { + decoder.grammar = {}; + } } // init prompt and kv cache for the current iteration @@ -4526,6 +4978,8 @@ int whisper_full_with_state( has_ts = true; } + whisper_grammar_accept_token(*ctx, decoder.grammar, token.id); + #ifdef WHISPER_DEBUG { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; diff --git a/whisper.h b/whisper.h index 73ab4d799a2..23f61ed5b06 100644 --- a/whisper.h +++ b/whisper.h @@ -96,6 +96,37 @@ extern "C" { void (*close)(void * ctx); } whisper_model_loader; + // grammar element type + enum whisper_gretype { + // end of rule definition + WHISPER_GRETYPE_END = 0, + + // start of alternate definition for rule + WHISPER_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + WHISPER_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + WHISPER_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + WHISPER_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + WHISPER_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding WHISPER_GRETYPE_CHAR or + // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + WHISPER_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct whisper_grammar_element { + enum whisper_gretype type; + uint32_t value; // Unicode code point or rule ID + } whisper_grammar_element; + // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure @@ -431,6 +462,11 @@ extern "C" { // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; void * logits_filter_callback_user_data; + + const whisper_grammar_element ** grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params() From b0306cd5cfc9733e4b8640468bc3a67a831442e2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Sep 2023 14:27:51 +0300 Subject: [PATCH 2/8] build : fix after master merge --- examples/CMakeLists.txt | 1 + examples/talk-llama/llama.cpp | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e0895669767..d91019197d9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(${TARGET} STATIC common.cpp common-ggml.h common-ggml.cpp + grammar-parser.cpp ) include(DefaultTargetOptions) diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 77550faa43c..aecae009d05 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -1164,7 +1164,7 @@ static bool llama_eval_internal( const llama_token * tokens, const int n_tokens, const int n_past, - const int n_threads) { + int n_threads) { // enforce that the first token is BOS if (n_past == 0 && tokens[0] != llama_token_bos()) { @@ -1190,6 +1190,8 @@ static bool llama_eval_internal( const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_embd/hparams.n_head; + const float eps = 5e-6f; // TODO: take from hparams + auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; @@ -1204,7 +1206,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; + n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_set_name(embd, "embd"); @@ -1221,7 +1223,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, eps); // cur = cur*attention_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); @@ -1329,7 +1331,7 @@ static bool llama_eval_internal( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, eps); // cur = cur*ffn_norm(broadcasted) cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); @@ -1367,7 +1369,7 @@ static bool llama_eval_internal( // norm { - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, eps); // inpL = inpL*norm(broadcasted) inpL = ggml_mul(ctx0, inpL, model.norm); @@ -1384,8 +1386,8 @@ static bool llama_eval_internal( //inpL = ggml_soft_max_inplace(ctx0, inpL); // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); + ggml_build_forward_expand (&gf, inpL); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -2488,8 +2490,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * } struct ggml_cgraph gf = ggml_build_forward(r); - gf.n_threads = n_threads; - ggml_graph_compute(lora_ctx, &gf); + ggml_graph_compute_with_ctx(lora_ctx, &gf, n_threads); // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); @@ -2635,7 +2636,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; - gf.n_threads = 1; ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); kout3d->data = out; @@ -2655,7 +2655,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute(cpy_ctx, &gf); + ggml_graph_compute_with_ctx(cpy_ctx, &gf, 1); ggml_free(cpy_ctx); } @@ -2743,7 +2743,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; - gf.n_threads = 1; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); kin3d->data = (void *) inp; @@ -2763,7 +2762,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute(cpy_ctx, &gf); + ggml_graph_compute_with_ctx(cpy_ctx, &gf, 1); ggml_free(cpy_ctx); } From 97ebb48b990b4e37b76e6c87e3e90347f9464f69 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Sep 2023 15:05:17 +0300 Subject: [PATCH 3/8] command : fix exception when recognizing the command --- examples/command/command.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index e93d8908b6d..85789d35de6 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -587,7 +587,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // find the prompt in the text float best_sim = 0.0f; size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + if (n >= txt.size()) { + break; + } + const auto prompt = txt.substr(0, n); const float sim = similarity(prompt, k_prompt); @@ -600,9 +604,15 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud } } - const std::string command = ::trim(txt.substr(best_len)); + if (best_len == 0) { + fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); + } else { + // cut the prompt from the decoded text + const std::string command = ::trim(txt.substr(best_len)); + + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + } - fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "\n"); } From b8f34d1ed786194d9787b1f1d086b89136e361f3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Sep 2023 17:05:05 +0300 Subject: [PATCH 4/8] whisper : fine-tuning grammar functionality --- examples/command/command.cpp | 18 ++++-- whisper.cpp | 118 ++++++++++++++++++++++++----------- 2 files changed, 94 insertions(+), 42 deletions(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 85789d35de6..f33f8e15ffb 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -31,8 +31,9 @@ struct whisper_params { int32_t max_tokens = 32; int32_t audio_ctx = 0; - float vad_thold = 0.6f; - float freq_thold = 100.0f; + float vad_thold = 0.6f; + float freq_thold = 100.0f; + float grammar_penalty = 100.0f; bool speed_up = false; @@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; + // disable fallback - seems not useful for command recognition + wparams.temperature_inc = 0.0f; + wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; @@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // general-purpose mode // freely transcribe the voice into text -int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { +int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { bool is_running = true; bool have_prompt = false; bool ask_prompt = true; @@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud std::vector pcmf32_cur; std::vector pcmf32_prompt; - const std::string k_prompt = "Ok Whisper, start listening for commands."; + //const std::string k_prompt = "Ok Whisper, start listening for commands."; + //const std::string k_prompt = "Начало."; + const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди."; fprintf(stderr, "\n"); fprintf(stderr, "%s: general-purpose mode\n", __func__); @@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + // append 1 second of silence + pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); prob = 100.0f*(prob - prob0); @@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud } } + fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str()); if (best_len == 0) { fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); } else { diff --git a/whisper.cpp b/whisper.cpp index 078841b391c..5e3b86a88b8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init( static void whisper_suppress_invalid_grammar( whisper_context & ctx, const whisper_full_params & params, - std::vector & logits, + std::vector & logprobs, const whisper_grammar & grammar) { if (grammar.rules.empty() || grammar.stacks.empty()) { @@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar( std::vector, whisper_partial_utf8>> candidates_decoded; std::vector candidates_grammar; - size_t size = logits.size(); - for (whisper_token id = 0; id < size; ++id) { + size_t size = logprobs.size(); + for (whisper_token id = 0; id < (int) size; ++id) { const std::string & text = ctx.vocab.id_to_token[id]; if (!text.empty() && text.rfind("[_", 0) != 0) { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); @@ -3893,14 +3893,18 @@ static void whisper_suppress_invalid_grammar( } const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + for (const auto & reject : rejects) { - if (logits[reject.id] > 0) { - logits[reject.id] /= params.grammar_penalty; - } else { - logits[reject.id] *= params.grammar_penalty; - } + logprobs[reject.id] -= params.grammar_penalty; } - // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); + + // when the grammar does not allow any continuation, we don't want to penalize the EOT token + // TODO: is there are better way to do this? + printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); + if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { + logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; + } + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { @@ -3908,10 +3912,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar return; } - // fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); + fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); const std::string & text = ctx.vocab.id_to_token[token]; - + if (text.rfind("[_", 0) == 0) { // fprintf(stderr, " (skipped)\n"); return; @@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.grammar_rules =*/ nullptr, /*.n_grammar_rules =*/ 0, /*.i_start_rule =*/ 0, - /*.grammar_penalty =*/ 1000.0f, + /*.grammar_penalty =*/ 100.0f, }; switch (strategy) { @@ -4181,12 +4185,18 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } - whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 if (params.suppress_non_speech_tokens) { @@ -4293,10 +4303,19 @@ static void whisper_process_logits( //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { + //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; } + } else { + //printf("sampling text\n"); + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + + whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); } } } @@ -4312,34 +4331,57 @@ static void whisper_process_logits( } } -#if 0 +#if 1 // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); } // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); #endif } From 54d168db67f63344ee3cd3bfc0137960d5f4d3fa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Sep 2023 17:58:54 +0300 Subject: [PATCH 5/8] command : grammar-related improvements - option to read grammar from file - add sample grammars for colors and chess moves - fine-tune the performance further --- examples/command/command.cpp | 77 ++++++++++++++++++++++++------------ examples/grammar-parser.cpp | 2 +- examples/grammar-parser.h | 2 +- grammars/chess.gbnf | 27 +++++++++++++ grammars/colors.gbnf | 24 +++++++++++ whisper.cpp | 65 ++++++++++++++++++------------ 6 files changed, 144 insertions(+), 53 deletions(-) create mode 100644 grammars/chess.gbnf create mode 100644 grammars/colors.gbnf diff --git a/examples/command/command.cpp b/examples/command/command.cpp index f33f8e15ffb..dfffbda72bd 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -22,6 +22,11 @@ #include #include +bool file_exists(const std::string & fname) { + std::ifstream f(fname.c_str()); + return f.good(); +} + // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -36,6 +41,8 @@ struct whisper_params { float grammar_penalty = 100.0f; + grammar_parser::parse_state grammar_parsed; + bool speed_up = false; bool translate = false; bool print_special = false; @@ -117,15 +124,18 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); } -std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { +std::string transcribe( + whisper_context * ctx, + const whisper_params & params, + const std::vector & pcmf32, + const std::string & grammar_rule, + float & prob, + int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); prob = 0.0f; t_ms = 0; - grammar_parser::parse_state parsed_grammar; - std::vector grammar_rules; - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; @@ -140,17 +150,20 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.n_threads = params.n_threads; // disable fallback - seems not useful for command recognition - wparams.temperature_inc = 0.0f; + wparams.temperature_inc = 0.00f; - wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - grammar_rules = parsed_grammar.c_rules(); + //wparams.initial_prompt = params.prompt.data(); + + const auto & grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (!params.grammar_parsed.rules.empty()) { wparams.grammar_rules = grammar_rules.data(); wparams.n_grammar_rules = grammar_rules.size(); - wparams.i_start_rule = parsed_grammar.symbol_ids.at("root"); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule); wparams.grammar_penalty = params.grammar_penalty; } @@ -270,7 +283,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const fprintf(stderr, " ]\n"); } - std::string k_prompt = "select one from the available words: "; + std::string k_prompt = "select one from the available words: "; for (int i = 0; i < (int) allowed_commands.size(); ++i) { if (i > 0) { k_prompt += ", "; @@ -476,7 +489,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // detect the commands audio.get(params.command_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", prob, t_ms)); const auto words = get_words(txt); @@ -523,9 +536,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au std::vector pcmf32_cur; std::vector pcmf32_prompt; - //const std::string k_prompt = "Ok Whisper, start listening for commands."; - //const std::string k_prompt = "Начало."; - const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди."; + std::string k_prompt = "Ok Whisper, start listening for commands."; + if (!params.prompt.empty()) { + k_prompt = params.prompt; + } fprintf(stderr, "\n"); fprintf(stderr, "%s: general-purpose mode\n", __func__); @@ -558,7 +572,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // wait for activation phrase audio.get(params.prompt_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob0, t_ms)); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); @@ -581,13 +595,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // we have heard the activation phrase, now detect the commands audio.get(params.command_ms, pcmf32_cur); + //printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE); + //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE); + + // prepend 3 second of silence + pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f); + // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - // append 1 second of silence - pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f); - - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob, t_ms)); prob = 100.0f*(prob - prob0); @@ -688,13 +705,23 @@ int main(int argc, char ** argv) { int ret_val = 0; if (!params.grammar.empty()) { - auto parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + auto & grammar = params.grammar_parsed; + if (file_exists(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { + if (grammar.rules.empty()) { ret_val = 1; } else { fprintf(stderr, "%s: grammar:\n", __func__); - grammar_parser::print_grammar(stderr, parsed_grammar); + grammar_parser::print_grammar(stderr, grammar); fprintf(stderr, "\n"); } } @@ -702,7 +729,7 @@ int main(int argc, char ** argv) { if (ret_val == 0) { if (!params.commands.empty()) { ret_val = process_command_list(ctx, audio, params); - } else if (!params.prompt.empty()) { + } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) { ret_val = always_prompt_transcription(ctx, audio, params); } else { ret_val = process_general_transcription(ctx, audio, params); diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index b5b607fa9d0..2daaaef4504 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -413,7 +413,7 @@ namespace grammar_parser { } } - std::vector parse_state::c_rules() { + std::vector parse_state::c_rules() const{ std::vector ret; for (const auto & rule : rules) { ret.push_back(rule.data()); diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index ef0ec44174f..47d019c33e1 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -21,7 +21,7 @@ namespace grammar_parser { std::map symbol_ids; std::vector> rules; - std::vector c_rules(); + std::vector c_rules() const; }; parse_state parse(const char * src); diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 00000000000..122ce1239e2 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,27 @@ +# - bishop to c3 +# - rook to d4 +# - knight to e5 +# - d4 d5 knight to c3 +# - c3 queen to d4 king b1 +# - pawn to a1 bishop to b2 knight to c3 +# +# initial prompt: +# +# "pawn to a1, bishop to b2, knight to c3, rook to d4, queen to e5, king to f6," +# +# example: +# +# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6" +# + +root ::= init (move? move? move? ".") +prompt ::= init "." + +# leading space is very important! +init ::= " pawn knight king a1 f5 h6" + +move ::= " " ((piece | pawn | king) " " "to "?)? [a-h] [1-8] + +piece ::= "bishop" | "rook" | "knight" | "queen" +king ::= "king" +pawn ::= "pawn" diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf new file mode 100644 index 00000000000..f4a4930ca62 --- /dev/null +++ b/grammars/colors.gbnf @@ -0,0 +1,24 @@ +# - red +# - green +# - blue +# - red green +# - red blue +# - green red +# - green blue green +# +# initial prompt: +# +# "red green blue" +# +# example: +# +# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue" +# + +root ::= init color (color)? (color)? "." +prompt ::= init "." + +# leading space is very important! +init ::= " red green blue" + +color ::= " " ("red" | "green" | "blue") diff --git a/whisper.cpp b/whisper.cpp index 5e3b86a88b8..c357994b7f5 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3865,28 +3865,29 @@ static struct whisper_grammar whisper_grammar_init( static void whisper_suppress_invalid_grammar( whisper_context & ctx, const whisper_full_params & params, - std::vector & logprobs, + std::vector & logits, const whisper_grammar & grammar) { if (grammar.rules.empty() || grammar.stacks.empty()) { return; } - // bool allow_eot = false; - // for (const auto & stack : grammar.stacks) { - // if (stack.empty()) { - // allow_eot = true; - // break; - // } - // } + bool allow_eot = false; + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + allow_eot = true; + break; + } + } + + const whisper_token eot = whisper_token_eot(&ctx); std::vector, whisper_partial_utf8>> candidates_decoded; std::vector candidates_grammar; - size_t size = logprobs.size(); - for (whisper_token id = 0; id < (int) size; ++id) { + for (whisper_token id = 0; id < eot; ++id) { const std::string & text = ctx.vocab.id_to_token[id]; - if (!text.empty() && text.rfind("[_", 0) != 0) { + if (!text.empty()) { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } @@ -3895,14 +3896,12 @@ static void whisper_suppress_invalid_grammar( const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - logprobs[reject.id] -= params.grammar_penalty; + logits[reject.id] -= params.grammar_penalty; } - // when the grammar does not allow any continuation, we don't want to penalize the EOT token - // TODO: is there are better way to do this? - printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); - if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { - logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; + // when the grammar allows a continuation, we penalize the end-of-text token + if (!allow_eot) { + logits[eot] -= params.grammar_penalty; } //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } @@ -3912,7 +3911,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar return; } - fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); const std::string & text = ctx.vocab.id_to_token[token]; @@ -4308,14 +4307,28 @@ static void whisper_process_logits( logits[i] = -INFINITY; logprobs[i] = -INFINITY; } - } else { - //printf("sampling text\n"); - for (int i = vocab.token_beg; i < n_logits; ++i) { - logits[i] = -INFINITY; - logprobs[i] = -INFINITY; - } + } else if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; - whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } } } } @@ -4331,7 +4344,7 @@ static void whisper_process_logits( } } -#if 1 +#if 0 // print first 100 logits - token string : logit //for (int i = 0; i < 10; i++) { // const auto token = vocab.id_to_token.at(i); From 7a2abb311db8cc2c373eb606947b8ac8ebfa39df Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Sep 2023 20:24:58 +0300 Subject: [PATCH 6/8] grammars : add assistant + update comments --- examples/command/command.cpp | 4 +-- grammars/assistant.gbnf | 57 ++++++++++++++++++++++++++++++++++++ grammars/chess.gbnf | 4 --- grammars/colors.gbnf | 4 --- 4 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 grammars/assistant.gbnf diff --git a/examples/command/command.cpp b/examples/command/command.cpp index dfffbda72bd..ef5b09be961 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -137,6 +137,7 @@ std::string transcribe( t_ms = 0; whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); wparams.print_progress = false; wparams.print_special = params.print_special; @@ -149,9 +150,6 @@ std::string transcribe( wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; - // disable fallback - seems not useful for command recognition - wparams.temperature_inc = 0.00f; - wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; diff --git a/grammars/assistant.gbnf b/grammars/assistant.gbnf new file mode 100644 index 00000000000..c7432672008 --- /dev/null +++ b/grammars/assistant.gbnf @@ -0,0 +1,57 @@ +# - "turn on lights." +# - "set thermostat to 22." +# - "increase TV by 10." +# - "decrease oven by 50." +# - "play music." +# - "stop podcast." +# - "schedule cleaning at 3pm." +# - "cancel cleaning." +# - "remind me to buy milk at 5pm." +# - "show me security system." +# - "hide washing machine." +# - "what is the lights status?" +# - "what is the current thermostat value?" +# - "what is the security system status?" +# - "what is the door lock status?" +# - "what is the camera battery level?" +# - "what is the weather like today?" +# - "what is the forecast for tomorrow?" +# - "what is the time?" +# - "what is my schedule for today?" +# - "what tasks do I have?" +# - "what reminders do I have?" +# +# example: +# +# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "Ok Whisper, start listening for commands" +# + +root ::= init " " (command | question) "." +prompt ::= init "." + +# leading space is very important! +init ::= " Ok Whisper, start listening for commands" + +command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value | + "Increase " device " by " value | "Decrease " device " by " value | + "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task | + "Remind me to " task " at " time | "Show me " device | "Hide " device + +question ::= "What is the " device " status?" | "What is the current " device " value?" | + "What is the " device " temperature?" | "What is the " device " humidity?" | + "What is the " device " power consumption?" | "What is the " device " battery level?" | + "What is the weather like today?" | "What is the forecast for tomorrow?" | + "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" | + "What reminders do I have?" + +device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" | + "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" | + "vacuum cleaner" + +value ::= [0-9]+ + +media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie" + +task ::= [a-zA-Z]+ (" " [a-zA-Z]+)? + +time ::= [0-9] [0-9]? ":" [0-9] [0-9] ("am" | "pm")? diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf index 122ce1239e2..9a3ca1d5ff5 100644 --- a/grammars/chess.gbnf +++ b/grammars/chess.gbnf @@ -5,10 +5,6 @@ # - c3 queen to d4 king b1 # - pawn to a1 bishop to b2 knight to c3 # -# initial prompt: -# -# "pawn to a1, bishop to b2, knight to c3, rook to d4, queen to e5, king to f6," -# # example: # # ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6" diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf index f4a4930ca62..039c83d96fa 100644 --- a/grammars/colors.gbnf +++ b/grammars/colors.gbnf @@ -6,10 +6,6 @@ # - green red # - green blue green # -# initial prompt: -# -# "red green blue" -# # example: # # ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue" From 37de5dcf2bd8181f4643c2d44980484bed61a966 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Sep 2023 12:22:57 +0300 Subject: [PATCH 7/8] command : enable beam-search, add "no_timestamps", add "context", add p --- .gitignore | 1 + examples/command/command.cpp | 70 +++++++++++++++++++++++------------- grammars/assistant.gbnf | 8 ++--- grammars/chess.gbnf | 14 +++++--- grammars/colors.gbnf | 12 +++---- whisper.cpp | 45 ++++++++++++++++------- whisper.h | 1 + 7 files changed, 98 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index a1adabaf40b..b30a1d19f01 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ build/ build-em/ build-debug/ build-release/ +build-rwdi/ build-static/ build-cublas/ build-no-accel/ diff --git a/examples/command/command.cpp b/examples/command/command.cpp index ef5b09be961..d5612bd0a00 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -54,6 +54,7 @@ struct whisper_params { std::string fname_out; std::string commands; std::string prompt; + std::string context; std::string grammar; }; @@ -84,6 +85,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else { @@ -119,6 +121,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str()); + fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, "\n"); @@ -129,15 +132,19 @@ std::string transcribe( const whisper_params & params, const std::vector & pcmf32, const std::string & grammar_rule, - float & prob, + float & logprob_min, + float & logprob_sum, + int & n_tokens, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); - prob = 0.0f; + logprob_min = 0.0f; + logprob_sum = 0.0f; + n_tokens = 0; t_ms = 0; - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); + //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); wparams.print_progress = false; wparams.print_special = params.print_special; @@ -145,6 +152,7 @@ std::string transcribe( wparams.print_timestamps = !params.no_timestamps; wparams.translate = params.translate; wparams.no_context = true; + wparams.no_timestamps = params.no_timestamps; wparams.single_segment = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); @@ -153,12 +161,18 @@ std::string transcribe( wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; - //wparams.initial_prompt = params.prompt.data(); + wparams.temperature = 0.4f; + wparams.temperature_inc = 1.0f; + wparams.greedy.best_of = 5; + + wparams.beam_search.beam_size = 5; + + wparams.initial_prompt = params.context.data(); const auto & grammar_parsed = params.grammar_parsed; auto grammar_rules = grammar_parsed.c_rules(); - if (!params.grammar_parsed.rules.empty()) { + if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) { wparams.grammar_rules = grammar_rules.data(); wparams.n_grammar_rules = grammar_rules.size(); wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule); @@ -169,7 +183,6 @@ std::string transcribe( return ""; } - int prob_n = 0; std::string result; const int n_segments = whisper_full_n_segments(ctx); @@ -178,19 +191,17 @@ std::string transcribe( result += text; - const int n_tokens = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n_tokens; ++j) { + const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { const auto token = whisper_full_get_token_data(ctx, i, j); - prob += token.p; - ++prob_n; + if(token.plog > 0.0f) exit(0); + logprob_min = std::min(logprob_min, token.plog); + logprob_sum += token.plog; + ++n_tokens; } } - if (prob_n > 0) { - prob /= prob_n; - } - const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); @@ -449,7 +460,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi bool is_running = true; bool ask_prompt = true; - float prob = 0.0f; + float logprob_min = 0.0f; + float logprob_sum = 0.0f; + int n_tokens = 0; std::vector pcmf32_cur; @@ -487,7 +500,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // detect the commands audio.get(params.command_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms)); const auto words = get_words(txt); @@ -528,8 +541,14 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au bool have_prompt = false; bool ask_prompt = true; - float prob0 = 0.0f; - float prob = 0.0f; + float logprob_min0 = 0.0f; + float logprob_min = 0.0f; + + float logprob_sum0 = 0.0f; + float logprob_sum = 0.0f; + + int n_tokens0 = 0; + int n_tokens = 0; std::vector pcmf32_cur; std::vector pcmf32_prompt; @@ -570,9 +589,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // wait for activation phrase audio.get(params.prompt_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob0, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms)); + + const float p = 100.0f * std::exp(logprob_min0); - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p); const float sim = similarity(txt, k_prompt); @@ -602,9 +623,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms)); - prob = 100.0f*(prob - prob0); + //const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0)); + const float p = 100.0f * std::exp(logprob_min); //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); @@ -628,7 +650,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au } } - fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str()); + fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); if (best_len == 0) { fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); } else { diff --git a/grammars/assistant.gbnf b/grammars/assistant.gbnf index c7432672008..c445778a11d 100644 --- a/grammars/assistant.gbnf +++ b/grammars/assistant.gbnf @@ -23,14 +23,14 @@ # # example: # -# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "Ok Whisper, start listening for commands" +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10 # root ::= init " " (command | question) "." -prompt ::= init "." +prompt ::= init # leading space is very important! -init ::= " Ok Whisper, start listening for commands" +init ::= " Ok Whisper, start listening for commands." command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value | "Increase " device " by " value | "Decrease " device " by " value | @@ -54,4 +54,4 @@ media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie" task ::= [a-zA-Z]+ (" " [a-zA-Z]+)? -time ::= [0-9] [0-9]? ":" [0-9] [0-9] ("am" | "pm")? +time ::= [0-9] [0-9]? ("am" | "pm")? diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf index 9a3ca1d5ff5..ec8c8423c85 100644 --- a/grammars/chess.gbnf +++ b/grammars/chess.gbnf @@ -5,18 +5,24 @@ # - c3 queen to d4 king b1 # - pawn to a1 bishop to b2 knight to c3 # +# The prompt (--prompt) is the initial phrase that the user has to say. +# This is used to prime Whisper with how the user is expected to speak. +# +# Provide long context (--context) with sample moves to help Whisper decode the correct sequence. +# Longer context is better, but it slightly increases the processing time. +# # example: # -# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6" +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100 # -root ::= init (move? move? move? ".") +root ::= init move move? move? "." prompt ::= init "." # leading space is very important! -init ::= " pawn knight king a1 f5 h6" +init ::= " rook to b4, f3" -move ::= " " ((piece | pawn | king) " " "to "?)? [a-h] [1-8] +move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8] piece ::= "bishop" | "rook" | "knight" | "queen" king ::= "king" diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf index 039c83d96fa..1d9945054b0 100644 --- a/grammars/colors.gbnf +++ b/grammars/colors.gbnf @@ -1,20 +1,16 @@ # - red # - green # - blue -# - red green -# - red blue -# - green red -# - green blue green # # example: # -# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue" +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue," # -root ::= init color (color)? (color)? "." +root ::= init color "." prompt ::= init "." # leading space is very important! -init ::= " red green blue" +init ::= " red, green, blue" -color ::= " " ("red" | "green" | "blue") +color ::= ", " ("red" | "green" | "blue") diff --git a/whisper.cpp b/whisper.cpp index c357994b7f5..0f5fef83be3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar( return; } - bool allow_eot = false; - for (const auto & stack : grammar.stacks) { - if (stack.empty()) { - allow_eot = true; - break; - } - } + //bool allow_eot = false; + //for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + //} const whisper_token eot = whisper_token_eot(&ctx); @@ -3900,9 +3900,9 @@ static void whisper_suppress_invalid_grammar( } // when the grammar allows a continuation, we penalize the end-of-text token - if (!allow_eot) { - logits[eot] -= params.grammar_penalty; - } + //if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } @@ -3955,6 +3955,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.translate =*/ false, /*.no_context =*/ true, + /*.no_timestamps =*/ false, /*.single_segment =*/ false, /*.print_special =*/ false, /*.print_progress =*/ true, @@ -4170,6 +4171,11 @@ static void whisper_process_logits( // suppress <|notimestamps|> token // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; @@ -4515,8 +4521,11 @@ static std::vector whisper_sample_token_topk( ptsum = sum_ts; } + std::discrete_distribution<> dist(probs.begin(), probs.end()); + for (int i = 0; i < k; ++i) { - const auto id = logits_id[i].second; + const auto id = dist(state.rng); + //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4726,7 +4735,7 @@ int whisper_full_with_state( state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx) }; + std::vector prompt_init = { whisper_token_sot(ctx), }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; @@ -4737,6 +4746,9 @@ int whisper_full_with_state( prompt_init.push_back(whisper_token_transcribe(ctx)); } } + if (params.no_timestamps) { + prompt_init.push_back(whisper_token_not(ctx)); + } int seek = seek_start; @@ -4821,7 +4833,7 @@ int whisper_full_with_state( n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -4978,8 +4990,15 @@ int whisper_full_with_state( continue; } + if (cur_c >= beam_candidates.size()) { + cur_c = 0; + } + auto & cur = beam_candidates[cur_c++]; + // TODO: test if this is better: + //while (beam_candidates.size() > cur_c && i > 0) { + while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { ++cur_c; } diff --git a/whisper.h b/whisper.h index 23f61ed5b06..fe50e73fb2b 100644 --- a/whisper.h +++ b/whisper.h @@ -389,6 +389,7 @@ extern "C" { bool translate; bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps bool single_segment; // force single segment output (useful for streaming) bool print_special; // print special tokens (e.g. , , , etc.) bool print_progress; // print progress information From 3c50be2217cc413db3beb636710de31adc37d022 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Sep 2023 13:27:06 +0300 Subject: [PATCH 8/8] whisper : remove comment --- whisper.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 0f5fef83be3..4753232b6ad 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -4996,9 +4996,6 @@ int whisper_full_with_state( auto & cur = beam_candidates[cur_c++]; - // TODO: test if this is better: - //while (beam_candidates.size() > cur_c && i > 0) { - while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { ++cur_c; }