Skip to content

Commit f609ff4

Browse files
committed
Refactor interactive mode in main.cpp
1 parent cb58437 commit f609ff4

File tree

1 file changed

+96
-69
lines changed

1 file changed

+96
-69
lines changed

main.cpp

+96-69
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ void sigint_handler(int signo) {
5555
#endif
5656

5757

58+
void process_interactive_input(llama_context& ctx, const gpt_params& params);
59+
5860
int main(int argc, char ** argv) {
5961
ggml_time_init();
6062
const int64_t t_main_start_us = ggml_time_us();
@@ -85,15 +87,18 @@ int main(int argc, char ** argv) {
8587
// params.prompt = R"(// this function checks if the number n is prime
8688
//bool is_prime(int n) {)";
8789

88-
int64_t t_load_us = 0;
89-
9090
// load the model
91-
llama_context* ctx_ptr = llama_init_from_params(params);
91+
llama_context* ctx_ptr = nullptr;
92+
{
93+
ctx_ptr = llama_init_from_params(params);
94+
if (!ctx_ptr) {
95+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
96+
return 1;
97+
}
98+
}
99+
92100
llama_context & ctx = *ctx_ptr;
93-
gpt_vocab & vocab = llama_context_get_vocab(ctx);
94-
95-
// print system information
96-
llama_print_context_info(ctx);
101+
const gpt_vocab & vocab = llama_context_get_vocab(ctx);
97102

98103
// Add a space in front of the first character to match OG llama tokenizer behavior
99104
params.prompt.insert(0, 1, ' ');
@@ -109,8 +114,13 @@ int main(int argc, char ** argv) {
109114
}
110115

111116
// tokenize the reverse prompt
112-
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.prompt);
117+
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt);
118+
<<<<<<< HEAD
119+
=======
120+
121+
>>>>>>> b30724a (Fix main)
113122

123+
// Setup interactive mode
114124
if (params.interactive) {
115125
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
116126
struct sigaction sigint_action;
@@ -146,50 +156,56 @@ int main(int argc, char ** argv) {
146156
is_interacting = true;
147157
}
148158

149-
bool input_noecho = false;
150-
151-
int remaining_tokens = params.n_predict;
159+
// prompt user immediately after the starting prompt has been loaded
160+
if (params.interactive_start) {
161+
is_interacting = true;
162+
}
152163

153164
// set the color for the prompt which will be output initially
154165
if (params.use_color) {
155166
printf(ANSI_COLOR_YELLOW);
156167
}
157168

158-
if(!llama_ingest_input(ctx, params.prompt))
169+
// Prepare the context with input
170+
// Send "beginning of string"
171+
llama_add_bos(ctx);
172+
173+
// load the input
174+
llama_update_input(ctx, params.prompt);
175+
176+
llama_print_startup_stats(ctx);
177+
178+
if(!llama_prepare_context(ctx))
159179
{
160-
fprintf(stderr, "Failed to ingest prompt\n");
180+
fprintf(stderr, "%s: failed to prepare context\n", __func__);
161181
return 1;
162-
};
163-
164-
// display text
165-
input_noecho = false;
166-
const std::vector<gpt_vocab::id>& embd = llama_context_get_embedding(ctx);
167-
if (!input_noecho) {
168-
for (auto id : embd) {
169-
printf("%s", vocab.id_to_token[id].c_str());
170-
}
171-
fflush(stdout);
172-
}
173-
174-
if (!input_noecho && params.use_color) {
175-
printf(ANSI_COLOR_RESET);
176182
}
177183

178-
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
179-
180-
while (llama_context_is_finished(ctx) != true) {
181-
gpt_vocab::id model_output = 0;
182-
bool response = llama_infer(ctx, model_output);
183-
if (response) {
184-
printf("%s", vocab.id_to_token[model_output].c_str());
185-
fflush(stdout);
184+
bool input_noecho = false;
185+
bool is_end_of_text = false;
186+
while (llama_context_is_finished(ctx) == false) {
187+
std::string model_output{};
188+
189+
if (llama_has_unconsumed_input(ctx)) {
190+
llama_ingest_all_pending_input(ctx, !input_noecho);
191+
// reset color to default if we there is no pending user input
192+
if (!input_noecho && params.use_color) {
193+
printf(ANSI_COLOR_RESET);
194+
}
195+
}else{
196+
// Run inference if we don't have any pending input
197+
llama_infer(ctx, model_output, is_end_of_text);
198+
// print the single token output
199+
printf("%s", model_output.c_str());
200+
input_noecho = false;
186201
}
187202

188203
// in interactive mode, and not currently processing queued inputs;
189204
// check if we should prompt the user for more
190-
if (params.interactive) {
205+
if (params.interactive && !llama_has_unconsumed_input(ctx)) {
206+
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
191207
// check for reverse prompt
192-
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
208+
if (antiprompt_inp.size() && llama_is_anti_prompt_present(ctx, antiprompt_inp)) {
193209
// reverse prompt found
194210
is_interacting = true;
195211
}
@@ -202,38 +218,14 @@ int main(int argc, char ** argv) {
202218
}
203219

204220
// currently being interactive
205-
bool another_line = true;
206-
while (another_line) {
207-
fflush(stdout);
208-
char buf[256] = {0};
209-
int n_read;
210-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
211-
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
212-
// presumable empty line, consume the newline
213-
std::ignore = scanf("%*c");
214-
n_read=0;
215-
}
216-
if (params.use_color) printf(ANSI_COLOR_RESET);
217-
218-
if (n_read > 0 && buf[n_read-1]=='\\') {
219-
another_line = true;
220-
buf[n_read-1] = '\n';
221-
buf[n_read] = 0;
222-
} else {
223-
another_line = false;
224-
buf[n_read] = '\n';
225-
buf[n_read+1] = 0;
226-
}
227-
// Do not clear existing context in interactive mode
228-
llama_update_context_with_prompt(ctx, buf, false);
229-
}
230-
221+
process_interactive_input(ctx, params);
222+
input_noecho = true; // do not echo this input again
231223
is_interacting = false;
232224
}
233225
}
234226

235227
// end of text token
236-
if (embd.back() == EOS_TOKEN_ID) {
228+
if (is_end_of_text) {
237229
if (params.interactive) {
238230
is_interacting = true;
239231
} else {
@@ -243,23 +235,58 @@ int main(int argc, char ** argv) {
243235
}
244236

245237
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
246-
if (params.interactive && remaining_tokens <= 0) {
247-
remaining_tokens = params.n_predict;
238+
if (params.interactive && llama_context_is_finished(ctx)) {
239+
llama_context_reset_remaining_tokens(ctx)
248240
is_interacting = true;
249241
}
250242
}
251243

252-
// report timing from context
244+
245+
#if defined (_WIN32)
246+
signal(SIGINT, SIG_DFL);
247+
#endif
248+
249+
// report timing
253250
{
254251
const int64_t t_main_end_us = ggml_time_us();
255252
llama_print_end_stats(ctx);
256253
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
257254
}
258-
llama_free_context(ctx_ptr);
255+
256+
llama_free_context(ctx_ptr);
259257

260258
if (params.use_color) {
261259
printf(ANSI_COLOR_RESET);
262260
}
263-
264261
return 0;
265262
}
263+
264+
void process_interactive_input(llama_context& ctx, const gpt_params& params)
265+
{
266+
bool another_line = true;
267+
while (another_line) {
268+
fflush(stdout);
269+
char buf[256] = {0};
270+
int n_read;
271+
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
272+
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
273+
// presumable empty line, consume the newline
274+
std::ignore = scanf("%*c");
275+
n_read=0;
276+
}
277+
if (params.use_color) printf(ANSI_COLOR_RESET);
278+
279+
if (n_read > 0 && buf[n_read-1]=='\\') {
280+
another_line = true;
281+
buf[n_read-1] = '\n';
282+
buf[n_read] = 0;
283+
} else {
284+
another_line = false;
285+
buf[n_read] = '\n';
286+
buf[n_read+1] = 0;
287+
}
288+
289+
// Do not clear existing context in interactive mode
290+
llama_update_input(ctx, buf);
291+
}
292+
}

0 commit comments

Comments
 (0)