From e8797a9aedad9cd51cdc01c7b269830c9026f931 Mon Sep 17 00:00:00 2001
From: Slaren <2141330+slaren@users.noreply.github.com>
Date: Fri, 21 Apr 2023 00:09:14 +0200
Subject: [PATCH 1/4] Improve cuBLAS performance by using a memory pool

---
 ggml.c | 127 +++++++++++++++++++++++++++++++++++++--------------------
 1 file changed, 82 insertions(+), 45 deletions(-)

diff --git a/ggml.c b/ggml.c
index 998602150fe55..fb5fd1f7ea515 100644
--- a/ggml.c
+++ b/ggml.c
@@ -152,25 +152,69 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #include <cuda_runtime.h>
 #include "ggml-cuda.h"
 
-#define CUDA_CHECK(err)                                                        \
-    do {                                                                       \
-        cudaError_t err_ = (err);                                              \
-        if (err_ != cudaSuccess) {                                             \
-            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
-                cudaGetErrorString(err_));                                     \
-            exit(1);                                                           \
-        }                                                                      \
+#define CUDA_CHECK(err)                                                                 \
+    do {                                                                                \
+        cudaError_t err_ = (err);                                                       \
+        if (err_ != cudaSuccess) {                                                      \
+            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                              \
+            exit(1);                                                                    \
+        }                                                                               \
     } while (0)
 
-#define CUBLAS_CHECK(err)                                                      \
-    do {                                                                       \
-        cublasStatus_t err_ = (err);                                           \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                   \
-            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
-            exit(1);                                                           \
-        }                                                                      \
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                                    \
+        }                                                                               \
     } while (0)
 
+// lock-free, thread safe buffer pool for cuda
+#define MAX_CUDA_BUFFERS 16
+struct cuda_buffer {
+    atomic_uintptr_t ptr;
+    size_t size;
+};
+
+static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0};
+
+static void * cuda_pool_malloc(size_t size, size_t * actual_size) {
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        struct cuda_buffer * b = &cuda_buffer_pool[i];
+        if (b->size >= size) {
+            uintptr_t ptr = atomic_load(&b->ptr);
+            if (ptr) {
+                if (atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
+                    *actual_size = b->size;
+                    return (void *) ptr;
+                }
+            }
+        }
+    }
+
+    void * ptr;
+    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
+    *actual_size = size;
+    return ptr;
+}
+
+static void cuda_pool_free(void * ptr, size_t size) {
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        struct cuda_buffer * b = &cuda_buffer_pool[i];
+        uintptr_t p = atomic_load(&b->ptr);
+        if (p == 0) {
+            if (atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
+                b->size = size;
+                return;
+            }
+        }
+    }
+    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    CUDA_CHECK(cudaFree(ptr));
+}
+
 static cublasHandle_t cublasH = NULL;
 static cudaStream_t cudaStream = NULL;
 static void init_cublas(void) {
@@ -7566,18 +7610,16 @@ static void ggml_compute_forward_mul_mat_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+        size_t x_size, y_size, d_size;
+        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7614,9 +7656,9 @@ static void ggml_compute_forward_mul_mat_f32(
         }
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
+        cuda_pool_free(d_X, x_size);
+        cuda_pool_free(d_Y, y_size);
+        cuda_pool_free(d_D, d_size);
 #endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
@@ -7766,18 +7808,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 #if defined(GGML_USE_CUBLAS)
         ggml_fp16_t * const wdata = params->wdata;
 
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+        size_t x_size, y_size, d_size;
+        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #else
         float * const wdata = params->wdata;
 #endif
@@ -7844,9 +7884,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
+        cuda_pool_free(d_X, x_size);
+        cuda_pool_free(d_Y, y_size);
+        cuda_pool_free(d_D, d_size);
 #endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
@@ -8014,20 +8054,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        float *d_X = NULL;
-        float *d_Y = NULL;
-        float *d_D = NULL;
-        float *d_Q = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
         const int y_ne = ne11 * ne10;
         const int d_ne = ne11 * ne01;
 
-        CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
-        CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+        size_t x_size, y_size, d_size, q_size;
+        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        float *d_Q = cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
 
         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
         if (type == GGML_TYPE_Q4_0) {
@@ -8100,10 +8137,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        CUDA_CHECK(cudaFree(d_X));
-        CUDA_CHECK(cudaFree(d_Y));
-        CUDA_CHECK(cudaFree(d_D));
-        CUDA_CHECK(cudaFree(d_Q));
+        cuda_pool_free(d_X, x_size);
+        cuda_pool_free(d_Y, y_size);
+        cuda_pool_free(d_D, d_size);
+        cuda_pool_free(d_Q, q_size);
 #endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 

From 641e9a0c52224fa38fabc47737250ab193622110 Mon Sep 17 00:00:00 2001
From: Slaren <2141330+slaren@users.noreply.github.com>
Date: Fri, 21 Apr 2023 00:58:26 +0200
Subject: [PATCH 2/4] Move cuda specific definitions to ggml-cuda.h/cu

---
 ggml-cuda.cu |  91 ++++++++++++++++++++++++++++++-------
 ggml-cuda.h  |  31 +++++++++++++
 ggml.c       | 123 +++++++++------------------------------------------
 3 files changed, 127 insertions(+), 118 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 0baa989a36ca9..2d2e5a90e0d41 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1,5 +1,7 @@
 #include <stdint.h>
+#include <stdio.h>
 #include <cuda_fp16.h>
+#include <atomic>
 #include "ggml-cuda.h"
 
 typedef uint16_t ggml_fp16_t;
@@ -35,8 +37,6 @@ typedef struct {
 } block_q4_3;
 static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
 
-
-
 static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
     const block_q4_0 * x = (const block_q4_0 *) vx;
 
@@ -131,24 +131,83 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
     }
 }
 
