From 90bbce242a73895131ef6831bc1f56bad46353a2 Mon Sep 17 00:00:00 2001 From: PXLKSR Date: Fri, 31 Mar 2023 00:15:06 +0200 Subject: [PATCH 1/3] Prevent instruction prompt and response from leaking into output. --- chat.cpp | 49 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/chat.cpp b/chat.cpp index 4e9de5811f6ec..2453cabe761b9 100644 --- a/chat.cpp +++ b/chat.cpp @@ -320,7 +320,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab fin.close(); std::vector tmp; - + for (int i = 0; i < n_parts; ++i) { const int part_id = i; //const int part_id = n_parts - i - 1; @@ -825,7 +825,7 @@ int main(int argc, char ** argv) { // load the model { const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { + if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -946,7 +946,17 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } - + // remove token 1 from prompt_inp + std::vector prompt_inp_sequence; + prompt_inp_sequence = std::vector(prompt_inp.begin() + 1, prompt_inp.end()); + + std::vector> instruct_sequences = { // token sequences to look for in output + prompt_inp_sequence, + response_inp + }; + + // Initialize a vector to track the progress of each target sequence + std::vector instruct_indices(instruct_sequences.size(), 0); while (remaining_tokens > 0) { // predict @@ -1016,10 +1026,35 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + for (size_t i = 0; i < embd.size(); ++i) { + gpt_vocab::id id = embd[i]; + bool sequence_found = false; + + for (size_t j = 0; j < instruct_sequences.size(); ++j) { + if (id == instruct_sequences[j][instruct_indices[j]]) { + instruct_indices[j]++; + if (instruct_indices[j] == instruct_sequences[j].size()) { // If we have found the full instruct_sequence stop printing + printf("\n"); + is_interacting = true; + i += instruct_sequences[j].size() - 1; // Skip the rest of the found target sequence + sequence_found = true; + continue; + } + } else { + // Handle partial match cases + if (instruct_indices[j] > 0) { + i -= instruct_indices[j]; // Move back by instruct_indices[j] steps + instruct_indices[j] = 0; + break; + } + } + } + + if (!sequence_found) { + printf("%s[%i] ", vocab.id_to_token[id].c_str(), id); + } + fflush(stdout); } - fflush(stdout); } // in interactive mode, and not currently processing queued inputs; @@ -1035,7 +1070,7 @@ int main(int argc, char ** argv) { // embd_inp.erase(embd_inp.begin()); input_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end()); - + printf("\n> "); From 3bea6813de86f4cafa72c242890f3074b06ae14b Mon Sep 17 00:00:00 2001 From: PXLKSR Date: Fri, 31 Mar 2023 00:28:21 +0200 Subject: [PATCH 2/3] remove debug output --- chat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.cpp b/chat.cpp index 2453cabe761b9..539f5a1c18806 100644 --- a/chat.cpp +++ b/chat.cpp @@ -1051,7 +1051,7 @@ int main(int argc, char ** argv) { } if (!sequence_found) { - printf("%s[%i] ", vocab.id_to_token[id].c_str(), id); + printf("%s", vocab.id_to_token[id].c_str()); } fflush(stdout); } From 28cb8ba6aa872ece10ef5f25850bca8d09c766ae Mon Sep 17 00:00:00 2001 From: PXLKSR Date: Fri, 31 Mar 2023 03:12:42 +0200 Subject: [PATCH 3/3] fix quite a few edge cases --- chat.cpp | 58 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/chat.cpp b/chat.cpp index 539f5a1c18806..7c0bb1ba2bd50 100644 --- a/chat.cpp +++ b/chat.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -946,13 +947,16 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } - // remove token 1 from prompt_inp + // remove token 1 from prompt_inp, and the last 2 tokens (line breaks) for each of prompt_inp and response_inp std::vector prompt_inp_sequence; - prompt_inp_sequence = std::vector(prompt_inp.begin() + 1, prompt_inp.end()); + std::vector response_inp_sequence; + prompt_inp_sequence = std::vector(prompt_inp.begin() + 1, prompt_inp.end() - 2); + response_inp_sequence = std::vector(response_inp.begin(), response_inp.end() - 2); + std::vector> instruct_sequences = { // token sequences to look for in output prompt_inp_sequence, - response_inp + response_inp_sequence }; // Initialize a vector to track the progress of each target sequence @@ -1026,35 +1030,51 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - for (size_t i = 0; i < embd.size(); ++i) { + std::deque output_buffer; + bool sequence_found = false; + + size_t i = 0; + while (i < embd.size() && !sequence_found) { gpt_vocab::id id = embd[i]; - bool sequence_found = false; + bool handled = false; - for (size_t j = 0; j < instruct_sequences.size(); ++j) { + size_t j = 0; + while (j < instruct_sequences.size() && !handled) { if (id == instruct_sequences[j][instruct_indices[j]]) { instruct_indices[j]++; + handled = true; if (instruct_indices[j] == instruct_sequences[j].size()) { // If we have found the full instruct_sequence stop printing - printf("\n"); - is_interacting = true; i += instruct_sequences[j].size() - 1; // Skip the rest of the found target sequence sequence_found = true; - continue; - } - } else { - // Handle partial match cases - if (instruct_indices[j] > 0) { - i -= instruct_indices[j]; // Move back by instruct_indices[j] steps - instruct_indices[j] = 0; + is_interacting = true; break; } + } else if (instruct_indices[j] > 0) { + // Handle partial match cases + i -= instruct_indices[j] - 1; // Move back by (instruct_indices[j] - 1) steps + instruct_indices[j] = 0; + handled = true; + break; + } else { + j++; // Increment 'j' only when no match or partial match is found } } - if (!sequence_found) { - printf("%s", vocab.id_to_token[id].c_str()); + if (!handled) { + output_buffer.push_back(vocab.id_to_token[id]); } + i++; fflush(stdout); } + + // Flush the remaining elements in the buffer + while (!output_buffer.empty()) { + std::string output = output_buffer.front(); + output_buffer.pop_front(); + printf("%s", output.c_str()); + } + fflush(stdout); + } // in interactive mode, and not currently processing queued inputs; @@ -1071,6 +1091,10 @@ int main(int argc, char ** argv) { input_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end()); + // clear all indices + for (size_t j = 0; j < instruct_indices.size(); ++j) { + instruct_indices[j] = 0; + } printf("\n> ");