Skip to content

Commit 993fba8

Browse files
ikawrakowIwan Kawrakow
andauthored
perplexity: avoid unnecessary alloocations and logit copies (#5035)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 8b20858 commit 993fba8

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
325325
double nll = 0.0;
326326
double nll2 = 0.0;
327327

328+
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
329+
330+
std::vector<float> logits;
331+
if (num_batches > 1) {
332+
logits.reserve((size_t)n_ctx * n_vocab);
333+
}
334+
328335
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
329336

330337
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
333340
const int start = i * n_ctx;
334341
const int end = start + n_ctx;
335342

336-
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
337-
338-
std::vector<float> logits;
339-
340343
const auto t_start = std::chrono::high_resolution_clock::now();
341344

342345
// clear the KV cache
@@ -362,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
362365
// restore the original token in case it was set to BOS
363366
tokens[batch_start] = token_org;
364367

365-
const auto * batch_logits = llama_get_logits(ctx);
366-
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
368+
if (num_batches > 1) {
369+
const auto * batch_logits = llama_get_logits(ctx);
370+
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
371+
}
367372
}
368373

369374
const auto t_end = std::chrono::high_resolution_clock::now();
@@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
392397
// last 256 tokens. Then, we split the input up into context window size chunks to
393398
// process the entire prompt.
394399
const int first = n_ctx/2;
395-
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
400+
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
401+
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
396402
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
397403
count += n_ctx - first - 1;
398404

@@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
406412
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
407413
}
408414
fflush(stdout);
415+
416+
logits.clear();
409417
}
410418
printf("\n");
411419

0 commit comments

Comments
 (0)