From 51bb7f0eef1ba932913f2c7f555e98f43a758658 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 14:58:40 +0200 Subject: [PATCH 1/4] server : fix context shift + simplify self-extend --- examples/server/server.cpp | 49 +++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 11dd82c33c106..913e6098e972c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -188,8 +188,6 @@ struct llama_client_slot int32_t ga_n = 1;// group-attention factor int32_t ga_w = 512; // group-attention width - int32_t n_past_se = 0; // self-extend - // multimodal std::vector images; @@ -219,7 +217,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; - n_past_se = 0; + generated_token_probs.clear(); for (slot_image & img : images) @@ -1364,18 +1362,18 @@ struct llama_server_context kv_cache_clear(); } return true; - } else { - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; - task.target_id = -1; - queue_tasks.post(task); } + task_server task; + task.type = TASK_TYPE_NEXT_RESPONSE; + task.target_id = -1; + queue_tasks.post(task); + for (llama_client_slot &slot : slots) { if (slot.ga_n == 1) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context const int n_left = slot.n_past - slot.params.n_keep - 1; @@ -1428,8 +1426,7 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; } @@ -1527,7 +1524,6 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; - slot.n_past_se = 0; slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } @@ -1557,7 +1553,7 @@ struct llama_server_context } slot_npast++; } - slot.n_past_se = slot_npast; + slot.n_past = slot_npast; slot.ga_i = ga_i; } @@ -1577,7 +1573,7 @@ struct llama_server_context slot.n_past--; if (slot.ga_i > 0) { - slot.n_past_se--; + slot.n_past--; } } @@ -1591,7 +1587,6 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; int ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; @@ -1599,14 +1594,14 @@ struct llama_server_context { if (slot.ga_n != 1) { - while (slot_npast >= ga_i + ga_w) { + while (slot.n_past >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; + slot.n_past -= bd; ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); - slot_npast += 1; + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false); + slot.n_past += 1; } if (has_images && !ingest_images(slot, n_batch)) @@ -1642,28 +1637,28 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) + while (slot.n_past >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past + ib * bd, dd); - slot.n_past_se -= bd; + slot.n_past -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); } - slot.n_past_se += n_tokens; + slot.n_past += n_tokens; } } llama_batch batch_view = From 8772d3ee638f11e575a5a38b82bb58b05b364aef Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 15:52:18 +0200 Subject: [PATCH 2/4] server : take system_tokens into account --- examples/server/server.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 913e6098e972c..17039e8d93c78 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1225,7 +1225,7 @@ struct llama_server_context std::vector append_tokens = tokenize(json_prompt, false); // has next image for (int i = 0; i < (int) append_tokens.size(); ++i) { - llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true); + llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; } } @@ -1376,12 +1376,12 @@ struct llama_server_context if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1426,6 +1426,8 @@ struct llama_server_context slot.i_batch = batch.n_tokens; + // TODO: we always have to take into account the "system_tokens" + // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -1478,8 +1480,8 @@ struct llama_server_context prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS - prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); - prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(model)); prompt_tokens = prefix_tokens; } From d0e10bf1b2617712dfb3a89b7ae264f77dde9a3d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 13:22:33 +0200 Subject: [PATCH 3/4] server : more n_past fixes --- examples/server/chat.sh | 1 + examples/server/server.cpp | 122 ++++++++++++++++++------------------- 2 files changed, 61 insertions(+), 62 deletions(-) diff --git a/examples/server/chat.sh b/examples/server/chat.sh index 0143601214b15..da0a6ca68ca6f 100755 --- a/examples/server/chat.sh +++ b/examples/server/chat.sh @@ -48,6 +48,7 @@ chat_completion() { top_p: 0.9, n_keep: $n_keep, n_predict: 256, + cache_prompt: true, stop: ["\n### Human:"], stream: true }')" diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 17039e8d93c78..7ff3622f9d6df 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -185,7 +185,7 @@ struct llama_client_slot llama_sampling_context *ctx_sampling = nullptr; int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1;// group-attention factor + int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width // multimodal @@ -1293,6 +1293,7 @@ struct llama_server_context for (llama_client_slot &slot : slots) { slot.cache_tokens.clear(); + slot.n_past = 0; } } @@ -1429,7 +1430,6 @@ struct llama_server_context // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); - slot.n_past += 1; } @@ -1540,25 +1540,6 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; - if (slot.ga_n != 1) - { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; - } - slot.n_past = slot_npast; - slot.ga_i = ga_i; - } - LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1573,25 +1554,44 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; - if (slot.ga_i > 0) - { - slot.n_past--; - } } LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, }); + if (slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past = slot_npast; + slot.ga_i = ga_i; + + LOG_TEE("slot %d : applied self-extend to prompt: %i tokens\n", slot.id, slot.n_past); + } + const bool has_images = process_images(slot); // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int ga_i = slot.ga_i; + + int32_t ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; + for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { if (slot.ga_n != 1) @@ -1603,7 +1603,6 @@ struct llama_server_context } } llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false); - slot.n_past += 1; } if (has_images && !ingest_images(slot, n_batch)) @@ -1660,7 +1659,6 @@ struct llama_server_context LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); } - slot.n_past += n_tokens; } } llama_batch batch_view = @@ -1779,51 +1777,51 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); if (llama_mlock_supported()) { - printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); + printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } if (llama_mmap_supported()) { - printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); + printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } - printf(" --numa attempt optimizations that help on some NUMA systems\n"); + printf(" --numa attempt optimizations that help on some NUMA systems\n"); #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); + printf(" number of layers to store in VRAM\n"); printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); - printf(" how to split the model across multiple GPUs, one of:\n"); - printf(" - none: use one GPU only\n"); - printf(" - layer (default): split layers and KV across GPUs\n"); - printf(" - row: split rows across GPUs\n"); + printf(" how to split the model across multiple GPUs, one of:\n"); + printf(" - none: use one GPU only\n"); + printf(" - layer (default): split layers and KV across GPUs\n"); + printf(" - row: split rows across GPUs\n"); printf(" -ts SPLIT --tensor-split SPLIT\n"); - printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); - printf(" or for intermediate results and KV (with split-mode = row)\n"); + printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); + printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); + printf(" or for intermediate results and KV (with split-mode = row)\n"); #endif printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: %s)\n", params.model.c_str()); + printf(" model path (default: %s)\n", params.model.c_str()); printf(" -a ALIAS, --alias ALIAS\n"); - printf(" set an alias for the model, will be added as `model` field in completion response\n"); - printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); - printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); - printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); - printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); - printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); - printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); - printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); - printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); - printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); - printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); - printf(" -spf FNAME, --system-prompt-file FNAME\n"); - printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); - printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); - printf(" --log-disable disables logging to a file.\n"); + printf(" set an alias for the model, will be added as `model` field in completion response\n"); + printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); + printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); + printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); + printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); + printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); + printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); + printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); + printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -spf FNAME, --system-prompt-file FNAME\n"); + printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); + printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n"); + printf(" --log-disable disables logging to a file.\n"); printf("\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); - printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); - printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); - printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); + printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); + printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); + printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf("\n"); } From 05350f286289b1b95e6712af9fca36e8b38e0e62 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 Jan 2024 19:05:36 +0200 Subject: [PATCH 4/4] server : rever n_past_se changes --- examples/server/server.cpp | 80 ++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7ff3622f9d6df..21bdce8edb780 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -188,6 +188,8 @@ struct llama_client_slot int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + int32_t n_past_se = 0; // self-extend + // multimodal std::vector images; @@ -217,6 +219,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; + n_past_se = 0; generated_token_probs.clear(); @@ -1293,7 +1296,8 @@ struct llama_server_context for (llama_client_slot &slot : slots) { slot.cache_tokens.clear(); - slot.n_past = 0; + slot.n_past = 0; + slot.n_past_se = 0; } } @@ -1427,9 +1431,11 @@ struct llama_server_context slot.i_batch = batch.n_tokens; + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; } @@ -1526,6 +1532,7 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_se = 0; slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } @@ -1540,6 +1547,25 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + if (slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past_se = slot_npast; + slot.ga_i = ga_i; + } + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1554,6 +1580,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if (slot.ga_i > 0) + { + slot.n_past_se--; + } } LOG_VERBOSE("prompt ingested", { @@ -1562,32 +1592,13 @@ struct llama_server_context {"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())}, }); - if (slot.ga_n != 1) - { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; - } - slot.n_past = slot_npast; - slot.ga_i = ga_i; - - LOG_TEE("slot %d : applied self-extend to prompt: %i tokens\n", slot.id, slot.n_past); - } - const bool has_images = process_images(slot); // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + int32_t ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; @@ -1596,13 +1607,14 @@ struct llama_server_context { if (slot.ga_n != 1) { - while (slot.n_past >= ga_i + ga_w) { + while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); - slot.n_past -= bd; + slot_npast -= bd; ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + slot_npast++; } if (has_images && !ingest_images(slot, n_batch)) @@ -1638,29 +1650,31 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past >= slot.ga_i + slot.ga_w) + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - slot.n_past -= bd; + slot.n_past_se -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); } + slot.n_past_se += n_tokens; } } + llama_batch batch_view = { n_tokens,