Skip to content

Layer skipping/self-speculation demo #3565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 158 additions & 3 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "llama.h"

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
Expand Down Expand Up @@ -320,15 +321,151 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;

llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);

// model layer count
const int32_t n_layers = 32;

// num perplexity chunks to run for each test
const int test_count = 4;

// prune this many of the worst results each pass
const size_t prune_target = 2;

// start with all but first/last layers disabled and start adding them back
const bool anti_mode = true;

// **** end tunables ***

// 1 = attn, 2 = mlp, 3 = both
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
std::vector<int32_t> layers;
layers.resize(n_layers + 1);
std::fill(layers.begin(), layers.end(), 0);
batch.run_layers = layers.data();
int32_t skip_layer = -1;
std::vector<int32_t> skips;
std::vector<int32_t> skip_types;
skip_types.resize(n_layers);
std::fill(skip_types.begin(), skip_types.end(), 0);
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
std::vector<int32_t> extremes;
extremes.resize(n_layers);
std::fill(extremes.begin(), extremes.end(), 0);
// if (anti_mode) {
// // No point in starting with first/last layer disabled.
// skip_types[0] = 15;
// skip_types[n_layers - 1] = 15;
// skips.push_back(0); skips.push_back(0 + n_layers);
// skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
// }
int32_t curr_best_layer = -1, curr_best_type = 0;
double curr_best_ppl = -1, ref_ppl = -1;
const int32_t mask = anti_mode ? 3 : 0;

int count = 0;
double nll = 0.0;
double nll2 = 0.0;

fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);

std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
static const char * label = "?AMB";

auto test_t_start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < n_chunk; ++i) {
if (i > 0 && i % test_count == 0) {
auto test_t_end = std::chrono::high_resolution_clock::now();
float test_t_total = std::chrono::duration<float>(test_t_end - test_t_start).count();

skip_layer = n_layers;
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
// printf("##%d, %d\n", new_sl, curr_skipped);
if (curr_skipped == 3) continue; // Already tested or perm skip.
skip_layer = new_sl;
test_skip_type = (curr_skipped & 1) != 0 ? 2 : 1;
break;
}
if (skip_layer >= n_layers) {
if (curr_best_layer == -1) break;
if (anti_mode || (prune_target > 0 && pass_results.size() >= prune_target * 2)) {
std::sort(pass_results.begin(), pass_results.end(),
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
return std::get<2>(a) > std::get<2>(b);
}
);
const size_t num_prune = std::min(pass_results.size(), prune_target);
if (num_prune > 0) printf("\nPruning: ");
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
int32_t lidx = std::get<0>(pass_results[temp]);
if (anti_mode) {
skip_types[lidx] |= std::get<1>(pass_results[temp]);
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : lidx + n_layers);
}
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
extremes[lidx] |= std::get<1>(pass_results[temp]);
printf("[%zu: %d (%d) - %.2f], ", pruned + 1, lidx,
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
if (++pruned >= num_prune) break;
}
}
pass_results.clear();
printf("\n\nADD %c%3d - ppl vs ref %.4f - cur:[",
int(label[curr_best_type]), curr_best_layer,
curr_best_ppl - ref_ppl);
if (!anti_mode) {
// if (curr_best_ppl > ref_ppl * 1.75) break;
skip_types[curr_best_layer] += curr_best_type;
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
}
curr_best_layer = -1;
curr_best_ppl = -1;
curr_best_type = 0;
skip_layer = n_layers;
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
skip_types[new_sl] = (skip_types[new_sl] & 3) | (extremes[new_sl] << 2);
}
for (int32_t i = 0; i < n_layers; i++) {
const int val = mask ^ (skip_types[i] & 3);
printf("%d%s", val, i < n_layers - 1 ? ", " : "]");
}
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
// printf("||%d, %d\n", new_sl, curr_skipped);
if (curr_skipped == 3) continue; // Already tested or perm skip.
skip_layer = new_sl;
test_skip_type = (curr_skipped & 1) != 0 ? 2 : 1;
break;
}
if (skip_layer == -1 || skip_layer == n_layers) break;
}

i = 0;
count = 0;
nll = 0;
nll2 = 0;
logit_history.clear();
prob_history.clear();

int alive = 0;
for (int32_t i = 0; i < n_layers; i++) {
layers[i] = mask ^ ((skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0));
alive += !(layers[i] & 1) + !(layers[i] & 2);
}
layers[n_layers] = -1;
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
for (auto l : skips) {
printf("%c%d, ", int(label[skip_types[l % n_layers] & 3]), l % n_layers);
}
printf("] - live: %3d/%3d, best:(%c%3d @ %.3f), last took %.2f sec\n",
alive, n_layers * 2,
int(label[curr_best_type]), curr_best_layer,
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
test_t_total);
test_t_start = std::chrono::high_resolution_clock::now();
}
const int start = i * n_ctx;
const int end = start + n_ctx;

Expand All @@ -353,7 +490,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
tokens[batch_start] = llama_token_bos(ctx);
}

if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
batch.n_tokens = batch_size;
batch.token = tokens.data() + batch_start;
batch.all_pos_0 = j * n_batch;

if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
Expand All @@ -367,7 +508,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0) {
if (i == 0 && skip_layer < 0 && ref_ppl < 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
Expand Down Expand Up @@ -396,15 +537,29 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
count += n_ctx - first - 1;

// perplexity is e^(average negative log-likelihood)
double ppl = std::exp(nll / count);
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
printf("[%d]%.4lf,", i + 1, ppl);
} else {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 30))) {
i = test_count - 1;
skip_types[skip_layer] |= test_skip_type << 2;
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
curr_best_layer = skip_layer;
curr_best_ppl = ppl;
curr_best_type = test_skip_type;
}
printf(" -- %.3f", ppl - ref_ppl);
pass_results.push_back({skip_layer, test_skip_type, ppl});
} else if (skip_layer < 0) {
ref_ppl = ppl;
}
}
printf("\n");

Expand Down
Loading