Skip to content

Commit 4a368ed

Browse files
committed
Exhaustively test skipping attn and MLP layers
Cover an extra alloc case where skipping could fail
1 parent e58816b commit 4a368ed

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
324324
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
325325

326326
const int32_t n_layers = 26;
327-
const int test_count = 15;
327+
const int test_count = 10;
328328
// 1 = attn, 2 = mlp, 3 = both
329-
int32_t test_skip_type = 1;
329+
int32_t test_skip_type = 0;
330330
std::vector<int32_t> layers;
331331
layers.resize(n_layers + 1);
332332
std::fill(layers.begin(), layers.end(), 0);
333333
batch.run_layers = layers.data();
334334
int32_t skip_layer = -1;
335335
std::vector<int32_t> skips;
336-
int32_t curr_best_layer = -1;
336+
std::vector<int32_t> skip_types;
337+
skip_types.resize(n_layers);
338+
std::fill(skip_types.begin(), skip_types.end(), 0);
339+
int32_t curr_best_layer = -1, curr_best_type = 0;
337340
double curr_best_ppl = -1, ref_ppl = -1;
338341

339342
int count = 0;
@@ -343,32 +346,47 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
343346
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
344347

345348
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
349+
static const char * label = "?AMB";
346350

347351
auto test_t_start = std::chrono::high_resolution_clock::now();
348352
for (int i = 0; i < n_chunk; ++i) {
349353
if (i > 0 && i % test_count == 0) {
350354
auto test_t_end = std::chrono::high_resolution_clock::now();
351355
float test_t_total = std::chrono::duration<float>(test_t_end - test_t_start).count();
352-
for (int32_t new_sl = std::max(0, skip_layer + 1); new_sl <= n_layers ; new_sl++) {
353-
if (std::find(skips.begin(), skips.end(), new_sl) != skips.end()) continue;
356+
357+
skip_layer = n_layers;
358+
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
359+
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
360+
if (curr_skipped == 3) continue; // Already tested or perm skip.
354361
skip_layer = new_sl;
362+
test_skip_type = (curr_skipped & 1) != 0 ? 2 : 1;
355363
break;
356364
}
357365
if (skip_layer >= n_layers) {
358366
if (curr_best_layer == -1) break;
359-
printf("\n\nADD SKIP %3d - ppl vs ref %.4f", curr_best_layer, curr_best_ppl - ref_ppl);
367+
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
368+
int(label[curr_best_type]), curr_best_layer,
369+
curr_best_ppl - ref_ppl);
360370
if (curr_best_ppl >= ref_ppl * 5) break;
361-
skips.push_back(curr_best_layer);
371+
skip_types[curr_best_layer] += curr_best_type;
372+
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
373+
skips.push_back(curr_best_layer);
374+
}
375+
for (int i = 0; i < n_layers; i++) skip_types[i] &= 3;
362376
curr_best_layer = -1;
363377
curr_best_ppl = -1;
364-
skip_layer = -1;
365-
for (int32_t new_sl = skip_layer + 1; new_sl <= n_layers; new_sl++) {
366-
if (std::find(skips.begin(), skips.end(), new_sl) != skips.end()) continue;
378+
curr_best_type = 0;
379+
skip_layer = n_layers;
380+
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
381+
skip_types[new_sl] &= 3;
382+
if (skip_types[new_sl] == 3) continue; // Already tested or perm skip.
367383
skip_layer = new_sl;
384+
test_skip_type = (skip_types[new_sl] & 1) != 0 ? 2 : 1;
368385
break;
369386
}
370387
if (skip_layer == -1 || skip_layer == n_layers) break;
371388
}
389+
372390
i = 0;
373391
count = 0;
374392
nll = 0;
@@ -377,18 +395,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
377395
prob_history.clear();
378396

379397
for (int32_t i = 0; i < n_layers; i++) {
380-
if (i == skip_layer || std::find(skips.begin(), skips.end(), i) != skips.end()) {
381-
layers[i] = test_skip_type;
382-
} else {
383-
layers[i] = 0;
384-
}
398+
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
385399
}
386400
layers[n_layers] = -1;
387-
printf("\nSKIP %3d + [", skip_layer);
388-
for (const auto l : skips) printf("%d,", l);
389-
printf("] - len: %3zu, best:(%3d: %.3f), took %.2f sec\n",
401+
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
402+
for (const auto l : skips) {
403+
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
404+
}
405+
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
390406
skips.size() + 1,
391-
curr_best_layer,
407+
int(label[curr_best_type]), curr_best_layer,
392408
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
393409
test_t_total);
394410
test_t_start = std::chrono::high_resolution_clock::now();
@@ -475,10 +491,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
475491
fflush(stdout);
476492
if (skip_layer >= 0 && i + 1 == test_count) {
477493
double ppl = std::exp(nll / count);
494+
skip_types[skip_layer] |= test_skip_type << 2;
478495
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
479496
curr_best_layer = skip_layer;
480497
curr_best_ppl = ppl;
498+
curr_best_type = test_skip_type;
481499
}
500+
printf(" -- %.3f", ppl - ref_ppl);
482501
} else if (skip_layer < 0) {
483502
ref_ppl = std::exp(nll / count);
484503
}

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,10 @@ static struct ggml_cgraph * llm_build_llama(
32463246
// No idea why this is needed, but otherwise we run out of space
32473247
// when skipping attn or mlp (but not both) on the last layer
32483248
run_mlp = false;
3249+
} else if (ggml_allocr_is_measure(lctx.alloc) && il == n_layer - 2) {
3250+
// No idea why this is needed, but otherwise we run out of space
3251+
// when skipping attn or mlp (but not both) on the last layer
3252+
run_attn = false;
32493253
}
32503254
if (!run_attn && !run_mlp) continue;
32513255

0 commit comments

Comments
 (0)