-extern "C" {
-    __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-        const int nb = k / QK4_0;
-        dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
-    }
+void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_0;
+    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+}
 
-    __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-        const int nb = k / QK4_1;
-        dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_1;
+    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_2;
+    dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK4_3;
+    dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+// lock-free, thread safe buffer pool for cuda
+#define MAX_CUDA_BUFFERS 16
+struct cuda_buffer {
+    std::atomic_uintptr_t ptr;
+    size_t size;
+};
+
+static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0};
+
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        struct cuda_buffer * b = &cuda_buffer_pool[i];
+        if (b->size >= size) {
+            uintptr_t ptr = atomic_load(&b->ptr);
+            if (ptr) {
+                if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
+                    *actual_size = b->size;
+                    return (void *) ptr;
+                }
+            }
+        }
     }
 
-    __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-        const int nb = k / QK4_2;
-        dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
+    void * ptr;
+    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
+    *actual_size = size;
+    return ptr;
+}
+
+void ggml_cuda_pool_free(void * ptr, size_t size) {
+    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
+        struct cuda_buffer * b = &cuda_buffer_pool[i];
+        uintptr_t p = std::atomic_load(&b->ptr);
+        if (p == 0) {
+            if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
+                b->size = size;
+                return;
+            }
+        }
     }
+    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+    CUDA_CHECK(cudaFree(ptr));
+}
+
+cublasHandle_t cublasH = NULL;
+cudaStream_t cudaStream = NULL;
+
+void ggml_init_cublas(void) {
+    if (cublasH == NULL) {
+        // create cublas handle, bind a stream
+        CUBLAS_CHECK(cublasCreate(&cublasH));
+
+        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+
+        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
 
-    __host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-        const int nb = k / QK4_3;
-        dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
+        // configure logging to stdout
+        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
     }
 }
