Skip to content

Never exit the main loop in interactive mode. #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 57 additions & 35 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);

std::mt19937 rng(params.seed);
if (params.prompt.empty()) {
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}

Expand Down Expand Up @@ -850,7 +850,11 @@ int main(int argc, char ** argv) {
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());

// tokenize the reverse prompt
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);
std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;

for (auto antiprompt : params.antiprompt) {
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
}

fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
Expand All @@ -872,13 +876,16 @@ int main(int argc, char ** argv) {

fprintf(stderr, "%s: interactive mode on.\n", __func__);

if(antiprompt_inp.size()) {
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
if(antipromptv_inp.size()) {
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
auto antiprompt_inp = antipromptv_inp.at(apindex);
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
fprintf(stderr, "\n");
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
Expand Down Expand Up @@ -935,35 +942,37 @@ int main(int argc, char ** argv) {
embd.clear();

if (embd_inp.size() <= input_consumed) {
// out of user input, sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
if (!is_interacting) {
// out of user input, sample next token
const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;

const int n_vocab = model.hparams.n_vocab;
const int n_vocab = model.hparams.n_vocab;

gpt_vocab::id id = 0;
gpt_vocab::id id = 0;

{
const int64_t t_start_sample_us = ggml_time_us();
{
const int64_t t_start_sample_us = ggml_time_us();

id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);

last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);

t_sample_us += ggml_time_us() - t_start_sample_us;
}
t_sample_us += ggml_time_us() - t_start_sample_us;
}

// add it to the context
embd.push_back(id);
// add it to the context
embd.push_back(id);

// echo this to console
input_noecho = false;
// echo this to console
input_noecho = false;

// decrement remaining sampling budget
--remaining_tokens;
// decrement remaining sampling budget
--remaining_tokens;
}
} else {
// some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) {
Expand Down Expand Up @@ -994,9 +1003,12 @@ int main(int argc, char ** argv) {
// check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
// reverse prompt found
is_interacting = true;
for (auto antiprompt_inp : antipromptv_inp) {
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
// reverse prompt found
is_interacting = true;
break;
}
}
if (is_interacting) {
// currently being interactive
Expand Down Expand Up @@ -1035,10 +1047,20 @@ int main(int argc, char ** argv) {
}
}

// end of text token
if (embd.back() == 2) {
fprintf(stderr, " [end of text]\n");
break;
if (params.interactive) {
if (embd.size() && embd.back() == 2) {
is_interacting = true;
}
if (remaining_tokens == 0) {
remaining_tokens = params.n_predict;
is_interacting = true;
}
} else {
// end of text token
if (embd.size() && embd.back() == 2) {
fprintf(stderr, " [end of text]\n");
break;
}
}
}

Expand Down
10 changes: 7 additions & 3 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--color") {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt = argv[++i];
params.antiprompt.push_back(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else if (arg == "--random-prompt") {
params.random_prompt = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
Expand All @@ -93,12 +95,14 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " prompt to start generation with (default: empty)\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
Expand Down
4 changes: 3 additions & 1 deletion utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt;

bool random_prompt = false;

bool use_color = false; // use color to distinguish generations and inputs

bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
};

bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
Expand Down