diff --git a/common/common.cpp b/common/common.cpp index 6b07f119718c0..83a0b1bddc653 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -446,6 +446,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_batch = std::stoi(argv[i]); + } else if (arg == "-ub" || arg == "--ubatch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ubatch = std::stoi(argv[i]); } else if (arg == "--keep") { if (++i >= argc) { invalid_param = true; @@ -926,6 +932,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" -ub N, --ubatch-size N\n"); + printf(" micro batch size for prompt processing (default: %d)\n", params.n_ubatch); printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); @@ -1171,6 +1179,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; + cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.mul_mat_q = params.mul_mat_q; diff --git a/common/common.h b/common/common.h index 214a379b57d1b..30d5f860c0144 100644 --- a/common/common.h +++ b/common/common.h @@ -51,7 +51,8 @@ struct gpt_params { int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_batch = 4096; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 256; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 8; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 97325b5bd634f..f6eb515a1cb52 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -202,7 +202,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str()); - printf(" -ts, --tensor_split (default: 0)\n"); + printf(" -ts, --tensor-split (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); printf(" -o, --output (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); @@ -1041,16 +1041,22 @@ struct sql_printer : public printer { }; static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { + llama_set_n_threads(ctx, n_threads, n_threads); + + std::vector tokens(n_prompt, llama_token_bos(llama_get_model(ctx))); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0)); + + GGML_UNUSED(n_batch); + +/* std::vector tokens(n_batch, llama_token_bos(llama_get_model(ctx))); int n_processed = 0; - llama_set_n_threads(ctx, n_threads, n_threads); - while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); n_processed += n_tokens; } +*/ } static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { @@ -1149,11 +1155,12 @@ int main(int argc, char ** argv) { // warmup run if (t.n_prompt > 0) { - test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); + test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { test_gen(ctx, 1, 0, t.n_threads); } + llama_get_logits(ctx); // force sync for (int i = 0; i < params.reps; i++) { llama_kv_cache_clear(ctx); @@ -1165,6 +1172,8 @@ int main(int argc, char ** argv) { if (t.n_gen > 0) { test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); } + llama_get_logits(ctx); // force sync + uint64_t t_ns = get_time_ns() - t_start; t.samples_ns.push_back(t_ns); } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8d2204969c0cb..03136b3e64e99 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1782,13 +1782,13 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; - params.n_batch = 512; + //params.n_batch = 512; if (!gpt_params_parse(argc, argv, params)) { return 1; } params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); + //params.n_batch = std::min(params.n_batch, params.n_ctx); if (params.ppl_stride > 0) { fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", diff --git a/ggml-alloc.c b/ggml-alloc.c index 60141a34d8f6a..da11001f88ff6 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -319,6 +319,18 @@ struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) { return alloc->buffer; } +void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer) { + GGML_ASSERT(talloc->measure == false); + // FIXME: buffer ownership semantics + // if the user is doing this, they probably want to take ownership of the buffer + // or they need to restore the original buffer before freeing the allocator + //talloc->buffer_owned = false; + talloc->buffer = buffer; + talloc->base = ggml_backend_buffer_get_base(buffer); + talloc->alignment = ggml_backend_buffer_get_alignment(buffer); + ggml_tallocr_reset(talloc); +} + void ggml_tallocr_free(ggml_tallocr_t alloc) { if (alloc == NULL) { return; diff --git a/ggml-alloc.h b/ggml-alloc.h index 4e59975213406..08c3d84d36d8c 100644 --- a/ggml-alloc.h +++ b/ggml-alloc.h @@ -59,6 +59,7 @@ GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_b GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend); GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc); +GGML_API void ggml_tallocr_set_buffer(ggml_tallocr_t talloc, struct ggml_backend_buffer * buffer); GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc); GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc); diff --git a/ggml-backend-impl.h b/ggml-backend-impl.h index 1397828d9ac71..fb4980d94eb74 100644 --- a/ggml-backend-impl.h +++ b/ggml-backend-impl.h @@ -78,14 +78,14 @@ extern "C" { ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst); + void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations void (*GGML_CALL synchronize)(ggml_backend_t backend); - // compute graph with a plan + // compute graph with a plan (not used currently) ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); void (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); @@ -95,14 +95,25 @@ extern "C" { // check if the backend supports an operation bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + + // (optional) event synchronization + ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend); + void (*GGML_CALL event_free) (ggml_backend_event_t event); + void (*GGML_CALL event_record) (ggml_backend_event_t event); + void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + void (*GGML_CALL event_synchronize) (ggml_backend_event_t event); }; struct ggml_backend { struct ggml_backend_i iface; - ggml_backend_context_t context; }; + struct ggml_backend_event { + ggml_backend_t backend; + void * context; + }; + // // Backend registry // diff --git a/ggml-backend.c b/ggml-backend.c index 423512defc132..f1358c07a8bc1 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -194,21 +194,21 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - tensor->buffer->iface.set_tensor(buf, tensor, data, offset, size); + buf->iface.set_tensor(buf, tensor, data, offset, size); } GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - tensor->buffer->iface.get_tensor(buf, tensor, data, offset, size); + buf->iface.get_tensor(buf, tensor, data, offset, size); } void ggml_backend_synchronize(ggml_backend_t backend) { @@ -279,30 +279,52 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { return; } - if (ggml_backend_buft_supports_backend(src->buffer->buft, backend) && ggml_backend_buft_supports_backend(dst->buffer->buft, backend)) { - if (backend->iface.cpy_tensor_async != NULL) { - if (backend->iface.cpy_tensor_async(backend, src, dst)) { - return; - } + if (backend_dst->iface.cpy_tensor_async != NULL) { + if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { + return; } } size_t nbytes = ggml_nbytes(src); if (ggml_backend_buffer_is_host(src->buffer)) { - ggml_backend_tensor_set_async(backend, dst, src->data, 0, nbytes); + // wait for src to be ready before copy + ggml_backend_synchronize(backend_src); + ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, nbytes); } else { ggml_backend_tensor_copy(src, dst); } } +// events + +ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) { + return backend->iface.event_new(backend); +} + +void ggml_backend_event_free(ggml_backend_event_t event) { + event->backend->iface.event_free(event); + free(event); +} + +void ggml_backend_event_record(ggml_backend_event_t event) { + event->backend->iface.event_record(event); +} + +void ggml_backend_event_synchronize(ggml_backend_event_t event) { + event->backend->iface.event_synchronize(event); +} + +void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + backend->iface.event_wait(backend, event); +} // backend registry @@ -716,6 +738,11 @@ static struct ggml_backend_i cpu_backend_i = { /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .supports_op = */ ggml_backend_cpu_supports_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .event_synchronize = */ NULL, }; ggml_backend_t ggml_backend_cpu_init(void) { @@ -853,6 +880,8 @@ static ggml_tallocr_t sched_allocr_from_buffer(ggml_backend_sched_t sched, ggml_ return sched->tallocs[i]; } } + + fprintf(stderr, "%s: error: no backend supports buffer type %s\n", __func__, ggml_backend_buffer_name(buffer)); GGML_ASSERT(false && "tensor buffer type not supported by any backend"); } @@ -1315,6 +1344,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { // copy the input tensors to the split backend uint64_t copy_start_us = ggml_time_us(); for (int j = 0; j < split->n_inputs; j++) { + ggml_backend_t input_backend = get_allocr_backend(sched, node_allocr(split->inputs[j])); struct ggml_tensor * input = split->inputs[j]; struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][split_backend_id]; @@ -1323,7 +1353,7 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { // TODO: avoid this copy if it was already copied in a previous split, and the input didn't change // this is important to avoid copying constants such as KQ_mask and inp_pos multiple times - ggml_backend_tensor_copy_async(split_backend, input, input_cpy); + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } //ggml_backend_synchronize(split_backend); // necessary to measure copy time int64_t copy_end_us = ggml_time_us(); @@ -1335,7 +1365,6 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { ggml_graph_dump_dot(split->graph, NULL, split_filename); #endif - uint64_t compute_start_us = ggml_time_us(); if (!sched->callback_eval) { ggml_backend_graph_compute(split_backend, &split->graph); @@ -1471,6 +1500,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { sched_reset(sched); } +void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } +} void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { sched->callback_eval = callback; diff --git a/ggml-backend.h b/ggml-backend.h index ab4ad773ffbce..03f7dc1c4a636 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -9,6 +9,7 @@ extern "C" { typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + typedef struct ggml_backend_event * ggml_backend_event_t; typedef struct ggml_backend * ggml_backend_t; typedef void * ggml_backend_graph_plan_t; @@ -47,7 +48,6 @@ extern "C" { // Backend // - GGML_API const char * ggml_backend_name(ggml_backend_t backend); GGML_API void ggml_backend_free(ggml_backend_t backend); @@ -72,7 +72,14 @@ extern "C" { // tensor copy between different backends GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t src_backend, ggml_backend_t dst_backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy + + // events + GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend); + GGML_API void ggml_backend_event_free (ggml_backend_event_t event); + GGML_API void ggml_backend_event_record (ggml_backend_event_t event); // can only be called from the backend that created the event + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); // can only be called from the backend that created the event + GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // can be called from any backend // // CPU backend @@ -118,17 +125,21 @@ extern "C" { /* Example usage: - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); + // operations that use tensors allocated in a buffer with USAGE_WEIGHTS + // will be assigned preferrably to run on the buffer backend by ggml_backend_sched + ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE); // sched is initialized with measure allocators and cannot be used until allocated with a measure graph // initialize buffers from a measure graph measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed // in build_graph: - build_graph(...) { + void build_graph(...) { // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) - alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu); - ggml_allocr_alloc(alloc_cpu, tensor); + alloc_cpu = ggml_backend_sched_get_tallocr(sched, backend_cpu); + ggml_tallocr_alloc(alloc_cpu, tensor); // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) struct ggml_tensor * node = ggml_mul_mat(ctx, ...); @@ -143,6 +154,7 @@ extern "C" { // compute graph = build_graph(sched); ggml_backend_sched_graph_compute(sched, graph); + */ struct ggml_backend_sched; @@ -176,6 +188,9 @@ extern "C" { // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); + // Synchronize all backends + GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); + // Set a callback to be called for each resulting node during graph compute GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7f460449eaa05..abaaee7496e5d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10813,6 +10813,10 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { return &ggml_backend_cuda_buffer_type_host; } +static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name; +} + // backend GGML_CALL static const char * ggml_backend_cuda_name(ggml_backend_t backend) { @@ -10836,8 +10840,9 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0])); @@ -10845,22 +10850,79 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0])); } -GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { + GGML_ASSERT(ggml_backend_is_cuda(backend_src) || ggml_backend_is_cuda(backend_dst)); + + ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (dst->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && ggml_backend_buffer_is_cuda(src->buffer)) { - CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, g_cudaStreams[cuda_ctx->device][0])); + // host -> device + if (ggml_backend_buffer_is_cuda_host(buf_src) && ggml_backend_buffer_is_cuda(buf_dst)) { + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + // make sure the data is ready on the source backend + // the CPU backend does not support async compute, so this does nothing at the moment + // but conceptually, it is necessary to synchronize with the source backend + ggml_backend_synchronize(backend_src); + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx_dst->device][0])); return true; } - return false; + // device -> host + if (ggml_backend_buffer_is_cuda_host(buf_dst) && ggml_backend_buffer_is_cuda(buf_src)) { + // this shoudln't happen currently because the dst backend is our own backend, which does not support host buffers + GGML_ASSERT(false); + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx_src->device][0])); + return true; + } + + if (!ggml_backend_buffer_is_cuda(src->buffer)) { + return false; + } + + if (!ggml_backend_buffer_is_cuda(dst->buffer)) { + return false; + } + + // device -> device + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + + if (backend_src != backend_dst) { + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + + GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device); + GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device); + + ggml_cuda_set_device(cuda_ctx_src->device); + + cudaEvent_t event; + CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + + // record event on src stream + CUDA_CHECK(cudaEventRecord(event, g_cudaStreams[cuda_ctx_src->device][0])); + + // wait on dst stream + ggml_cuda_set_device(cuda_ctx_dst->device); + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx_dst->device][0], event, 0)); + + CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), g_cudaStreams[cuda_ctx_dst->device][0])); + + CUDA_CHECK(cudaEventDestroy(event)); + } else { + // copy + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, g_cudaStreams[cuda_ctx_dst->device][0])); + } + return true; } GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { @@ -11027,6 +11089,58 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons UNUSED(backend); } +static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + cudaEvent_t event; + CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + + return new ggml_backend_event { + /* .backend = */ backend, + /* .context = */ event, + }; +} + +static void ggml_backend_cuda_event_free(ggml_backend_event_t event) { + CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context)); + + delete event; +} + +static void ggml_backend_cuda_event_record(ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)event->backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, g_cudaStreams[cuda_ctx->device][0])); +} + +static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + if (ggml_backend_is_cuda(event->backend)) { + + ggml_cuda_set_device(cuda_ctx->device); + + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[cuda_ctx->device][0], (cudaEvent_t)event->context, 0)); + } else { + auto wait_fn = [](void * user_data) { + ggml_backend_event_t event = (ggml_backend_event_t)user_data; + ggml_backend_event_synchronize(event); + }; + + CUDA_CHECK(cudaLaunchHostFunc(g_cudaStreams[cuda_ctx->device][0], wait_fn, event)); + } +} + +static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) { + assert(backend == event->backend); + + CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); +} + static ggml_backend_i ggml_backend_cuda_interface = { /* .get_name = */ ggml_backend_cuda_name, /* .free = */ ggml_backend_cuda_free, @@ -11040,6 +11154,11 @@ static ggml_backend_i ggml_backend_cuda_interface = { /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .supports_op = */ ggml_backend_cuda_supports_op, + /* .event_new = */ ggml_backend_cuda_event_new, + /* .event_free = */ ggml_backend_cuda_event_free, + /* .event_record = */ ggml_backend_cuda_event_record, + /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .event_synchronize = */ ggml_backend_cuda_event_synchronize, }; GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { diff --git a/ggml.c b/ggml.c index ca98fde8ab239..d09450b83aaf3 100644 --- a/ggml.c +++ b/ggml.c @@ -2553,6 +2553,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ GGML_OP_NONE, /*.op_params =*/ { 0 }, + /*.flags =*/ 0, /*.is_param =*/ false, /*.grad =*/ NULL, /*.src =*/ { NULL }, @@ -10902,8 +10903,6 @@ static void ggml_compute_forward_get_rows_q( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -10911,7 +10910,7 @@ static void ggml_compute_forward_get_rows_q( GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + const int64_t nr = ggml_nelements(src1); const enum ggml_type type = src0->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; @@ -10921,17 +10920,25 @@ static void ggml_compute_forward_get_rows_q( assert(nb00 == ggml_type_size(type)); assert(ggml_nrows(dst) == nr); - // TODO: multi-thread - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + const int ith = params->ith; + const int nth = params->nth; - dequantize_row_q( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } - } + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -10940,8 +10947,6 @@ static void ggml_compute_forward_get_rows_f16( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -10949,24 +10954,32 @@ static void ggml_compute_forward_get_rows_f16( GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + const int64_t nr = ggml_nelements(src1); assert(ne0 == nc); assert(ne02 == ne11); assert(nb00 == sizeof(ggml_fp16_t)); assert(ggml_nrows(dst) == nr); - // TODO: multi-thread - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } - } + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -10975,8 +10988,6 @@ static void ggml_compute_forward_get_rows_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(params->ith == 0); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -10984,24 +10995,32 @@ static void ggml_compute_forward_get_rows_f32( GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + const int64_t nr = ggml_nelements(src1); assert(ne0 == nc); assert(ne02 == ne11); assert(nb00 == sizeof(float)); assert(ggml_nrows(dst) == nr); - // TODO: multi-thread - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + const int ith = params->ith; + const int nth = params->nth; - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); - } - } + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); } } @@ -16565,6 +16584,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_GET_ROWS: + { + n_tasks = MIN(n_threads, ggml_nelements(node->src[1])); + } break; case GGML_OP_SCALE: case GGML_OP_SET: case GGML_OP_CONT: @@ -16572,7 +16595,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG: { diff --git a/ggml.h b/ggml.h index 1c49762716774..fab7e9c79d1bb 100644 --- a/ggml.h +++ b/ggml.h @@ -503,11 +503,17 @@ extern "C" { enum ggml_log_level { GGML_LOG_LEVEL_ERROR = 2, - GGML_LOG_LEVEL_WARN = 3, - GGML_LOG_LEVEL_INFO = 4, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_INFO = 4, GGML_LOG_LEVEL_DEBUG = 5 }; + enum ggml_tensor_flags { + GGML_TENSOR_INPUT = 1, + GGML_TENSOR_OUTPUT = 2, + GGML_TENSOR_PARAM = 4, + }; + // ggml object struct ggml_object { size_t offs; @@ -541,7 +547,9 @@ extern "C" { // op params - allocated as int32_t for alignment int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; - bool is_param; + int32_t flags; + + bool is_param; // TODO: move to flags struct ggml_tensor * grad; struct ggml_tensor * src[GGML_MAX_SRC]; @@ -560,7 +568,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + char padding[12]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); diff --git a/llama.cpp b/llama.cpp index 114046db92675..0f0c1d5dda4c1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1420,6 +1420,7 @@ struct llama_hparams { struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; + uint32_t n_ubatch; uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing @@ -1664,12 +1665,23 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { + for (auto & it : bufs_compute) { + // restore the original buffer in the tallocr + ggml_tallocr_t allocr = ggml_backend_sched_get_tallocr(sched, it.first); + ggml_tallocr_set_buffer(allocr, it.second[0]); + // free the rest of the buffers + for (size_t i = 1; i < it.second.size(); ++i) { + ggml_backend_buffer_free(it.second[i]); + } + } + ggml_backend_sched_free(sched); for (ggml_backend_t backend : backends) { ggml_backend_free(backend); } + ggml_backend_buffer_free(buf_logits); ggml_backend_buffer_free(buf_input); ggml_free(ctx_input); } @@ -1702,7 +1714,12 @@ struct llama_context { int32_t n_eval = 0; // number of eval calls // decode output (2-dimensional array: [n_tokens][n_vocab]) - std::vector logits; + //std::vector logits; + + ggml_backend_buffer_t buf_logits = nullptr; + size_t logits_size = 0; + float * logits = nullptr; + #ifndef NDEBUG // guard against access to unset logits std::vector logits_valid; @@ -1716,7 +1733,11 @@ struct llama_context { std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; // allocator for the input tensors - ggml_tallocr * alloc = nullptr; + ggml_tallocr_t alloc_cpu = nullptr; + + std::map> bufs_compute; + size_t n_compute_bufs = 0; + size_t i_compute_buf = 0; // input tensors ggml_backend_buffer_t buf_input = nullptr; @@ -3292,7 +3313,8 @@ static bool llm_load_tensors( const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0); // there is very little benefit to offloading the input layer, so always keep it on the CPU - model.buft_input = llama_default_buffer_type_cpu(true); + //model.buft_input = llama_default_buffer_type_cpu(true); + model.buft_input = llama_default_buffer_type_offload(main_gpu); model.buft_layer.resize(n_layer); @@ -6374,7 +6396,7 @@ static struct ggml_cgraph * llama_build_graph( const auto & model = lctx.model; // check if we should build the worst-case graph (for memory measurement) - const bool worst_case = ggml_tallocr_is_measure(lctx.alloc); + const bool worst_case = ggml_tallocr_is_measure(lctx.alloc_cpu); // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { @@ -6400,7 +6422,7 @@ static struct ggml_cgraph * llama_build_graph( // set input data // - if (!ggml_tallocr_is_measure(lctx.alloc)) { + if (!ggml_tallocr_is_measure(lctx.alloc_cpu)) { if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -6540,10 +6562,11 @@ static struct ggml_cgraph * llama_build_graph( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch) { - const uint32_t n_tokens = batch.n_tokens; + llama_batch all_batch) { - if (n_tokens == 0) { + const uint32_t n_tokens_all = all_batch.n_tokens; + + if (n_tokens_all == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; } @@ -6552,12 +6575,9 @@ static int llama_decode_internal( const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - const auto n_batch = cparams.n_batch; - - GGML_ASSERT(n_tokens <= n_batch); + GGML_ASSERT((!all_batch.token && all_batch.embd) || (all_batch.token && !all_batch.embd)); // NOLINT - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + GGML_ASSERT(n_tokens_all <= cparams.n_ctx); const int64_t t_start_us = ggml_time_us(); @@ -6567,206 +6587,239 @@ static int llama_decode_internal( //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - GGML_ASSERT(n_threads > 0); - auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; + auto * logits_out = lctx.logits; - if (batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = batch.all_pos_0 + i*batch.all_pos_1; - } +#ifndef NDEBUG + auto & logits_valid = lctx.logits_valid; + logits_valid.clear(); + logits_valid.resize(n_tokens_all); - batch.pos = pos.data(); - } + memset(logits_out, 0, lctx.logits_size*sizeof(float)); +#endif - if (batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - batch.n_seq_id = n_seq_id.data(); - batch.seq_id = seq_id_arr.data(); - } + const uint32_t n_ubatch = cparams.n_ubatch; - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } + //printf("n_tokens_all = %u, n_ubatch = %u\n", n_tokens_all, n_ubatch); - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return 1; - } + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { + const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + llama_batch batch = { + /* .n_tokens = */ (int32_t) n_tokens, + /* .token = */ all_batch.token ? all_batch.token + cur_token : nullptr, + /* .embd = */ all_batch.embd ? all_batch.embd + cur_token*n_embd : nullptr, + /* .pos = */ all_batch.pos ? all_batch.pos + cur_token : nullptr, + /* .n_seq_id = */ all_batch.n_seq_id ? all_batch.n_seq_id + cur_token : nullptr, + /* .seq_id = */ all_batch.seq_id ? all_batch.seq_id + cur_token : nullptr, + /* .logits = */ all_batch.logits ? all_batch.logits + cur_token : nullptr, + /* .all_pos_0 = */ all_batch.all_pos_0 + (llama_pos) cur_token*all_batch.all_pos_1, + /* .all_pos_1 = */ all_batch.all_pos_1, + /* .all_seq_id = */ all_batch.all_seq_id, + }; - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT(n_threads > 0); - ggml_backend_sched_reset(lctx.sched); - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + std::vector pos; - ggml_cgraph * gf = llama_build_graph(lctx, batch); + std::vector n_seq_id; + std::vector seq_id_arr; + std::vector> seq_id; - // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } - // the embeddings could be the second to last tensor, or the third to last tensor - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - } + batch.pos = pos.data(); + } - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } - // for big prompts, if BLAS is enabled, it is better to use only one thread - // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance - // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well - // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering - // with the BLAS calls. need a better solution - if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { - n_threads = std::min(4, n_threads); - } + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); + } - const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1; - if (ggml_cpu_has_cublas() && fully_offloaded) { - n_threads = 1; - } + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } -#ifdef GGML_USE_MPI - const int64_t n_layer = hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif + if (!llama_kv_cache_find_slot(kv_self, batch)) { + LLAMA_LOG_ERROR("%s: failed to find a slot in the cache", __func__); + return 1; + } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(lctx.backend_metal)) { - ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); - } -#endif + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); - if (lctx.backend_cpu != nullptr) { - ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); - } - ggml_backend_sched_graph_compute(lctx.sched, gf); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); + // change the CPU compute buffer to avoid overwriting inputs + size_t i_compute_buf = lctx.i_compute_buf; + lctx.i_compute_buf = (lctx.i_compute_buf + 1) % lctx.n_compute_bufs; + if (i_compute_buf == 0 && cur_token > 0) { + // sync all backends to ensure that the current buffer is not in use + printf("not enough buffers, syncing now\n"); + ggml_backend_sched_synchronize(lctx.sched); + } + for (auto it : lctx.bufs_compute) { + ggml_tallocr_t alloc = ggml_backend_sched_get_tallocr(lctx.sched, it.first); + ggml_tallocr_set_buffer(alloc, it.second.at(i_compute_buf)); + } -#ifdef GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif + ggml_backend_sched_reset(lctx.sched); - // update the kv ring buffer - { - if (kv_self.has_shift) { - kv_self.has_shift = false; - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].delta = 0; - } + ggml_cgraph * gf = llama_build_graph(lctx, batch); + + // the output is always the last tensor in the graph + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + + // the embeddings could be the second to last tensor, or the third to last tensor + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); } - kv_self.head += n_tokens; + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); - // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); } - } -#ifdef GGML_PERF - // print timing information per ggml operation (for debugging purposes) - // requires GGML_PERF to be defined - ggml_graph_print(gf); -#endif + #ifdef GGML_USE_MPI + const int64_t n_layer = hparams.n_layer; + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); + #endif - // plot the computation graph in dot format (for debugging purposes) - //if (n_past%100 == 0) { - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} + #ifdef GGML_USE_METAL + if (lctx.backend_metal != nullptr) { + ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); + } + #endif - // extract logits - // TODO: do not compute and extract logits if only embeddings are needed - // need to update the graphs to skip "result_output" - { - auto & logits_out = lctx.logits; + if (lctx.backend_cpu != nullptr) { + ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); + } -#ifndef NDEBUG - auto & logits_valid = lctx.logits_valid; - logits_valid.clear(); - logits_valid.resize(n_tokens); + ggml_backend_sched_graph_compute(lctx.sched, gf); - logits_out.clear(); -#endif + #ifdef GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); + #endif - ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); - GGML_ASSERT(res_backend != nullptr); - if (batch.logits) { - logits_out.resize(n_vocab * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; + // update the kv ring buffer + { + if (kv_self.has_shift) { + kv_self.has_shift = false; + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].delta = 0; } - ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); -#ifndef NDEBUG - logits_valid[i] = true; -#endif } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); -#ifndef NDEBUG - std::fill(logits_valid.begin(), logits_valid.end(), true); -#endif - } else { - logits_out.resize(n_vocab); - ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float)); -#ifndef NDEBUG - logits_valid[0] = true; -#endif + + kv_self.head += n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } } - ggml_backend_synchronize(res_backend); - } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + #ifdef GGML_PERF + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(gf); + #endif + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} - embedding_out.resize(n_embd); - ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); - ggml_backend_synchronize(embeddings_backend); + // extract logits + // TODO: do not compute and extract logits if only embeddings are needed + // need to update the graphs to skip "result_output" + { + ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res); + GGML_ASSERT(res_backend != nullptr); + if (batch.logits) { + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get_async(res_backend, res, logits_out + n_vocab*(cur_token + i), n_vocab*i*sizeof(float), n_vocab*sizeof(float)); + #ifndef NDEBUG + logits_valid[cur_token + i] = true; + #endif + } + } else if (lctx.logits_all) { + ggml_backend_tensor_get_async(res_backend, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float)); + #ifndef NDEBUG + std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true); + #endif + } else { + if (cur_token + n_tokens >= n_tokens_all) { + ggml_backend_tensor_get_async(res_backend, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float)); + #ifndef NDEBUG + logits_valid[0] = true; + #endif + } + } + } + + // FIXME + // extract embeddings + if (!lctx.embedding.empty()) { + GGML_ASSERT(!"not implemented"); + auto & embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); + } } + ggml_backend_sched_synchronize(lctx.sched); + lctx.i_compute_buf = 0; + // measure the performance only for the single-token evals - if (n_tokens == 1) { + if (n_tokens_all == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (n_tokens > 1) { + else if (n_tokens_all > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += n_tokens; + lctx.n_p_eval += n_tokens_all; } // get a more accurate load time, upon first eval @@ -9655,7 +9708,8 @@ struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, + /*.n_batch =*/ 4096, + /*.n_ubatch =*/ 256, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, @@ -9792,6 +9846,7 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; + cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -9830,6 +9885,8 @@ struct llama_context * llama_new_context_with_model( } LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -9913,7 +9970,8 @@ struct llama_context * llama_new_context_with_model( } // resized during inference, reserve maximum - ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); + //ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); + ctx->logits_size = hparams.n_vocab*cparams.n_ctx; if (params.embedding){ ctx->embedding.resize(hparams.n_embd); @@ -9964,17 +10022,17 @@ struct llama_context * llama_new_context_with_model( ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES); - ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); + ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); // initialize scheduler with the worst-case graph ggml_backend_sched_init_measure(ctx->sched, gf); - ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); + ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); for (ggml_backend_t backend : ctx->backends) { ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend); @@ -9983,6 +10041,41 @@ struct llama_context * llama_new_context_with_model( ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); } + // duplicate cpu buffers for microbatching + const int n_ub = (cparams.n_batch + cparams.n_ubatch - 1) / cparams.n_ubatch; + ctx->n_compute_bufs = n_ub; + LLAMA_LOG_INFO("%s: allocating %d compute buffers\n", __func__, n_ub); + + for (ggml_backend_t b : ctx->backends) { + ggml_tallocr_t alloc = ggml_backend_sched_get_tallocr(ctx->sched, b); + ggml_backend_buffer_t buf = ggml_tallocr_get_buffer(alloc); + size_t buf_size = ggml_backend_buffer_get_size(buf); + ctx->bufs_compute[b].push_back(buf); + auto * buft = ggml_backend_buffer_get_type(buf); + for (int i = 1; i < n_ub; ++i) { + ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(buft, buf_size); + if (buf == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffer\n", __func__); + llama_free(ctx); + return nullptr; + } + ctx->bufs_compute[b].push_back(buf); + } + } + + // allocate buffer for logits output + ctx->buf_logits = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), hparams.n_vocab*cparams.n_ctx*sizeof(float)); + if (ctx->buf_logits == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__); + llama_free(ctx); + return nullptr; + } + ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_logits); + ggml_backend_buffer_clear(ctx->buf_logits, 0); + LLAMA_LOG_INFO("%s: logits buffer size = %8.2f MiB, type = %s\n", __func__, + ggml_backend_buffer_get_size(ctx->buf_logits) / 1024.0 / 1024.0, + ggml_backend_buffer_name(ctx->buf_logits)); + // note: the number of splits during measure is higher than during inference due to the kv shift int n_splits = ggml_backend_sched_get_n_splits(ctx->sched); LLAMA_LOG_INFO("%s: graph splits (measure): %d\n", __func__, n_splits); @@ -10292,7 +10385,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_rng = LLAMA_MAX_RNG_STATE; const size_t s_logits_size = sizeof(size_t); // assume worst case for logits although only currently set ones are serialized - const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_logits = ctx->logits_size * sizeof(float); const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding = ctx->embedding.size() * sizeof(float); const size_t s_kv_size = sizeof(size_t); @@ -10384,12 +10477,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // copy logits { - const size_t logits_size = ctx->logits.size(); + const size_t logits_size = ctx->logits_size; data_ctx->write(&logits_size, sizeof(logits_size)); if (logits_size) { - data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); + data_ctx->write(ctx->logits, logits_size * sizeof(float)); } } @@ -10491,12 +10584,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); - GGML_ASSERT(ctx->logits.capacity() >= logits_size); + GGML_ASSERT(ctx->logits_size >= logits_size); if (logits_size) { - ctx->logits.resize(logits_size); + //ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); + memcpy(ctx->logits, inp, logits_size * sizeof(float)); inp += logits_size * sizeof(float); } } @@ -10771,15 +10864,23 @@ int32_t llama_decode( } float * llama_get_logits(struct llama_context * ctx) { - return ctx->logits.data(); + ggml_backend_sched_synchronize(ctx->sched); + ctx->i_compute_buf = 0; + return ctx->logits; } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + ggml_backend_sched_synchronize(ctx->sched); + ctx->i_compute_buf = 0; + assert(ctx->logits_valid.at(i)); - return ctx->logits.data() + i*ctx->model.hparams.n_vocab; + return ctx->logits + i*ctx->model.hparams.n_vocab; } float * llama_get_embeddings(struct llama_context * ctx) { + ggml_backend_sched_synchronize(ctx->sched); + ctx->i_compute_buf = 0; + return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index bb60545575932..21bdf84172273 100644 --- a/llama.h +++ b/llama.h @@ -219,7 +219,8 @@ extern "C" { struct llama_context_params { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model - uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_batch; // prompt processing maximum batch size (ignored if n_ubatch is set) + uint32_t n_ubatch; // prompt processing maximum batch size uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`