diff --git a/ggml-cuda.h b/ggml-cuda.h
index be140606aa2d4..40877ecd5500f 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -1,7 +1,38 @@
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
+#define CUDA_CHECK(err)                                                                 \
+    do {                                                                                \
+        cudaError_t err_ = (err);                                                       \
+        if (err_ != cudaSuccess) {                                                      \
+            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                              \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+
+
+
+extern cublasHandle_t cublasH;
+extern cudaStream_t cudaStream;
+
+void   ggml_init_cublas(void);
+void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
+void   ggml_cuda_pool_free(void * ptr, size_t size);
+
 void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
diff --git a/ggml.c b/ggml.c
index fb5fd1f7ea515..6d8796e1ad894 100644
--- a/ggml.c
+++ b/ggml.c
@@ -148,88 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #elif defined(GGML_USE_OPENBLAS)
 #include <cblas.h>
 #elif defined(GGML_USE_CUBLAS)
-#include <cublas_v2.h>
-#include <cuda_runtime.h>
 #include "ggml-cuda.h"
-
-#define CUDA_CHECK(err)                                                                 \
-    do {                                                                                \
-        cudaError_t err_ = (err);                                                       \
-        if (err_ != cudaSuccess) {                                                      \
-            fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
-                cudaGetErrorString(err_));                                              \
-            exit(1);                                                                    \
-        }                                                                               \
-    } while (0)
-
-#define CUBLAS_CHECK(err)                                                               \
-    do {                                                                                \
-        cublasStatus_t err_ = (err);                                                    \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
-            exit(1);                                                                    \
-        }                                                                               \
-    } while (0)
-
-// lock-free, thread safe buffer pool for cuda
-#define MAX_CUDA_BUFFERS 16
-struct cuda_buffer {
-    atomic_uintptr_t ptr;
-    size_t size;
-};
-
-static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0};
-
-static void * cuda_pool_malloc(size_t size, size_t * actual_size) {
-    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        struct cuda_buffer * b = &cuda_buffer_pool[i];
-        if (b->size >= size) {
-            uintptr_t ptr = atomic_load(&b->ptr);
-            if (ptr) {
-                if (atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
-                    *actual_size = b->size;
-                    return (void *) ptr;
-                }
-            }
-        }
-    }
-
-    void * ptr;
-    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
-    *actual_size = size;
-    return ptr;
-}
-
-static void cuda_pool_free(void * ptr, size_t size) {
-    for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        struct cuda_buffer * b = &cuda_buffer_pool[i];
-        uintptr_t p = atomic_load(&b->ptr);
-        if (p == 0) {
-            if (atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
-                b->size = size;
-                return;
-            }
-        }
-    }
-    fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
-    CUDA_CHECK(cudaFree(ptr));
-}
-
-static cublasHandle_t cublasH = NULL;
-static cudaStream_t cudaStream = NULL;
-static void init_cublas(void) {
-    if (cublasH == NULL) {
-        // create cublas handle, bind a stream
-        CUBLAS_CHECK(cublasCreate(&cublasH));
-
-        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
-
-        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
-
-        // configure logging to stdout
-        // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
-    }
-}
 #endif
 
 #undef MIN
@@ -3764,7 +3683,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
         // initialize cuBLAS
         #if defined(GGML_USE_CUBLAS)
-        init_cublas();
+        ggml_init_cublas();
         #endif
 
         is_first_call = false;
