diff --git a/chat.cpp b/chat.cpp index 4e9de5811f6ec..7c0bb1ba2bd50 100644 --- a/chat.cpp +++ b/chat.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -320,7 +321,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab fin.close(); std::vector tmp; - + for (int i = 0; i < n_parts; ++i) { const int part_id = i; //const int part_id = n_parts - i - 1; @@ -825,7 +826,7 @@ int main(int argc, char ** argv) { // load the model { const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { + if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -946,7 +947,20 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } - + // remove token 1 from prompt_inp, and the last 2 tokens (line breaks) for each of prompt_inp and response_inp + std::vector prompt_inp_sequence; + std::vector response_inp_sequence; + prompt_inp_sequence = std::vector(prompt_inp.begin() + 1, prompt_inp.end() - 2); + response_inp_sequence = std::vector(response_inp.begin(), response_inp.end() - 2); + + + std::vector> instruct_sequences = { // token sequences to look for in output + prompt_inp_sequence, + response_inp_sequence + }; + + // Initialize a vector to track the progress of each target sequence + std::vector instruct_indices(instruct_sequences.size(), 0); while (remaining_tokens > 0) { // predict @@ -1016,10 +1030,51 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { - for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + std::deque output_buffer; + bool sequence_found = false; + + size_t i = 0; + while (i < embd.size() && !sequence_found) { + gpt_vocab::id id = embd[i]; + bool handled = false; + + size_t j = 0; + while (j < instruct_sequences.size() && !handled) { + if (id == instruct_sequences[j][instruct_indices[j]]) { + instruct_indices[j]++; + handled = true; + if (instruct_indices[j] == instruct_sequences[j].size()) { // If we have found the full instruct_sequence stop printing + i += instruct_sequences[j].size() - 1; // Skip the rest of the found target sequence + sequence_found = true; + is_interacting = true; + break; + } + } else if (instruct_indices[j] > 0) { + // Handle partial match cases + i -= instruct_indices[j] - 1; // Move back by (instruct_indices[j] - 1) steps + instruct_indices[j] = 0; + handled = true; + break; + } else { + j++; // Increment 'j' only when no match or partial match is found + } + } + + if (!handled) { + output_buffer.push_back(vocab.id_to_token[id]); + } + i++; + fflush(stdout); + } + + // Flush the remaining elements in the buffer + while (!output_buffer.empty()) { + std::string output = output_buffer.front(); + output_buffer.pop_front(); + printf("%s", output.c_str()); } fflush(stdout); + } // in interactive mode, and not currently processing queued inputs; @@ -1035,7 +1090,11 @@ int main(int argc, char ** argv) { // embd_inp.erase(embd_inp.begin()); input_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end()); - + + // clear all indices + for (size_t j = 0; j < instruct_indices.size(); ++j) { + instruct_indices[j] = 0; + } printf("\n> ");