From cec409aa98e9ae1cbdb0ba0003951ce5a2e100b2 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Fri, 19 Apr 2024 05:09:03 -0700 Subject: [PATCH 01/17] DRAFT: Introduction of CUDA Graphs to LLama.cpp --- ggml-cuda.cu | 112 ++++++++++++++++++++++++++++++++++++++++++++++ ggml-cuda/cpy.cu | 29 ++++++++++++ ggml-cuda/cpy.cuh | 2 + 3 files changed, 143 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d277104d12177..07a3bc32e76d4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2405,11 +2405,63 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +struct ggml_cudaGraph { + int count=0; + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t numNodes = 0; + int softmax_ne0 = 0; +}; + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); + // Objects required for CUDA Graph +#define MAX_NODES_IN_CUDA_GRAPH 10000 + static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory) + bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations. + char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; + bool cudaGraphUpdateRequired = false; + // pointer to CUDA cpy kernel, which is required to identify + // kernel parameters which need updated in the graph for each token + void* ggmlCudaCpyFn = nullptr; + if(useCudaGraph) { + + if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true; + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + int k=0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + // Identify if the graph needs updated for this token due to the number of elements changing + // (identified by inspecting soft max op parameters) + if(node->op == GGML_OP_SOFT_MAX) { + if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) { + cudaGraphUpdateRequired = true; + cudaGraph.softmax_ne0 = node->src[0]->ne[0]; + } + } + if(node->op == GGML_OP_CPY) { + // store the copy op parameter which changes with each token. + updatedKernelArg[k++]=(char**) &(node->src[1]->data); + if(ggmlCudaCpyFn == nullptr){ + // store a pointer to the copy op CUDA kernel to identify it later + ggmlCudaCpyFn = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + } + } + } + } + + if(useCudaGraph && cudaGraphUpdateRequired) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); + } + + // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. + // With use of CUDA graphs, the execution will be performed by the graph launch. + if(!useCudaGraph || cudaGraphUpdateRequired) { + //temporarily avoid indenting here to make code review easier for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2432,7 +2484,67 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } GGML_ASSERT(ok); } + } + + if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); + } + if(useCudaGraph){ + + if(cudaGraph.instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0)); + } + + + // Perform update to graph (if required for this token), and change copy parameter (required for every token) + + cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; + CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; + cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; + if(cudaGraphUpdateRequired) { + // Extract nodes from graph + if(cudaGraph.numNodes == 0) { + CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); + } + CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes)); + + // Loop over nodes, and extract kernel parameters fro each node + for(size_t i=0; istream())); + } + cudaGraph.count++; return GGML_STATUS_SUCCESS; } diff --git a/ggml-cuda/cpy.cu b/ggml-cuda/cpy.cu index 16d9c8fffb4b3..12d741f017d3b 100644 --- a/ggml-cuda/cpy.cu +++ b/ggml-cuda/cpy.cu @@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; ggml_cuda_cpy(ctx, src0, dst); } + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } +} + diff --git a/ggml-cuda/cpy.cuh b/ggml-cuda/cpy.cuh index f0b2c453bfe6a..7961674266ee1 100644 --- a/ggml-cuda/cpy.cuh +++ b/ggml-cuda/cpy.cuh @@ -5,3 +5,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); From c8dd0e7c1c6864c889df33dc40ddc33589adcfb5 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Mon, 22 Apr 2024 01:32:06 -0700 Subject: [PATCH 02/17] FIx issues raised in comments --- ggml-cuda.cu | 62 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 07a3bc32e76d4..670ba78a028f7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2405,23 +2405,33 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +#if (CUDART_VERSION >= 12000) +#define USE_CUDA_GRAPH +#endif + +#ifdef USE_CUDA_GRAPH +#define MAX_NODES_IN_CUDA_GRAPH 10000 struct ggml_cudaGraph { int count=0; cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; size_t numNodes = 0; int softmax_ne0 = 0; + cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; + CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; + cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; }; +#endif GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); +#ifdef USE_CUDA_GRAPH // Objects required for CUDA Graph -#define MAX_NODES_IN_CUDA_GRAPH 10000 - static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory) - bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations. + static ggml_cudaGraph cudaGraph; + bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations. char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; bool cudaGraphUpdateRequired = false; // pointer to CUDA cpy kernel, which is required to identify @@ -2458,6 +2468,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); } +#else + bool useCudaGraph = false; + bool cudaGraphUpdateRequired = false; +#endif + // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. // With use of CUDA graphs, the execution will be performed by the graph launch. if(!useCudaGraph || cudaGraphUpdateRequired) { @@ -2486,6 +2501,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } + #ifdef USE_CUDA_GRAPH if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); } @@ -2498,26 +2514,26 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Perform update to graph (if required for this token), and change copy parameter (required for every token) - cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; - CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; - cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; - if(cudaGraphUpdateRequired) { // Extract nodes from graph if(cudaGraph.numNodes == 0) { CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); } - CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes)); + CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes)); // Loop over nodes, and extract kernel parameters fro each node for(size_t i=0; istream())); } cudaGraph.count++; +#endif return GGML_STATUS_SUCCESS; } From 800f4fe48eba1f36c0aba66a5a28521ab38f971d Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Mon, 22 Apr 2024 04:50:39 -0700 Subject: [PATCH 03/17] Tidied to now only use CUDA runtime (not mixed with driver calls) --- ggml-cuda.cu | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 670ba78a028f7..7da061240a540 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2418,8 +2418,7 @@ struct ggml_cudaGraph { size_t numNodes = 0; int softmax_ne0 = 0; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; - CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; - cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; + cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; }; #endif @@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Loop over nodes, and extract kernel parameters fro each node for(size_t i=0; i Date: Mon, 22 Apr 2024 09:01:44 -0700 Subject: [PATCH 04/17] disable for multi-gpu and batch size > 1 --- ggml-cuda.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7da061240a540..6f1973f2c663a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2436,6 +2436,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // pointer to CUDA cpy kernel, which is required to identify // kernel parameters which need updated in the graph for each token void* ggmlCudaCpyFn = nullptr; + + if(ggml_backend_cuda_get_device_count() > 1){ + useCudaGraph = false; // disable CUDA graphs for multi-gpu for now. TO DO investigate + } + if(useCudaGraph) { if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true; @@ -2447,6 +2452,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Identify if the graph needs updated for this token due to the number of elements changing // (identified by inspecting soft max op parameters) if(node->op == GGML_OP_SOFT_MAX) { + if(node->src[1]->ne[1] > 1){ + useCudaGraph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate + } if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) { cudaGraphUpdateRequired = true; cudaGraph.softmax_ne0 = node->src[0]->ne[0]; From df4719ec7e3fa19617e671d75dcd8319b6777397 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 23 Apr 2024 06:27:08 -0700 Subject: [PATCH 05/17] Disable CUDA graphs for old GPU arch and with env var --- ggml-cuda.cu | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6f1973f2c663a..5b696719b82cc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2419,9 +2419,12 @@ struct ggml_cudaGraph { int softmax_ne0 = 0; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; + bool disableDueToGpuArch=false; }; #endif +const bool disableCudaGraphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -2437,8 +2440,21 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // kernel parameters which need updated in the graph for each token void* ggmlCudaCpyFn = nullptr; - if(ggml_backend_cuda_get_device_count() > 1){ - useCudaGraph = false; // disable CUDA graphs for multi-gpu for now. TO DO investigate + + if(cudaGraph.count==0){ + cudaDeviceProp prop; + int device; + cudaGetDevice(&device); + cudaGetDeviceProperties(&prop, device); + if (prop.major < 8){ + cudaGraph.disableDueToGpuArch=true; + } + } + + // Disable CUDA graphs in presence of env var or old GPU. + // Also disable for multi-gpu for now. TO DO investigate + if(disableCudaGraphs || cudaGraph.disableDueToGpuArch || ggml_backend_cuda_get_device_count() > 1){ + useCudaGraph = false; } if(useCudaGraph) { From c3d4ead1367627eb9332f03e9c01f19fae65528a Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 24 Apr 2024 02:37:57 -0700 Subject: [PATCH 06/17] added missing CUDA_CHECKs --- ggml-cuda.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5b696719b82cc..1fc21f5405073 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2440,12 +2440,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // kernel parameters which need updated in the graph for each token void* ggmlCudaCpyFn = nullptr; - if(cudaGraph.count==0){ cudaDeviceProp prop; int device; - cudaGetDevice(&device); - cudaGetDeviceProperties(&prop, device); + CUDA_CHECK(cudaGetDevice(&device)); + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); if (prop.major < 8){ cudaGraph.disableDueToGpuArch=true; } From d403b180a69cc4114eef04f1c3b2a6b0db507619 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 24 Apr 2024 05:43:26 -0700 Subject: [PATCH 07/17] Addressed comments --- ggml-cuda.cu | 129 ++++++++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 63 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1fc21f5405073..344d7d61aa7fb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2411,19 +2411,19 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #ifdef USE_CUDA_GRAPH #define MAX_NODES_IN_CUDA_GRAPH 10000 -struct ggml_cudaGraph { - int count=0; +struct ggml_cuda_graph { + int count = 0; cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; - size_t numNodes = 0; + size_t num_nodes = 0; int softmax_ne0 = 0; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; - bool disableDueToGpuArch=false; + bool disable_due_to_gpu_arch = false; }; #endif -const bool disableCudaGraphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); +const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -2432,33 +2432,29 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #ifdef USE_CUDA_GRAPH // Objects required for CUDA Graph - static ggml_cudaGraph cudaGraph; - bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations. - char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; - bool cudaGraphUpdateRequired = false; + static ggml_cuda_graph cuda_graph; + bool use_cuda_graph = (cuda_graph.count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations. + char ** updated_kernel_arg[MAX_NODES_IN_CUDA_GRAPH]; + bool cuda_graph_update_required = false; // pointer to CUDA cpy kernel, which is required to identify // kernel parameters which need updated in the graph for each token - void* ggmlCudaCpyFn = nullptr; + void * ggml_cuda_cpy_fn_ptr = nullptr; - if(cudaGraph.count==0){ - cudaDeviceProp prop; - int device; - CUDA_CHECK(cudaGetDevice(&device)); - CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - if (prop.major < 8){ - cudaGraph.disableDueToGpuArch=true; + if(cuda_graph.count == 0){ + if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800){ + cuda_graph.disable_due_to_gpu_arch=true; } } // Disable CUDA graphs in presence of env var or old GPU. // Also disable for multi-gpu for now. TO DO investigate - if(disableCudaGraphs || cudaGraph.disableDueToGpuArch || ggml_backend_cuda_get_device_count() > 1){ - useCudaGraph = false; + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || ggml_backend_cuda_get_device_count() > 1){ + use_cuda_graph = false; } - if(useCudaGraph) { + if(use_cuda_graph) { - if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true; + if(cuda_graph.instance == nullptr) cuda_graph_update_required=true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph int k=0; @@ -2468,36 +2464,36 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // (identified by inspecting soft max op parameters) if(node->op == GGML_OP_SOFT_MAX) { if(node->src[1]->ne[1] > 1){ - useCudaGraph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate + use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate } - if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) { - cudaGraphUpdateRequired = true; - cudaGraph.softmax_ne0 = node->src[0]->ne[0]; + if(node->src[0]->ne[0] != cuda_graph.softmax_ne0) { + cuda_graph_update_required = true; + cuda_graph.softmax_ne0 = node->src[0]->ne[0]; } } if(node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. - updatedKernelArg[k++]=(char**) &(node->src[1]->data); - if(ggmlCudaCpyFn == nullptr){ + updated_kernel_arg[k++]=(char **) &(node->src[1]->data); + if(ggml_cuda_cpy_fn_ptr == nullptr){ // store a pointer to the copy op CUDA kernel to identify it later - ggmlCudaCpyFn = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); } } } } - if(useCudaGraph && cudaGraphUpdateRequired) { // Start CUDA graph capture + if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); } #else - bool useCudaGraph = false; - bool cudaGraphUpdateRequired = false; + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; #endif // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. // With use of CUDA graphs, the execution will be performed by the graph launch. - if(!useCudaGraph || cudaGraphUpdateRequired) { + if(!use_cuda_graph || cuda_graph_update_required) { //temporarily avoid indenting here to make code review easier for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2524,67 +2520,74 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } #ifdef USE_CUDA_GRAPH - if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); + if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph)); } - if(useCudaGraph){ + if(use_cuda_graph){ - if(cudaGraph.instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0)); + if(cuda_graph.instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0)); } // Perform update to graph (if required for this token), and change copy parameter (required for every token) - if(cudaGraphUpdateRequired) { + if(cuda_graph_update_required) { // Extract nodes from graph - if(cudaGraph.numNodes == 0) { - CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); + if(cuda_graph.num_nodes == 0) { + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, nullptr, &cuda_graph.num_nodes)); } - CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes)); + // Subsequent call with non-null argument gets nodes + CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes)); // Loop over nodes, and extract kernel parameters fro each node - for(size_t i=0; istream())); + CUDA_CHECK(cudaGraphLaunch(cuda_graph.instance, cuda_ctx->stream())); } - cudaGraph.count++; + cuda_graph.count++; #endif return GGML_STATUS_SUCCESS; } From 408759687fc669ccab7407700e9404e26d7b46fd Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 24 Apr 2024 06:31:08 -0700 Subject: [PATCH 08/17] further addressed comments --- ggml-cuda.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 344d7d61aa7fb..46cbb7c7d95c9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2460,7 +2460,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t int k=0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - // Identify if the graph needs updated for this token due to the number of elements changing + // Identify if the graph needs to be updated for this token due to the number of elements changing // (identified by inspecting soft max op parameters) if(node->op == GGML_OP_SOFT_MAX) { if(node->src[1]->ne[1] > 1){ @@ -2489,10 +2489,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #else bool use_cuda_graph = false; bool cuda_graph_update_required = false; -#endif +#endif // USE_CUDA_GRAPH - // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. - // With use of CUDA graphs, the execution will be performed by the graph launch. + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. if(!use_cuda_graph || cuda_graph_update_required) { //temporarily avoid indenting here to make code review easier for (int i = 0; i < cgraph->n_nodes; i++) { @@ -2519,7 +2519,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } - #ifdef USE_CUDA_GRAPH +#ifdef USE_CUDA_GRAPH if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph)); } @@ -2541,7 +2541,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Subsequent call with non-null argument gets nodes CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes)); - // Loop over nodes, and extract kernel parameters fro each node + // Loop over nodes, and extract kernel parameters from each node for(size_t i=0; istream())); } cuda_graph.count++; -#endif +#endif // USE_CUDA_GRAPH return GGML_STATUS_SUCCESS; } From 0640427f7b0343cf8832589f912407d19510870f Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Thu, 25 Apr 2024 00:51:48 -0700 Subject: [PATCH 09/17] limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake --- CMakeLists.txt | 1 + ggml-cuda.cu | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f134a153bb4ff..a5a2304923efd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -413,6 +413,7 @@ if (LLAMA_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") add_compile_definitions(GGML_USE_CUDA) + add_compile_definitions(GGML_ALLOW_CUDA_GRAPHS) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 46cbb7c7d95c9..a63b9b554a84f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2405,7 +2405,7 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } -#if (CUDART_VERSION >= 12000) +#if (CUDART_VERSION >= 12000) && defined(GGML_ALLOW_CUDA_GRAPHS) #define USE_CUDA_GRAPH #endif From d44e0fb22c36f4b18761296e84bbd1d256891b87 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 30 Apr 2024 03:29:35 -0700 Subject: [PATCH 10/17] Added more comprehensive graph node checking --- ggml-cuda.cu | 68 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a63b9b554a84f..2977902bd696b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2409,6 +2409,14 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { #define USE_CUDA_GRAPH #endif +struct ggml_graph_node_properties { + void * node_address; + int node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + #ifdef USE_CUDA_GRAPH #define MAX_NODES_IN_CUDA_GRAPH 10000 struct ggml_cuda_graph { @@ -2416,15 +2424,42 @@ struct ggml_cuda_graph { cudaGraph_t graph = nullptr; cudaGraphExec_t instance = nullptr; size_t num_nodes = 0; - int softmax_ne0 = 0; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + int number_consecutive_updates = 0; + ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; }; #endif const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); +GGML_CALL static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + graph_node_properties->node_address = node; + graph_node_properties->node_op = node->op; + for(int i=0; ine[i] = node->ne[i]; + graph_node_properties->nb[i] = node->nb[i]; + } + for(int i=0; isrc_address[i] = node->src[i]; + } +} + +GGML_CALL static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if(node != graph_node_properties->node_address) return false; + if(node->op != graph_node_properties->node_op) return false; + for(int i=0; ine[i] != graph_node_properties->ne[i]) return false; + if(node->nb[i] != graph_node_properties->nb[i]) return false; + } + for(int i=0; isrc[i] != graph_node_properties->src_address[i]) return false; + } + return true; +} + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -2446,9 +2481,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } - // Disable CUDA graphs in presence of env var or old GPU. + // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || ggml_backend_cuda_get_device_count() > 1){ + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || + ggml_backend_cuda_get_device_count() > 1){ use_cuda_graph = false; } @@ -2456,20 +2492,25 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if(cuda_graph.instance == nullptr) cuda_graph_update_required=true; + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if(!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + } + if(!has_matching_properties) cuda_graph_update_required = true; + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + } + // Loop over nodes in GGML graph to obtain info needed for CUDA graph int k=0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - // Identify if the graph needs to be updated for this token due to the number of elements changing - // (identified by inspecting soft max op parameters) if(node->op == GGML_OP_SOFT_MAX) { if(node->src[1]->ne[1] > 1){ use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate } - if(node->src[0]->ne[0] != cuda_graph.softmax_ne0) { - cuda_graph_update_required = true; - cuda_graph.softmax_ne0 = node->src[0]->ne[0]; - } } if(node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. @@ -2480,6 +2521,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } } + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if(cuda_graph_update_required) { + cuda_graph.number_consecutive_updates++; + } + else { + cuda_graph.number_consecutive_updates = 0; + } + if (cuda_graph.number_consecutive_updates >= 4) cuda_graph.disable_due_to_too_many_updates = true; } if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture From eb9f15fb6fcb81384f732c4601a5b25c016a5143 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 30 Apr 2024 06:19:51 -0700 Subject: [PATCH 11/17] With mechanism to fall back if graph capture fails --- ggml-cuda.cu | 35 ++++++++++++++++++++++++++++++----- ggml-cuda/common.cuh | 1 - 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2977902bd696b..840d61ac9085b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -48,11 +48,20 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -[[noreturn]] +static bool disable_cuda_graphs_due_to_failed_capture = false; + void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails cudaGetDevice(&id); + if(strcmp(msg,"operation not permitted when stream is capturing")==0 || + strcmp(msg,"operation failed due to a previous error during capture")==0) { + // CUDA graph capture has failed, but we can fall back to regular stream-based CUDA + // so mark as failed, clear the error and return. + disable_cuda_graphs_due_to_failed_capture = true; + cudaGetLastError(); + return; + } fprintf(stderr, "CUDA error: %s\n", msg); fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); fprintf(stderr, " %s\n", stmt); @@ -2428,6 +2437,7 @@ struct ggml_cuda_graph { cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; }; @@ -2481,9 +2491,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } - // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || + cuda_graph.disable_due_to_too_many_updates || cuda_graph.disable_due_to_failed_graph_capture || ggml_backend_cuda_get_device_count() > 1){ use_cuda_graph = false; } @@ -2540,11 +2552,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = false; bool cuda_graph_update_required = false; #endif // USE_CUDA_GRAPH - + + bool graph_evaluated_or_captured = false; + + while(!graph_evaluated_or_captured) { + // Temporarily avoid indenting here (and below the following if) to make code review easier + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if(!use_cuda_graph || cuda_graph_update_required) { - //temporarily avoid indenting here to make code review easier + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2572,6 +2589,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #ifdef USE_CUDA_GRAPH if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph)); + if(disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_graph.disable_due_to_failed_graph_capture = true; + } + } + else { + graph_evaluated_or_captured = true; + } } if(use_cuda_graph){ diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 481065b2a3484..418056a98955c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -172,7 +172,6 @@ #define GGML_CUDA_MAX_STREAMS 8 -[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg); #define CUDA_CHECK_GEN(err, success, error_fn) \ From 909e4c664b4fe2a5e596240eb4ae10a64329ddb1 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 30 Apr 2024 11:39:59 -0700 Subject: [PATCH 12/17] Revert "With mechanism to fall back if graph capture fails" This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143. --- ggml-cuda.cu | 35 +++++------------------------------ ggml-cuda/common.cuh | 1 + 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 840d61ac9085b..2977902bd696b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -48,20 +48,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -static bool disable_cuda_graphs_due_to_failed_capture = false; - +[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails cudaGetDevice(&id); - if(strcmp(msg,"operation not permitted when stream is capturing")==0 || - strcmp(msg,"operation failed due to a previous error during capture")==0) { - // CUDA graph capture has failed, but we can fall back to regular stream-based CUDA - // so mark as failed, clear the error and return. - disable_cuda_graphs_due_to_failed_capture = true; - cudaGetLastError(); - return; - } fprintf(stderr, "CUDA error: %s\n", msg); fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); fprintf(stderr, " %s\n", stmt); @@ -2437,7 +2428,6 @@ struct ggml_cuda_graph { cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; }; @@ -2491,11 +2481,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. + // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || - cuda_graph.disable_due_to_too_many_updates || cuda_graph.disable_due_to_failed_graph_capture || + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || ggml_backend_cuda_get_device_count() > 1){ use_cuda_graph = false; } @@ -2552,16 +2540,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = false; bool cuda_graph_update_required = false; #endif // USE_CUDA_GRAPH - - bool graph_evaluated_or_captured = false; - - while(!graph_evaluated_or_captured) { - // Temporarily avoid indenting here (and below the following if) to make code review easier - + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if(!use_cuda_graph || cuda_graph_update_required) { - + //temporarily avoid indenting here to make code review easier for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2589,14 +2572,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #ifdef USE_CUDA_GRAPH if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph)); - if(disable_cuda_graphs_due_to_failed_capture) { - use_cuda_graph = false; - cuda_graph.disable_due_to_failed_graph_capture = true; - } - } - else { - graph_evaluated_or_captured = true; - } } if(use_cuda_graph){ diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 418056a98955c..481065b2a3484 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -172,6 +172,7 @@ #define GGML_CUDA_MAX_STREAMS 8 +[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg); #define CUDA_CHECK_GEN(err, success, error_fn) \ From 58199503a85231e512ecb1483b5f7cad0d0df384 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 30 Apr 2024 06:19:51 -0700 Subject: [PATCH 13/17] Fall back if graph capture fails and address other comments --- ggml-cuda.cu | 108 ++++++++++++++++++++++++++++--------------- ggml-cuda/common.cuh | 6 ++- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2977902bd696b..88c1814c1da8a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -48,11 +48,20 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -[[noreturn]] +static bool disable_cuda_graphs_due_to_failed_capture = false; + void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails cudaGetDevice(&id); + if(strcmp(msg,"operation not permitted when stream is capturing")==0 || + strcmp(msg,"operation failed due to a previous error during capture")==0) { + // CUDA graph capture has failed, but we can fall back to regular stream-based CUDA + // so mark as failed, clear the error and return. + disable_cuda_graphs_due_to_failed_capture = true; + cudaGetLastError(); + return; + } fprintf(stderr, "CUDA error: %s\n", msg); fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); fprintf(stderr, " %s\n", stmt); @@ -2428,6 +2437,7 @@ struct ggml_cuda_graph { cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; }; @@ -2436,26 +2446,28 @@ struct ggml_cuda_graph { const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); GGML_CALL static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node; + graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; for(int i=0; ine[i] = node->ne[i]; graph_node_properties->nb[i] = node->nb[i]; } for(int i=0; isrc_address[i] = node->src[i]; + graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } } GGML_CALL static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if(node != graph_node_properties->node_address) return false; + if(node->data != graph_node_properties->node_address && + node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false; if(node->op != graph_node_properties->node_op) return false; for(int i=0; ine[i] != graph_node_properties->ne[i]) return false; if(node->nb[i] != graph_node_properties->nb[i]) return false; } for(int i=0; isrc[i] != graph_node_properties->src_address[i]) return false; + if(node->src[i] && node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false; } return true; } @@ -2467,46 +2479,54 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #ifdef USE_CUDA_GRAPH // Objects required for CUDA Graph - static ggml_cuda_graph cuda_graph; - bool use_cuda_graph = (cuda_graph.count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations. + if(cuda_ctx->cuda_graph == nullptr) + { + cuda_ctx->cuda_graph = (ggml_cuda_graph *) malloc(sizeof(ggml_cuda_graph)); + } + bool use_cuda_graph = (cuda_ctx->cuda_graph->count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations. char ** updated_kernel_arg[MAX_NODES_IN_CUDA_GRAPH]; bool cuda_graph_update_required = false; // pointer to CUDA cpy kernel, which is required to identify // kernel parameters which need updated in the graph for each token void * ggml_cuda_cpy_fn_ptr = nullptr; - if(cuda_graph.count == 0){ + if(cuda_ctx->cuda_graph->count == 0){ if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800){ - cuda_graph.disable_due_to_gpu_arch=true; + cuda_ctx->cuda_graph->disable_due_to_gpu_arch=true; } } - // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || + if(disable_cuda_graphs || cuda_ctx->cuda_graph->disable_due_to_gpu_arch || + cuda_ctx->cuda_graph->disable_due_to_too_many_updates || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture || ggml_backend_cuda_get_device_count() > 1){ use_cuda_graph = false; } if(use_cuda_graph) { - if(cuda_graph.instance == nullptr) cuda_graph_update_required=true; + if(cuda_ctx->cuda_graph->instance == nullptr) cuda_graph_update_required=true; // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; if(!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } if(!has_matching_properties) cuda_graph_update_required = true; - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]); + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } // Loop over nodes in GGML graph to obtain info needed for CUDA graph int k=0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if(node->op == GGML_OP_MUL_MAT_ID) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture + } if(node->op == GGML_OP_SOFT_MAX) { if(node->src[1]->ne[1] > 1){ use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate @@ -2524,12 +2544,12 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. if(cuda_graph_update_required) { - cuda_graph.number_consecutive_updates++; + cuda_ctx->cuda_graph->number_consecutive_updates++; } else { - cuda_graph.number_consecutive_updates = 0; + cuda_ctx->cuda_graph->number_consecutive_updates = 0; } - if (cuda_graph.number_consecutive_updates >= 4) cuda_graph.disable_due_to_too_many_updates = true; + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; } if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture @@ -2540,11 +2560,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = false; bool cuda_graph_update_required = false; #endif // USE_CUDA_GRAPH - + + bool graph_evaluated_or_captured = false; + + while(!graph_evaluated_or_captured) { + // Temporarily avoid indenting here (and below the following if) to make code review easier + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if(!use_cuda_graph || cuda_graph_update_required) { - //temporarily avoid indenting here to make code review easier + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2571,12 +2596,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #ifdef USE_CUDA_GRAPH if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + if(disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; + } + else { + graph_evaluated_or_captured = true; // CUDA graph has been captured + } + } + else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } } if(use_cuda_graph){ - if(cuda_graph.instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0)); + if(cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } @@ -2584,19 +2620,19 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if(cuda_graph_update_required) { // Extract nodes from graph - if(cuda_graph.num_nodes == 0) { + if(cuda_ctx->cuda_graph->num_nodes == 0) { // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, nullptr, &cuda_graph.num_nodes)); + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); } // Subsequent call with non-null argument gets nodes - CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes)); + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes, &cuda_ctx->cuda_graph->num_nodes)); // Loop over nodes, and extract kernel parameters from each node - for(size_t i=0; icuda_graph->num_nodes; i++) { cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_graph.nodes[i], &node_type)); + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); if (node_type == cudaGraphNodeTypeKernel) { - auto stat = cudaGraphKernelNodeGetParams(cuda_graph.nodes[i], &cuda_graph.params[i]); // Get params using runtime + auto stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime if(stat == cudaErrorInvalidDeviceFunction) { // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // We don't need to update blas nodes, so clear error and move on. @@ -2613,31 +2649,31 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // replace that argument with the updated value in the CUDA graph if(!cuda_graph_update_required) { // on update steps, the live parameters will already be captured int k=0; - for(size_t i=0; icuda_graph->num_nodes; i++) { + if(cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { char ** updated_kernel_arg_ptr = updated_kernel_arg[k++]; - cuda_graph.params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_graph.nodes[i], &cuda_graph.params[i])); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); } } } // Update graph executable cudaGraphExecUpdateResultInfo result_info; - auto stat = cudaGraphExecUpdate(cuda_graph.instance, cuda_graph.graph, &result_info); + auto stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); if(stat == cudaErrorGraphExecUpdateFailure) { // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate cudaGetLastError(); - CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_graph.instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); } - cuda_graph.count++; + cuda_ctx->cuda_graph->count++; #endif // USE_CUDA_GRAPH return GGML_STATUS_SUCCESS; } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 481065b2a3484..7dbd823c4940c 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -172,7 +172,6 @@ #define GGML_CUDA_MAX_STREAMS 8 -[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg); #define CUDA_CHECK_GEN(err, success, error_fn) \ @@ -479,6 +478,8 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs }; +struct ggml_cuda_graph; + struct ggml_backend_cuda_context { int device; std::string name; @@ -487,6 +488,8 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + ggml_cuda_graph * cuda_graph = nullptr; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { @@ -506,6 +509,7 @@ struct ggml_backend_cuda_context { CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); } } + if(cuda_graph != nullptr) free(cuda_graph); } cudaStream_t stream(int device, int stream) { From e830949e985afc6ee9bd138ebca615846ffcf7dd Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 7 May 2024 15:05:10 +0200 Subject: [PATCH 14/17] - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS - rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS - updated Makefile build to enable CUDA graphs - removed graph capture failure checking in ggml_cuda_error using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context - fixed several resource leaks - fixed issue with zero node graphs - changed fixed size arrays to vectors - removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed - removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row - changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX - code style fixes - things to look into - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes --- CMakeLists.txt | 2 +- Makefile | 2 +- ggml-cuda.cu | 332 +++++++++++++++++++++++-------------------- ggml-cuda/clamp.cu | 1 - ggml-cuda/common.cuh | 42 +++++- ggml-cuda/convert.cu | 4 +- ggml-cuda/mmq.cu | 30 ++-- ggml-cuda/mmvq.cu | 6 +- ggml-cuda/scale.cu | 1 - 9 files changed, 235 insertions(+), 185 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3824321599eaf..07b3a4388b0a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -405,7 +405,7 @@ if (LLAMA_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") add_compile_definitions(GGML_USE_CUDA) - add_compile_definitions(GGML_ALLOW_CUDA_GRAPHS) + add_compile_definitions(GGML_CUDA_USE_GRAPHS) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() diff --git a/Makefile b/Makefile index c568dd008f350..3fa56d13a4671 100644 --- a/Makefile +++ b/Makefile @@ -433,7 +433,7 @@ ifdef LLAMA_CUDA else CUDA_PATH ?= /usr/local/cuda endif - MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib OBJS += ggml-cuda.o OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e42a81627e595..5beb51f46a21e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -49,20 +49,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -static bool disable_cuda_graphs_due_to_failed_capture = false; - +[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails cudaGetDevice(&id); - if(strcmp(msg,"operation not permitted when stream is capturing")==0 || - strcmp(msg,"operation failed due to a previous error during capture")==0) { - // CUDA graph capture has failed, but we can fall back to regular stream-based CUDA - // so mark as failed, clear the error and return. - disable_cuda_graphs_due_to_failed_capture = true; - cudaGetLastError(); - return; - } fprintf(stderr, "CUDA error: %s\n", msg); fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); fprintf(stderr, " %s\n", stmt); @@ -1656,7 +1647,7 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -1679,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); @@ -2419,60 +2410,46 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } -#if (CUDART_VERSION >= 12000) && defined(GGML_ALLOW_CUDA_GRAPHS) -#define USE_CUDA_GRAPH -#endif - -struct ggml_graph_node_properties { - void * node_address; - int node_op; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; -}; - -#ifdef USE_CUDA_GRAPH -#define MAX_NODES_IN_CUDA_GRAPH 10000 -struct ggml_cuda_graph { - int count = 0; - cudaGraph_t graph = nullptr; - cudaGraphExec_t instance = nullptr; - size_t num_nodes = 0; - cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; - cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH]; - bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; - int number_consecutive_updates = 0; - ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH]; -}; -#endif - -const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr); - -GGML_CALL static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { +static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; - for(int i=0; ine[i] = node->ne[i]; graph_node_properties->nb[i] = node->nb[i]; } - for(int i=0; isrc_address[i] = node->src[i] ? node->src[i]->data : nullptr; } } -GGML_CALL static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if(node->data != graph_node_properties->node_address && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false; - if(node->op != graph_node_properties->node_op) return false; - for(int i=0; ine[i] != graph_node_properties->ne[i]) return false; - if(node->nb[i] != graph_node_properties->nb[i]) return false; +static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if (node->data != graph_node_properties->node_address && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW) { + return false; + } + + if (node->op != graph_node_properties->node_op) { + return false; + } + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != graph_node_properties->ne[i]) { + return false; + } + if (node->nb[i] != graph_node_properties->nb[i]) { + return false; + } } - for(int i=0; isrc[i] && node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i] && + node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW + ) { + return false; + } } return true; } @@ -2483,82 +2460,121 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t ggml_cuda_set_device(cuda_ctx->device); #ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + // Objects required for CUDA Graph - if(cuda_ctx->cuda_graph == nullptr) - { - cuda_ctx->cuda_graph = (ggml_cuda_graph *) malloc(sizeof(ggml_cuda_graph)); + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } - bool use_cuda_graph = (cuda_ctx->cuda_graph->count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations. - char ** updated_kernel_arg[MAX_NODES_IN_CUDA_GRAPH]; + + bool use_cuda_graph = true; bool cuda_graph_update_required = false; // pointer to CUDA cpy kernel, which is required to identify // kernel parameters which need updated in the graph for each token void * ggml_cuda_cpy_fn_ptr = nullptr; - if(cuda_ctx->cuda_graph->count == 0){ - if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800){ - cuda_ctx->cuda_graph->disable_due_to_gpu_arch=true; + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif } } // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, // or previous graph capture failure. // Also disable for multi-gpu for now. TO DO investigate - if(disable_cuda_graphs || cuda_ctx->cuda_graph->disable_due_to_gpu_arch || - cuda_ctx->cuda_graph->disable_due_to_too_many_updates || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture || - ggml_backend_cuda_get_device_count() > 1){ + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { use_cuda_graph = false; } - if(use_cuda_graph) { + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } - if(cuda_ctx->cuda_graph->instance == nullptr) cuda_graph_update_required=true; + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; - if(!cuda_graph_update_required) { + if (!cuda_graph_update_required) { has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } - if(!has_matching_properties) cuda_graph_update_required = true; + if (!has_matching_properties) { + cuda_graph_update_required = true; + } set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); } // Loop over nodes in GGML graph to obtain info needed for CUDA graph - int k=0; + cuda_ctx->cuda_graph->updated_kernel_arg.clear(); for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if(node->op == GGML_OP_MUL_MAT_ID) { + + if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__); +#endif } - if(node->op == GGML_OP_SOFT_MAX) { - if(node->src[1]->ne[1] > 1){ - use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate - } + + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif } - if(node->op == GGML_OP_CPY) { + + if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. - updated_kernel_arg[k++]=(char **) &(node->src[1]->data); - if(ggml_cuda_cpy_fn_ptr == nullptr){ + cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + if (ggml_cuda_cpy_fn_ptr == nullptr) { // store a pointer to the copy op CUDA kernel to identify it later ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); } } + + if (!use_cuda_graph) { + break; + } } // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if(cuda_graph_update_required) { + if (cuda_graph_update_required) { cuda_ctx->cuda_graph->number_consecutive_updates++; - } - else { + } else { cuda_ctx->cuda_graph->number_consecutive_updates = 0; } - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; + + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } } - if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture - CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); + if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } #else @@ -2568,95 +2584,105 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool graph_evaluated_or_captured = false; - while(!graph_evaluated_or_captured) { - // Temporarily avoid indenting here (and below the following if) to make code review easier + while (!graph_evaluated_or_captured) { + // Temporarily avoid indenting here (and below the following if) to make code review easier - // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. - // With the use of CUDA graphs, the execution will be performed by the graph launch. - if(!use_cuda_graph || cuda_graph_update_required) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; - } + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } #ifndef NDEBUG - assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); - } - } + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + } + } #endif - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); - if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + if (!ok) { + fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } } - GGML_ASSERT(ok); - } - } #ifdef USE_CUDA_GRAPH - if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); - if(disable_cuda_graphs_due_to_failed_capture) { - use_cuda_graph = false; - cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; - } - else { + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + +#if 0 + if (disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); +#endif + } else { + graph_evaluated_or_captured = true; // CUDA graph has been captured + } +#endif graph_evaluated_or_captured = true; // CUDA graph has been captured + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated } } - else { - graph_evaluated_or_captured = true; // ggml graph has been directly evaluated - } - } - if(use_cuda_graph){ - if(cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } - // Perform update to graph (if required for this token), and change copy parameter (required for every token) - if(cuda_graph_update_required) { + if (cuda_graph_update_required) { // Extract nodes from graph - if(cuda_ctx->cuda_graph->num_nodes == 0) { + if (cuda_ctx->cuda_graph->num_nodes == 0) { // First call with null argument gets number of nodes in graph CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); } // Subsequent call with non-null argument gets nodes - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes, &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for(size_t i=0; icuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - auto stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if(stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); - } - else { - GGML_ASSERT(stat == cudaSuccess); + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); + } } } } } // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - if(!cuda_graph_update_required) { // on update steps, the live parameters will already be captured - int k=0; - for(size_t i=0; icuda_graph->num_nodes; i++) { - if(cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { - char ** updated_kernel_arg_ptr = updated_kernel_arg[k++]; + // replace that argument with the updated value in the CUDA graph + if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); } @@ -2665,21 +2691,25 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Update graph executable cudaGraphExecUpdateResultInfo result_info; - auto stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); - if(stat == cudaErrorGraphExecUpdateFailure) { + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + fprintf(stderr, "%s: CUDA graph update failed\n", __func__); +#endif // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); - } - else { + } else { GGML_ASSERT(stat == cudaSuccess); } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); } - cuda_ctx->cuda_graph->count++; #endif // USE_CUDA_GRAPH + return GGML_STATUS_SUCCESS; } diff --git a/ggml-cuda/clamp.cu b/ggml-cuda/clamp.cu index 379ded042d897..8009a3e3d8607 100644 --- a/ggml-cuda/clamp.cu +++ b/ggml-cuda/clamp.cu @@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 26401d369c417..a4197f11ba779 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -174,6 +175,7 @@ #define GGML_CUDA_MAX_STREAMS 8 +[[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg); #define CUDA_CHECK_GEN(err, success, error_fn) \ @@ -525,7 +527,42 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs }; -struct ggml_cuda_graph; + +#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#define USE_CUDA_GRAPH +#endif + +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector nodes; + std::vector params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector ggml_graph_properties; + std::vector updated_kernel_arg; +#endif +}; struct ggml_backend_cuda_context { int device; @@ -535,7 +572,7 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - ggml_cuda_graph * cuda_graph = nullptr; + std::unique_ptr cuda_graph; explicit ggml_backend_cuda_context(int device) : device(device), @@ -556,7 +593,6 @@ struct ggml_backend_cuda_context { CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); } } - if(cuda_graph != nullptr) free(cuda_graph); } cudaStream_t stream(int device, int stream) { diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 75e50c9856123..830e2d7566162 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_ } to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - int id; switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - CUDA_CHECK(cudaGetDevice(&id)); - if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) { + if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 60d6616a860f7..7948f1b1237fa 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( #if QK_K == 256 - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index 3965590017b95..65cc1bcaad697 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; @@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into diff --git a/ggml-cuda/scale.cu b/ggml-cuda/scale.cu index 6e3617d1cdbd5..1405e066e86a2 100644 --- a/ggml-cuda/scale.cu +++ b/ggml-cuda/scale.cu @@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, dst->op_params, sizeof(float)); scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); } From a4c9b9017fc177288ea6b01694e9e8dc941f30d2 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 8 May 2024 02:09:52 +0200 Subject: [PATCH 15/17] fix build without cuda graphs --- ggml-cuda.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5beb51f46a21e..f7df32267184e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2707,8 +2707,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); - } +#else + graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH + } return GGML_STATUS_SUCCESS; } From ab40e667ddbdd8eb4a5b6d0268aedb17e5375db3 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 8 May 2024 16:37:19 +0200 Subject: [PATCH 16/17] remove outdated comment --- ggml-cuda.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f7df32267184e..d01797e0cd6b9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2585,8 +2585,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool graph_evaluated_or_captured = false; while (!graph_evaluated_or_captured) { - // Temporarily avoid indenting here (and below the following if) to make code review easier - // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { From f42312e0a19e4dc5d7e332678bb451a7e9fb0cb3 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 8 May 2024 16:38:48 +0200 Subject: [PATCH 17/17] replace minimum cc value with a constant --- ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d01797e0cd6b9..6f89a7cc3e900 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2474,7 +2474,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t void * ggml_cuda_cpy_fn_ptr = nullptr; if (cuda_ctx->cuda_graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; #ifndef NDEBUG fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);