diff --git a/examples/common.cpp b/examples/common.cpp index bd39d9220cd14..12f4368f4a925 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -119,6 +119,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); } + } else if (arg == "-ppt" || arg == "--pp-threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.pp_threads = std::stoi(argv[i]); + if (params.pp_threads <= 0) { + params.pp_threads = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -545,6 +554,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stdout, " -ppt N, --pp-threads N\n"); + fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads); fprintf(stdout, " -p PROMPT, --prompt PROMPT\n"); fprintf(stdout, " prompt to start generation with (default: empty)\n"); fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); diff --git a/examples/common.h b/examples/common.h index 375bc0a3db416..d59861aa0b255 100644 --- a/examples/common.h +++ b/examples/common.h @@ -19,6 +19,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t pp_threads = get_num_physical_cores(); int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 2185b9b0e2839..d5df68637f4ec 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -83,7 +83,7 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -104,7 +104,7 @@ bool eval_tokens(void * model, std::vector tokens) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 5192d6df5c2f8..58fa2753edd45 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -50,8 +50,8 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } int n_past = 0; @@ -74,7 +74,7 @@ int main(int argc, char ** argv) { if (params.embedding){ if (embd_inp.size() > 0) { - if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) { + if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 266c8eab3b2f6..d69dcaf1a06b6 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -853,7 +853,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); + llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads, n_threads); n_processed += n_tokens; } } @@ -861,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_token token = llama_token_bos(); for (int i = 0; i < n_gen; i++) { - llama_eval(ctx, &token, 1, n_past + i, n_threads); + llama_eval(ctx, &token, 1, n_past + i, n_threads, n_threads); } } diff --git a/examples/main/README.md b/examples/main/README.md index 60e3907d52f5b..4a51c5550c7c6 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -263,6 +263,7 @@ These options help improve the performance and memory usage of the LLaMA models. ### Number of Threads - `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance. +- `-ppt N, --pp-threads N`: Set the number of threads to use during prompt processing only. ### Mlock diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a632bea1cf2b9..e0db33a370975 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -133,8 +133,8 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters @@ -513,7 +513,8 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { + int eval_thr = n_eval > 1 ? params.pp_threads : params.n_threads; + if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, eval_thr)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } @@ -527,7 +528,8 @@ int main(int argc, char ** argv) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + int eval_thr = n_eval > 1 ? params.pp_threads : params.n_threads; + if (llama_eval(ctx, &embd[i], n_eval, n_past, eval_thr)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 2409db69f1afd..51bb5267cd6c9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -67,7 +67,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { tokens[batch_start] = llama_token_bos(); } - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads, params.pp_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } @@ -130,7 +130,7 @@ std::vector hellaswag_evaluate_tokens(llama_context * ctx, const std::vec for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread, n_thread)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } @@ -402,8 +402,8 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "system_info: pp_threads = %d / %d | %s\n", + params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } if (params.hellaswag) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 61c71c3589fdf..2a466c2e20fac 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -10,6 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; params.seed = 42; params.n_threads = 4; + params.pp_threads = 4; params.repeat_last_n = 64; params.prompt = "The quick brown fox"; @@ -56,7 +57,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -93,7 +94,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -153,7 +154,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index 1559dd3f2639a..b605e5269680b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -5,6 +5,7 @@ This example demonstrates a simple HTTP API server and a simple web front end to Command line options: - `--threads N`, `-t N`: Set the number of threads to use during computation. +- `-ppt N`, `--pp-threads N`: Set the number of threads to use during prompt processing only. - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). - `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 99660455ac0b1..b1c42bb79f155 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -385,7 +385,7 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, @@ -651,6 +651,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -h, --help show this help message and exit\n"); fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stdout, " -ppt N, --pp-threads N\n"); + fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); @@ -822,6 +824,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } + else if (arg == "-ppt" || arg == "--pp-threads") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.pp_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) @@ -1185,6 +1196,7 @@ int main(int argc, char **argv) {"commit", BUILD_COMMIT}}); LOG_INFO("system info", { {"n_threads", params.n_threads}, + {"pp_threads", params.pp_threads}, {"total_threads", std::thread::hardware_concurrency()}, {"system_info", llama_print_system_info()}, }); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 97137a6584aa3..f093da32a3789 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -123,7 +123,7 @@ int main(int argc, char ** argv) // Evaluate the tokens : //--------------------------------- - if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) + if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads , params.n_threads ) ) { fprintf( stderr, "%s : failed to eval\n" , __func__ ); return 1; diff --git a/llama.cpp b/llama.cpp index f2cbe764142e5..ef2672ae9d510 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1786,7 +1786,7 @@ static struct ggml_cgraph * llama_build_graph( // - embd embeddings input // - n_tokens number of tokens // - n_past: the context size so far -// - n_threads: number of threads to use +// - n_threads: number of threads to use for inference // static bool llama_eval_internal( llama_context & lctx,