@@ -50,25 +50,6 @@ void sigint_handler(int signo) {
50
50
}
51
51
#endif
52
52
53
- const char * llama_print_system_info (void ) {
54
- static std::string s;
55
-
56
- s = " " ;
57
- s += " AVX = " + std::to_string (ggml_cpu_has_avx ()) + " | " ;
58
- s += " AVX2 = " + std::to_string (ggml_cpu_has_avx2 ()) + " | " ;
59
- s += " AVX512 = " + std::to_string (ggml_cpu_has_avx512 ()) + " | " ;
60
- s += " FMA = " + std::to_string (ggml_cpu_has_fma ()) + " | " ;
61
- s += " NEON = " + std::to_string (ggml_cpu_has_neon ()) + " | " ;
62
- s += " ARM_FMA = " + std::to_string (ggml_cpu_has_arm_fma ()) + " | " ;
63
- s += " F16C = " + std::to_string (ggml_cpu_has_f16c ()) + " | " ;
64
- s += " FP16_VA = " + std::to_string (ggml_cpu_has_fp16_va ()) + " | " ;
65
- s += " WASM_SIMD = " + std::to_string (ggml_cpu_has_wasm_simd ()) + " | " ;
66
- s += " BLAS = " + std::to_string (ggml_cpu_has_blas ()) + " | " ;
67
- s += " SSE3 = " + std::to_string (ggml_cpu_has_sse3 ()) + " | " ;
68
- s += " VSX = " + std::to_string (ggml_cpu_has_vsx ()) + " | " ;
69
-
70
- return s.c_str ();
71
- }
72
53
73
54
int main (int argc, char ** argv) {
74
55
ggml_time_init ();
@@ -97,49 +78,20 @@ int main(int argc, char ** argv) {
97
78
98
79
int64_t t_load_us = 0 ;
99
80
100
- gpt_vocab vocab;
101
- llama_model model;
102
-
103
81
// load the model
104
- {
105
- const int64_t t_start_us = ggml_time_us ();
106
- if (!llama_model_load (params.model , model, vocab, params.n_ctx )) {
107
- fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
108
- return 1 ;
109
- }
110
-
111
- t_load_us = ggml_time_us () - t_start_us;
112
- }
82
+ llama_context* ctx_ptr = llama_init_from_params (params);
83
+ llama_context & ctx = *ctx_ptr;
84
+ gpt_vocab & vocab = llama_context_get_vocab (ctx);
113
85
114
86
// print system information
115
- {
116
- fprintf (stderr, " \n " );
117
- fprintf (stderr, " system_info: n_threads = %d / %d | %s\n " ,
118
- params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
119
- }
120
-
121
- int n_past = 0 ;
122
-
123
- int64_t t_sample_us = 0 ;
124
- int64_t t_predict_us = 0 ;
125
-
126
- std::vector<float > logits;
87
+ llama_print_context_info (ctx);
127
88
128
89
// tokenize the prompt
129
- std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize (vocab, params.prompt , true );
130
-
131
- params.n_predict = std::min (params.n_predict , model.hparams .n_ctx - (int ) embd_inp.size ());
90
+ std::vector<gpt_vocab::id> embd_inp = llama_tokenize_text (ctx, params.prompt );
132
91
133
92
// tokenize the reverse prompt
134
- std::vector<gpt_vocab::id> antiprompt_inp = :: llama_tokenize (vocab , params.antiprompt , false );
93
+ std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text (ctx , params.prompt );
135
94
136
- fprintf (stderr, " \n " );
137
- fprintf (stderr, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
138
- fprintf (stderr, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
139
- for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
140
- fprintf (stderr, " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
141
- }
142
- fprintf (stderr, " \n " );
143
95
if (params.interactive ) {
144
96
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
145
97
struct sigaction sigint_action;
@@ -165,17 +117,6 @@ int main(int argc, char ** argv) {
165
117
fprintf (stderr, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
166
118
fprintf (stderr, " \n\n " );
167
119
168
- std::vector<gpt_vocab::id> embd;
169
-
170
- // determine the required inference memory per token:
171
- size_t mem_per_token = 0 ;
172
- llama_eval (model, params.n_threads , 0 , { 0 , 1 , 2 , 3 }, logits, mem_per_token);
173
-
174
- int last_n_size = params.repeat_last_n ;
175
- std::vector<gpt_vocab::id> last_n_tokens (last_n_size);
176
- std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
177
-
178
-
179
120
if (params.interactive ) {
180
121
fprintf (stderr, " == Running in interactive mode. ==\n "
181
122
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -185,8 +126,6 @@ int main(int argc, char ** argv) {
185
126
" - If you want to submit another line, end your input in '\\ '.\n " );
186
127
}
187
128
188
- int remaining_tokens = params.n_predict ;
189
- int input_consumed = 0 ;
190
129
bool input_noecho = false ;
191
130
192
131
// prompt user immediately after the starting prompt has been loaded
@@ -199,81 +138,39 @@ int main(int argc, char ** argv) {
199
138
printf (ANSI_COLOR_YELLOW);
200
139
}
201
140
202
- while (remaining_tokens > 0 ) {
203
- // predict
204
- if (embd.size () > 0 ) {
205
- const int64_t t_start_us = ggml_time_us ();
206
-
207
- if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
208
- fprintf (stderr, " Failed to predict\n " );
209
- return 1 ;
210
- }
211
-
212
- t_predict_us += ggml_time_us () - t_start_us;
213
- }
214
-
215
- n_past += embd.size ();
216
- embd.clear ();
217
-
218
- if (embd_inp.size () <= input_consumed) {
219
- // out of user input, sample next token
220
- const float top_k = params.top_k ;
221
- const float top_p = params.top_p ;
222
- const float temp = params.temp ;
223
- const float repeat_penalty = params.repeat_penalty ;
224
-
225
- const int n_vocab = model.hparams .n_vocab ;
226
-
227
- gpt_vocab::id id = 0 ;
228
-
229
- {
230
- const int64_t t_start_sample_us = ggml_time_us ();
231
-
232
- id = llama_sample_top_p_top_k (vocab, logits.data () + (logits.size () - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
233
-
234
- last_n_tokens.erase (last_n_tokens.begin ());
235
- last_n_tokens.push_back (id);
141
+ if (!llama_injest_input (ctx, params.prompt ))
142
+ {
143
+ fprintf (stderr, " Failed to injest prompt\n " );
144
+ return 1 ;
145
+ };
236
146
237
- t_sample_us += ggml_time_us () - t_start_sample_us;
238
- }
147
+ // display text
148
+ input_noecho = false ;
149
+ const std::vector<gpt_vocab::id>& embd = llama_context_get_embd (ctx);
150
+ if (!input_noecho) {
151
+ for (auto id : embd) {
152
+ printf (" %s" , vocab.id_to_token [id].c_str ());
153
+ }
154
+ fflush (stdout);
155
+ }
239
156
240
- // add it to the context
241
- embd.push_back (id);
157
+ if (!input_noecho && params.use_color ) {
158
+ printf (ANSI_COLOR_RESET);
159
+ }
242
160
243
- // echo this to console
244
- input_noecho = false ;
161
+ const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens (ctx);
245
162
246
- // decrement remaining sampling budget
247
- --remaining_tokens;
248
- } else {
249
- // some user input remains from prompt or interaction, forward it to processing
250
- // Copy at most n_batch elements from embd_inp to embd
251
- size_t num_copied = std::min ((size_t ) params.n_batch , embd_inp.size () - input_consumed);
252
- std::copy (embd_inp.begin () + input_consumed, embd_inp.begin () + input_consumed + num_copied, std::back_inserter (embd));
253
- input_consumed += num_copied;
254
-
255
- // Copy the last `last_n_size` elements copied into embd to last_n_tokens
256
- size_t num_copied_last_n = std::min (num_copied, (size_t ) last_n_size);
257
- last_n_tokens.erase (last_n_tokens.begin (), last_n_tokens.begin ()+num_copied_last_n);
258
- last_n_tokens.insert (last_n_tokens.end (), embd.end () - num_copied_last_n, embd.end ());
259
-
260
- // reset color to default if we there is no pending user input
261
- if (!input_noecho && params.use_color && embd_inp.size () == input_consumed) {
262
- printf (ANSI_COLOR_RESET);
263
- }
264
- }
265
-
266
- // display text
267
- if (!input_noecho) {
268
- for (auto id : embd) {
269
- printf (" %s" , vocab.id_to_token [id].c_str ());
270
- }
163
+ while (llama_context_not_finished (ctx) > 0 ) {
164
+ std::optional<gpt_vocab::id> model_output = llama_inference (ctx);
165
+ if (model_output.has_value ()) {
166
+ printf (" %s" , vocab.id_to_token [model_output.value ()].c_str ());
271
167
fflush (stdout);
272
168
}
273
169
170
+
274
171
// in interactive mode, and not currently processing queued inputs;
275
172
// check if we should prompt the user for more
276
- if (params.interactive && embd_inp. size () <= input_consumed ) {
173
+ if (params.interactive ) {
277
174
// check for reverse prompt
278
175
if (antiprompt_inp.size () && std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), last_n_tokens.rbegin ())) {
279
176
// reverse prompt found
@@ -303,13 +200,8 @@ int main(int argc, char ** argv) {
303
200
buf[n_read] = ' \n ' ;
304
201
buf[n_read+1 ] = 0 ;
305
202
}
306
-
307
- std::vector<gpt_vocab::id> line_inp = ::llama_tokenize (vocab, buf, false );
308
- embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
309
-
310
- remaining_tokens -= line_inp.size ();
311
-
312
- input_noecho = true ; // do not echo this again
203
+ // Do not clear existing context in interactive mode
204
+ llama_init_context_with_prompt (ctx, buf, false );
313
205
}
314
206
315
207
is_interacting = false ;
@@ -322,24 +214,14 @@ int main(int argc, char ** argv) {
322
214
break ;
323
215
}
324
216
}
325
-
326
- #if defined (_WIN32)
327
- signal (SIGINT, SIG_DFL);
328
- #endif
329
-
330
- // report timing
217
+
218
+ // report timing from context
331
219
{
332
220
const int64_t t_main_end_us = ggml_time_us ();
333
-
334
- fprintf (stderr, " \n\n " );
335
- fprintf (stderr, " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
336
- fprintf (stderr, " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
337
- fprintf (stderr, " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
338
- fprintf (stderr, " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
221
+ llama_print_end_stats (ctx);
339
222
fprintf (stderr, " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
340
223
}
341
-
342
- ggml_free (model.ctx );
224
+ llama_free_context (ctx_ptr);
343
225
344
226
if (params.use_color ) {
345
227
printf (ANSI_COLOR_RESET);
0 commit comments