@@ -7617,9 +7536,9 @@ static void ggml_compute_forward_mul_mat_f32(
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size;
-        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7656,9 +7575,9 @@ static void ggml_compute_forward_mul_mat_f32(
         }
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        cuda_pool_free(d_X, x_size);
-        cuda_pool_free(d_Y, y_size);
-        cuda_pool_free(d_D, d_size);
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
 #endif
         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
@@ -7815,9 +7734,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size;
-        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
 #else
         float * const wdata = params->wdata;
 #endif
@@ -7884,9 +7803,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        cuda_pool_free(d_X, x_size);
-        cuda_pool_free(d_Y, y_size);
-        cuda_pool_free(d_D, d_size);
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
 #endif
         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
@@ -8061,10 +7980,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
         const int d_ne = ne11 * ne01;
 
         size_t x_size, y_size, d_size, q_size;
-        float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
-        float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
-        float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
-        float *d_Q = cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
+        float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+        float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+        float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+        float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
 
         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
         if (type == GGML_TYPE_Q4_0) {
@@ -8137,10 +8056,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
         CUDA_CHECK(cudaStreamSynchronize(cudaStream));
-        cuda_pool_free(d_X, x_size);
-        cuda_pool_free(d_Y, y_size);
-        cuda_pool_free(d_D, d_size);
-        cuda_pool_free(d_Q, q_size);
+        ggml_cuda_pool_free(d_X, x_size);
+        ggml_cuda_pool_free(d_Y, y_size);
+        ggml_cuda_pool_free(d_D, d_size);
+        ggml_cuda_pool_free(d_Q, q_size);
 #endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 

From c832e7c793077aff82856dc101a10962ab65501e Mon Sep 17 00:00:00 2001
From: Slaren <2141330+slaren@users.noreply.github.com>
Date: Fri, 21 Apr 2023 03:39:04 +0200
Subject: [PATCH 3/4] Add CXX flags to nvcc

---
 Makefile     | 10 ++++++----
 ggml-cuda.cu |  6 +++---
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/Makefile b/Makefile
index f267d086415ee..3b48eec9906e4 100644
--- a/Makefile
+++ b/Makefile
@@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS
 	LDFLAGS += -lopenblas
 endif
 ifdef LLAMA_CUBLAS
-	CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
-	LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
-	OBJS	+= ggml-cuda.o
+	CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
+	LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
+	OBJS      += ggml-cuda.o
+	NVCC      = nvcc
+	NVCCFLAGS = --forward-unknown-to-host-linker -arch=native
 ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
-	nvcc -arch=native -c -o $@ $<
+	$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
 endif
 ifdef LLAMA_GPROF
 	CFLAGS   += -pg
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 2d2e5a90e0d41..dc8f486f20584 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -154,11 +154,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st
 // lock-free, thread safe buffer pool for cuda
 #define MAX_CUDA_BUFFERS 16
 struct cuda_buffer {
-    std::atomic_uintptr_t ptr;
-    size_t size;
+    std::atomic_uintptr_t ptr { 0 };
+    size_t size { 0 };
 };
 
-static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0};
+static cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS];
 
 void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {

From d774e05428f410d407aa79fb05c016a41c18a4cd Mon Sep 17 00:00:00 2001
From: Slaren <2141330+slaren@users.noreply.github.com>
Date: Fri, 21 Apr 2023 21:02:17 +0200
Subject: [PATCH 4/4] Change memory pool synchronization mechanism to a spin
 lock General code cleanup

---
 ggml-cuda.cu | 75 +++++++++++++++++++++++++++++++---------------------
 ggml-cuda.h  |  6 ++---
 ggml.c       | 32 +++++++++++-----------
 3 files changed, 63 insertions(+), 50 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index dc8f486f20584..fa511c1dc5d3d 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -31,9 +31,9 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
 
 #define QK4_3 16
 typedef struct {
-    __half  d;         // delta
-    __half  m;         // min
-    uint8_t qs[QK4_3 / 2]; // nibbles / quants
+    __half  d;              // delta
+    __half  m;              // min
+    uint8_t qs[QK4_3 / 2];  // nibbles / quants
 } block_q4_3;
 static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
 
@@ -151,29 +151,44 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st
     dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
 }
 
-// lock-free, thread safe buffer pool for cuda
+// buffer pool for cuda
 #define MAX_CUDA_BUFFERS 16
+
+struct scoped_spin_lock {
+    std::atomic_flag& lock;
+    scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
+        while (lock.test_and_set(std::memory_order_acquire)) {
+            ; // spin
+        }
+    }
+    ~scoped_spin_lock() {
+        lock.clear(std::memory_order_release);
+    }
+    scoped_spin_lock(const scoped_spin_lock&) = delete;
+    scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
+};
+
 struct cuda_buffer {
-    std::atomic_uintptr_t ptr { 0 };
-    size_t size { 0 };
+    void * ptr = nullptr;
+    size_t size = 0;
 };
 
