diff --git a/examples/chat.sh b/examples/chat.sh index 9a928ef05431a..bd48c83034acf 100755 --- a/examples/chat.sh +++ b/examples/chat.sh @@ -11,6 +11,8 @@ cd .. # # "--keep 48" is based on the contents of prompts/chat-with-bob.txt # -./main -m ./models/7B/ggml-model-q4_0.bin -c 512 -b 1024 -n 256 --keep 48 \ - --repeat_penalty 1.0 --color -i \ - -r "User:" -f prompts/chat-with-bob.txt +./main -m ./models/7B/ggml-model-q4_0.bin -c 512 -b 1024 -n -1 --keep 48 \ + --repeat_penalty 1.0 --color \ + -i --interactive-first \ + -r "User:" --in-prefix " " \ + -f prompts/chat-with-bob.txt diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0a22f3c25ff46..5a2480e420b5e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -631,6 +631,16 @@ int main(int argc, char ** argv) { llama_grammar_accept_token(ctx, grammar, id); } + // replace end of text token with newline token and inject reverse prompt when in interactive mode + if (id == llama_token_eos() && params.interactive && !params.instruct && !params.input_prefix_bos) { + id = llama_token_nl(); + if (params.antiprompt.size() != 0) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + } + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -712,8 +722,8 @@ int main(int argc, char ** argv) { is_interacting = true; printf("\n"); - console::set_display(console::user_input); fflush(stdout); + console::set_display(console::user_input); } else if (params.instruct) { is_interacting = true; } @@ -722,6 +732,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { if (params.instruct) { printf("\n> "); + fflush(stdout); } if (params.input_prefix_bos) { @@ -732,6 +743,7 @@ int main(int argc, char ** argv) { if (!params.input_prefix.empty()) { buffer += params.input_prefix; printf("%s", buffer.c_str()); + fflush(stdout); } std::string line; @@ -751,6 +763,7 @@ int main(int argc, char ** argv) { if (!params.input_suffix.empty()) { buffer += params.input_suffix; printf("%s", params.input_suffix.c_str()); + fflush(stdout); } // instruct mode: insert instruction prefix