@@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
325
325
double nll = 0.0 ;
326
326
double nll2 = 0.0 ;
327
327
328
+ const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
329
+
330
+ std::vector<float > logits;
331
+ if (num_batches > 1 ) {
332
+ logits.reserve ((size_t )n_ctx * n_vocab);
333
+ }
334
+
328
335
fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
329
336
330
337
std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
@@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
333
340
const int start = i * n_ctx;
334
341
const int end = start + n_ctx;
335
342
336
- const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
337
-
338
- std::vector<float > logits;
339
-
340
343
const auto t_start = std::chrono::high_resolution_clock::now ();
341
344
342
345
// clear the KV cache
@@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
362
365
// restore the original token in case it was set to BOS
363
366
tokens[batch_start] = token_org;
364
367
365
- const auto * batch_logits = llama_get_logits (ctx);
366
- logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
368
+ if (num_batches > 1 ) {
369
+ const auto * batch_logits = llama_get_logits (ctx);
370
+ logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
371
+ }
367
372
}
368
373
369
374
const auto t_end = std::chrono::high_resolution_clock::now ();
@@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
392
397
// last 256 tokens. Then, we split the input up into context window size chunks to
393
398
// process the entire prompt.
394
399
const int first = n_ctx/2 ;
395
- process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
400
+ const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
401
+ process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
396
402
workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
397
403
count += n_ctx - first - 1 ;
398
404
@@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
406
412
printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
407
413
}
408
414
fflush (stdout);
415
+
416
+ logits.clear ();
409
417
}
410
418
printf (" \n " );
411
419
0 commit comments