-static cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS];
+static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
+static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
 
 void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        struct cuda_buffer * b = &cuda_buffer_pool[i];
-        if (b->size >= size) {
-            uintptr_t ptr = atomic_load(&b->ptr);
-            if (ptr) {
-                if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
-                    *actual_size = b->size;
-                    return (void *) ptr;
-                }
-            }
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.size >= size && b.ptr != nullptr) {
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
         }
     }
-
     void * ptr;
     CUDA_CHECK(cudaMalloc((void **) &ptr, size));
     *actual_size = size;
@@ -181,31 +196,31 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
 }
 
 void ggml_cuda_pool_free(void * ptr, size_t size) {
+    scoped_spin_lock lock(g_cuda_pool_lock);
+
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        struct cuda_buffer * b = &cuda_buffer_pool[i];
-        uintptr_t p = std::atomic_load(&b->ptr);
-        if (p == 0) {
-            if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
-                b->size = size;
-                return;
-            }
+        cuda_buffer& b = g_cuda_buffer_pool[i];
+        if (b.ptr == nullptr) {
+            b.ptr = ptr;
+            b.size = size;
+            return;
         }
     }
     fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
     CUDA_CHECK(cudaFree(ptr));
 }
 
-cublasHandle_t cublasH = NULL;
-cudaStream_t cudaStream = NULL;
+cublasHandle_t g_cublasH = NULL;
+cudaStream_t g_cudaStream = NULL;
 
 void ggml_init_cublas(void) {
-    if (cublasH == NULL) {
+    if (g_cublasH == NULL) {
         // create cublas handle, bind a stream
-        CUBLAS_CHECK(cublasCreate(&cublasH));
+        CUBLAS_CHECK(cublasCreate(&g_cublasH));
 
-        CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+        CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
 
-        CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
+        CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
 
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
diff --git a/ggml-cuda.h b/ggml-cuda.h
index 40877ecd5500f..370bbc75f5f76 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -24,10 +24,8 @@ extern "C" {
         }                                                                               \
     } while (0)
 
-
-
-extern cublasHandle_t cublasH;
-extern cudaStream_t cudaStream;
+extern cublasHandle_t g_cublasH;
+extern cudaStream_t   g_cudaStream;
 
 void   ggml_init_cublas(void);
 void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
diff --git a/ggml.c b/ggml.c
index 6d8796e1ad894..8beca30fe6cca 100644
--- a/ggml.c
+++ b/ggml.c
@@ -7550,19 +7550,19 @@ static void ggml_compute_forward_mul_mat_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, ne00,
                                     d_Y, ne10,
                             &beta,  d_D, ne01));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7574,7 +7574,7 @@ static void ggml_compute_forward_mul_mat_f32(
             }
         }
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
         ggml_cuda_pool_free(d_X, x_size);
         ggml_cuda_pool_free(d_Y, y_size);
         ggml_cuda_pool_free(d_D, d_size);
@@ -7770,12 +7770,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, CUDA_R_16F, ne00,
                                     d_Y, CUDA_R_16F, ne10,
@@ -7784,7 +7784,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                             CUBLAS_GEMM_DEFAULT));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7802,7 +7802,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
         ggml_cuda_pool_free(d_X, x_size);
         ggml_cuda_pool_free(d_Y, y_size);
         ggml_cuda_pool_free(d_D, d_size);
@@ -8013,9 +8013,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
                 // copy and dequantize on device
                 CUDA_CHECK(
                     cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
-                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
 
-                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
                 CUDA_CHECK(cudaGetLastError());
 #else
                 {
@@ -8031,18 +8031,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
 
                 // compute
                 CUBLAS_CHECK(
-                    cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
                             ne01, ne11, ne10,
                             &alpha, d_X, ne00,
                                     d_Y, ne10,
                             &beta,  d_D, ne01));
 
                 // copy data to host
-                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+                CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -8055,7 +8055,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
-        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
         ggml_cuda_pool_free(d_X, x_size);
         ggml_cuda_pool_free(d_Y, y_size);
         ggml_cuda_pool_free(d_D, d_size);