Skip to content

Save and restore prompt evaluation state for much faster startup times #1169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/chat-13B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ The transcript only includes text, it does not include markup like HTML and Mark

$USER_NAME: Hello, $AI_NAME!
$AI_NAME: Hello $USER_NAME! How may I help you today?
$USER_NAME: What time is it?
$AI_NAME: It is $(date +%H:%M).
$USER_NAME: What year is it?
$AI_NAME: We are in $(date +%Y).
$USER_NAME: Please tell me the largest city in Europe.
Expand All @@ -50,4 +48,6 @@ $AI_NAME: The arguments are stored in process.argv.
argv[3] is the second argument passed to the script and so on.
$USER_NAME: Name a color.
$AI_NAME: Blue
$USER_NAME: What time is it?
$AI_NAME: It is $(date +%H:%M).
$USER_NAME:" "$@"
7 changes: 7 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.prompt = argv[i];
} else if (arg == "--session") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.path_session = argv[i];
} else if (arg == "-f" || arg == "--file") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -228,6 +234,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: empty)\n");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct gpt_params {

std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string path_session = ""; // path to file for saving/loading model eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

Expand Down
89 changes: 89 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ int main(int argc, char ** argv) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');

std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;

if (!path_session.empty()) {
fprintf(stderr, "%s: attempting to load saved session from %s..\n", __func__, path_session.c_str());

// REVIEW - fopen to check for existing session
FILE * fp = std::fopen(path_session.c_str(), "rb");
if (fp != NULL) {
std::fclose(fp);

session_tokens.resize(params.n_ctx);
size_t n_token_count_out = 0;
const size_t n_session_bytes = llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
session_tokens.resize(n_token_count_out);

if (n_session_bytes > 0) {
fprintf(stderr, "%s: loaded %zu bytes of session data!\n", __func__, n_session_bytes);
} else {
fprintf(stderr, "%s: could not load session file, will recreate\n", __func__);
}
} else {
fprintf(stderr, "%s: session file does not exist, will create\n", __func__);
}
}

// tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);

Expand All @@ -167,6 +193,26 @@ int main(int argc, char ** argv) {
return 1;
}

// debug message about similarity of saved session, if applicable
size_t n_matching_session_tokens = 0;
if (session_tokens.size()) {
for (llama_token id : session_tokens) {
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
break;
}
n_matching_session_tokens++;
}
if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
__func__, n_matching_session_tokens, embd_inp.size());
} else {
fprintf(stderr, "%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
}

// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
Expand Down Expand Up @@ -252,9 +298,16 @@ int main(int argc, char ** argv) {
bool is_antiprompt = false;
bool input_noecho = false;

// HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
// if we loaded a session with at least 75% similarity. It's currently just used to speed up the
// initial prompt so it doesn't need to be an exact match.
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);


int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;

// the first thing we will do is to output the prompt, so set color accordingly
set_console_color(con_st, CONSOLE_COLOR_PROMPT);
Expand All @@ -276,6 +329,9 @@ int main(int argc, char ** argv) {
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());

// REVIEW - stop saving session if we run out of context
path_session = "";

//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
Expand All @@ -285,6 +341,28 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
}

// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
// REVIEW
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;
}

n_past++;
n_session_consumed++;

if (n_session_consumed >= (int) session_tokens.size()) {
break;
}
}
if (i > 0) {
embd.erase(embd.begin(), embd.begin() + i);
}
}

// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
Expand All @@ -298,6 +376,11 @@ int main(int argc, char ** argv) {
}
n_past += n_eval;
}

if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
}

embd.clear();
Expand All @@ -309,6 +392,12 @@ int main(int argc, char ** argv) {
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;

// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_token id = 0;

{
Expand Down
53 changes: 53 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2412,3 +2412,56 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
return ctx->model.tensors_by_name;
}

size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
// TODO leverage mmap
llama_file file(path_session, "rb");
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();

if (!(magic == 'ggsn' && version == 0)) {
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return 0;
}

llama_hparams session_hparams;
file.read_raw(&session_hparams, sizeof(llama_hparams));

// REVIEW
if (session_hparams != ctx->model.hparams) {
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
return 0;
}

const uint32_t n_token_count = file.read_u32();
LLAMA_ASSERT(n_token_capacity >= n_token_count);
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;

const size_t n_state_size = file.size - file.tell();
const size_t n_orig_state_size = llama_get_state_size(ctx);
if (n_state_size != n_orig_state_size) {
fprintf(stderr, "%s : failed to validate state size\n", __func__);
}
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
file.read_raw(state_data.get(), n_state_size);
return llama_set_state_data(ctx, state_data.get());
}

size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
// TODO save temp & swap
llama_file file(path_session, "wb");

const size_t n_state_size = llama_get_state_size(ctx);
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]);
llama_copy_state_data(ctx, state_data.get());

file.write_u32('ggsn'); // magic
file.write_u32(0); // version
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));

file.write_u32((uint32_t) n_token_count); // REVIEW
file.write_raw(tokens, sizeof(llama_token) * n_token_count);

file.write_raw(state_data.get(), n_state_size);
return n_state_size; // REVIEW
}
4 changes: 4 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ extern "C" {
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);

// Save/load session file
LLAMA_API size_t llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API size_t llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);

// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
Expand Down