From 73e7601bf395a73d7c186a99846ed0997efdd3ce Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sat, 15 Apr 2023 19:57:01 +0800 Subject: [PATCH 01/11] stash --- ggml.c | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 118 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index a26b4853f7eae..e933bb944f338 100644 --- a/ggml.c +++ b/ggml.c @@ -20,6 +20,10 @@ #include #include +#if defined(__AVX2__) +#include +#endif + // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 #ifndef static_assert @@ -94,7 +98,7 @@ typedef void* thread_ret_t; #define static_assert(cond, msg) _Static_assert(cond, msg) #endif -/*#define GGML_PERF*/ +// #define GGML_PERF #define GGML_DEBUG 0 #define GGML_GELU_FP16 #define GGML_SILU_FP16 @@ -412,9 +416,90 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); #define QK 32 + +// AVX routine provided by GH user jon-chuang +// ref: https://github.com/ggerganov/llama.cpp/issues/956#issuecomment-1508090551 + +#if __AVX2__ || __AVX512F__ + +// Given A = K X M, B = K X N, compute one row of C = A^TB +void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { + for (int j = 0; j < N; j += 8) { // Process 8 elements of C's row at a time - 256 / size_of(float) + __m256 c_vec = _mm256_setzero_ps(); // Initialize the result vector to all zeros + + for (int k = 0; k < K; ++k) { + __m256 a = _mm256_broadcast_ss(&A[k * M]); // Broadcast the k-th element of the row of A^T + __m256 b_vec = _mm256_load_ps(&B[j + k * N]); // Load the j/8-th segment of the k-th row of B^T (corresponding to the k-th column of B) + c_vec = _mm256_fmadd_ps(a, b_vec, c_vec); // FMA: c_vec += a * b_vec + } + + // Store the result in the corresponding row of C + _mm256_store_ps(&C[j], c_vec); + } + + // Handle the remainder + const int64_t remainder = N & (8l - 1); + if (remainder > 0) { + int j = N - remainder; + __m256i mask_vec = _mm256_set1_epi32(0xffffffff << (8l - remainder)); + + __m256 c_vec = _mm256_setzero_ps(); // Initialize the result vector to all zeros + + for (int k = 0; k < K; ++k) { + __m256 a = _mm256_broadcast_ss(&A[k * M]); // Broadcast the k-th element of the row of A^T + __m256 b_vec = _mm256_maskload_ps(&B[j + k * N], mask_vec); // Load the j/8-th segment of the k-th row of B^T (corresponding to the k-th column of B) + c_vec = _mm256_fmadd_ps(a, b_vec, c_vec); // FMA: c_vec += a * b_vec + } + + // Store the result in the corresponding offset of C + _mm256_maskstore_ps(&C[j], mask_vec, c_vec); + } +} + +#elif __AVX__ + +// Given A = K X M, B = K X N, compute one row of C = A^TB +void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { + for (int j = 0; j < N; j += 4) { // Process 4 elements of C's row at a time - 128 / size_of(float) + __m128 c_vec = _mm_setzero_ps(); // Initialize the result vector to all zeros + + for (int k = 0; k < K; ++k) { + __m128 a = _mm_broadcast_ss(&A[k * M]); // Broadcast the k-th element of the row of A^T + __m128 b_vec = _mm_load_ps(&B[j + k * N]); // Load the j/4-th segment of the k-th row of B^T (corresponding to the k-th column of B) + c_vec = _mm_fmadd_ps(a, b_vec, c_vec); // FMA: c_vec += a * b_vec + } + + // Store the result in the corresponding row of C + _mm_store_ps(&C[j], c_vec); + } + + // Handle the remainder + const int64_t remainder = N & (4l - 1); + if (remainder > 0) { + int j = N - remainder; + __m128i mask_vec = _mm_set1_epi32(0xffffffff << (4l - remainder)); + + __m128 c_vec = _mm_setzero_ps(); // Initialize the result vector to all zeros + + for (int k = 0; k < K; ++k) { + __m128 a = _mm_broadcast_ss(&A[k * M]); // Broadcast the k-th element of the row of A^T + __m128 b_vec = _mm_maskload_ps(&B[j + k * N], mask_vec); // Load the j/4-th segment of the k-th row of B^T (corresponding to the k-th column of B) + c_vec = _mm_fmadd_ps(a, b_vec, c_vec); // FMA: c_vec += a * b_vec + } + + // Store the result in the corresponding offset of C + _mm_maskstore_ps(&C[j], mask_vec, c_vec); + } +} + +#endif + + // AVX routines provided by GH user Const-me // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 + #if __AVX2__ || __AVX512F__ + // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytesFromNibbles( const uint8_t* rsi ) @@ -449,6 +534,7 @@ static inline __m128i packNibbles( __m256i bytes ) return _mm_packus_epi16( r0, r1 ); } #elif __AVX__ + static inline __m128i bytesFromNibbles( const uint8_t* rsi ) { // Load 8 bytes from memory @@ -6353,7 +6439,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(__AVX2__) || defined(__AVX__) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -6407,6 +6493,35 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne2 == ne02); assert(ne3 == ne03); +#if defined(__AVX2__) || defined(__AVX__) + if (ne00 <= 32) { + assert(ne00 == ne10); + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // fprintf(stderr, "TS_MUL_MAT: %i %i %i, ir0: %d, ir1: %d, ID: %d\n", ne00, ne11, ne01, ir0, ir1, thrd_current()); + for (int i = ir0; i < ir1; ++i) { + // M = ne01, N = ne11, K = ne00 + ggml_mul_row_f32_tall_skinny( + (float *) src0->data + i, + (float *) src1->data, + (float *) dst->data + i*ne01, + ne01, ne11, ne00); + } + return; + } +#endif + // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -6440,7 +6555,7 @@ static void ggml_compute_forward_mul_mat_f32( } } - //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + // printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne01, ne11, ne02, ne03); return; } From 00e86b97cc12cd4ba15dcdfb9dcbdbbeb649a902 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sat, 15 Apr 2023 21:54:37 +0800 Subject: [PATCH 02/11] commit --- Makefile | 2 +- examples/CMakeLists.txt | 1 + examples/benchmark/CMakeLists.txt | 4 +++ examples/benchmark/benchmark-q4_0-matmult.c | 35 +++++++++++++++++++-- ggml.c | 14 ++++++--- 5 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 examples/benchmark/CMakeLists.txt diff --git a/Makefile b/Makefile index a1b99c6f9dfe2..1c450f055b951 100644 --- a/Makefile +++ b/Makefile @@ -176,7 +176,7 @@ libllama.so: llama.o ggml.o # Tests # -benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o +benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o llama.o common.o $(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS) ./benchmark-q4_0-matmult diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 67a7cea543a40..8f53244f67282 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -34,4 +34,5 @@ else() add_subdirectory(quantize-stats) add_subdirectory(perplexity) add_subdirectory(embedding) + add_subdirectory(benchmark) endif() diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt new file mode 100644 index 0000000000000..31aaa190b9652 --- /dev/null +++ b/examples/benchmark/CMakeLists.txt @@ -0,0 +1,4 @@ +set(TARGET benchmark) +add_executable(${TARGET} benchmark-q4_0-matmul.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/benchmark/benchmark-q4_0-matmult.c b/examples/benchmark/benchmark-q4_0-matmult.c index 84b06766c15dc..27c00e1cecb32 100644 --- a/examples/benchmark/benchmark-q4_0-matmult.c +++ b/examples/benchmark/benchmark-q4_0-matmult.c @@ -8,6 +8,7 @@ #include #include "ggml.h" +#include "llama.h" #include #include #include @@ -45,7 +46,7 @@ float tensor_sum_elements(struct ggml_tensor * tensor) { #define TENSOR_TYPE_AS_STR(TYPE) TYPE == GGML_TYPE_F32 ? "FP32" : TYPE == GGML_TYPE_F16 ? "FP16" : TYPE == GGML_TYPE_Q4_0 ? "Q4_0" : TYPE == GGML_TYPE_Q4_1 ? "Q4_1" : "UNKNOWN" -#define TENSOR_DUMP(TENSOR) printf("%15s: type = %i (%5s) ne = %5d x %5d x %5d, nb = (%5li, %5li, %5li) - ", #TENSOR, \ +#define TENSOR_DUMP(TENSOR) printf("%15s: type = %i (%5s) ne = %5ld x %5ld x %5ld, nb = (%5li, %5li, %5li) - ", #TENSOR, \ TENSOR->type,TENSOR_TYPE_AS_STR(TENSOR->type),\ TENSOR->ne[0], TENSOR->ne[1], TENSOR->ne[2], TENSOR->nb[0], TENSOR->nb[1], TENSOR->nb[2]); \ { float sum = tensor_sum_elements(TENSOR); printf("Sum of tensor %s is %6.2f\n",#TENSOR, sum); } @@ -170,12 +171,40 @@ int main(int argc, char ** argv) { struct ggml_cgraph gf = ggml_build_forward(m11xm2); gf.n_threads=benchmark_params.n_threads; - printf("cgraph->n_threads=%i\n",gf.n_threads); + fprintf(stderr, "system_info: n_threads = %d | %s\n", + benchmark_params.n_threads, llama_print_system_info()); TENSOR_DUMP(m11); TENSOR_DUMP(m2); ggml_graph_compute(ctx, &gf); + { + const int dimx = sizex; + const int dimy = sizey; + const int dimz = sizez; + long long int flops_per_dot_product = dimy + dimy; + long long int flops_per_matrix = flops_per_dot_product * dimx * dimz; ; + printf("Matrix Multiplication of (%i,%i,%i) x (%i,%i,%i) - about %6.2f gFLOPS\n\n", sizex, sizey, 1, sizex, sizez, 1, 1.0f*flops_per_matrix / 1000 / 1000 / 1000); + + printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; FLOPS_per_u_Second\n"); + printf("==============================================================================================\n"); + + for (int i=0;itype == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; From 6bf6543a6ac169c8070188a61da4a65e92050ddc Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sat, 15 Apr 2023 21:57:39 +0800 Subject: [PATCH 03/11] format --- ggml.c | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/ggml.c b/ggml.c index df27a984f5cfe..797579926161f 100644 --- a/ggml.c +++ b/ggml.c @@ -433,12 +433,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); #define QK 32 - // AVX routine provided by GH user jon-chuang -// ref: https://github.com/ggerganov/llama.cpp/issues/956#issuecomment-1508090551 - -#if false && __AVX2__ || __AVX512F__ - +#if __AVX2__ || __AVX512F__ // Given A = K X M, B = K X N, compute one row of C = A^TB void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { alignas(32) float res_vec[8]; @@ -476,9 +472,7 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i _mm256_maskstore_ps(&C[j], mask_vec, c_vec); } } - #elif __AVX__ - // Given A = K X M, B = K X N, compute one row of C = A^TB void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { for (int j = 0; j < N; j += 4) { // Process 4 elements of C's row at a time - 128 / size_of(float) @@ -515,12 +509,9 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i #endif - // AVX routines provided by GH user Const-me // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 - #if __AVX2__ || __AVX512F__ - // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytesFromNibbles( const uint8_t* rsi ) From a38b9d7fab84c2082f8ff030083bfe0ea3a520e4 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sat, 15 Apr 2023 21:58:10 +0800 Subject: [PATCH 04/11] minor --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 797579926161f..54722cd619e6b 100644 --- a/ggml.c +++ b/ggml.c @@ -6760,7 +6760,7 @@ static void ggml_compute_forward_mul_mat_f32( } } - // printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne01, ne11, ne02, ne03); + //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne01, ne11, ne02, ne03); return; } From 02258616ef9199b7ca0c5bfd28ca6a6575231dbc Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sat, 15 Apr 2023 22:27:23 +0800 Subject: [PATCH 05/11] minor --- examples/benchmark/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 31aaa190b9652..6768ead37c253 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -1,4 +1,4 @@ set(TARGET benchmark) -add_executable(${TARGET} benchmark-q4_0-matmul.cpp) +add_executable(${TARGET} benchmark-q4_0-matmult.c) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) From 5bb53278330b3e6e80ba544503458810fc83bcab Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Wed, 26 Apr 2023 22:48:15 +0800 Subject: [PATCH 06/11] minor --- ggml.c | 47 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 5b2653ba49857..76a435e6cf0e8 100644 --- a/ggml.c +++ b/ggml.c @@ -434,7 +434,51 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // -#define QK 32 +#if __AVX__ || __AVX2__ || __AVX512F__ +// Unpack 16 4-bit fields into 16 bytes +// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval +static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) +{ + // Load 8 bytes from memory + __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi ); + + // Expand bytes into uint16_t values + __m128i bytes = _mm_cvtepu8_epi16( tmp ); + + // Unpack values into individual bytes + const __m128i lowMask = _mm_set1_epi8( 0xF ); + __m128i high = _mm_andnot_si128( lowMask, bytes ); + __m128i low = _mm_and_si128( lowMask, bytes ); + high = _mm_slli_epi16( high, 4 ); + bytes = _mm_or_si128( low, high ); + return bytes; +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} // AVX routine provided by GH user jon-chuang #if __AVX2__ || __AVX512F__ @@ -509,7 +553,6 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i _mm_maskstore_ps(&C[j], mask_vec, c_vec); } } - #endif // AVX routines provided by GH user Const-me From 8ead56c03a33cdc78d0b666f49f9769ea44bf41d Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Wed, 26 Apr 2023 22:58:20 +0800 Subject: [PATCH 07/11] fix --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 76a435e6cf0e8..2924eaa1994da 100644 --- a/ggml.c +++ b/ggml.c @@ -7755,7 +7755,7 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne3 == ne03); #if defined(__AVX2__) || defined(__AVX__) - if (ggml_cpu_has_avx2() && ne00 <= 48 || ne00 <= 32) { + if ((ggml_cpu_has_avx2() && ne00 <= 48) || ne00 <= 32) { // Handle tall and skinny matrices // TODO(jon-chuang): Also check that we only handle 2D matrices? assert(ne00 == ne10); From 8cead207461e0bb05b5a0cbba5b3b8b617bf2e71 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Wed, 26 Apr 2023 23:03:54 +0800 Subject: [PATCH 08/11] done --- ggml.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 2924eaa1994da..5eb0e0ceb2483 100644 --- a/ggml.c +++ b/ggml.c @@ -494,7 +494,7 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i } // Store the result in the corresponding row of C - _mm256_store_ps(&res_vec, c_vec); + _mm256_store_ps((float *) &res_vec, c_vec); for (int k = 0; k < 8; ++k) { C[j+k] = res_vec[k]; @@ -7700,7 +7700,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(__AVX2__) || defined(__AVX__) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -7758,7 +7758,6 @@ static void ggml_compute_forward_mul_mat_f32( if ((ggml_cpu_has_avx2() && ne00 <= 48) || ne00 <= 32) { // Handle tall and skinny matrices // TODO(jon-chuang): Also check that we only handle 2D matrices? - assert(ne00 == ne10); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } From fb469ed972ac280d66707976fee03d1aeea9064f Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sun, 30 Apr 2023 18:24:56 +0800 Subject: [PATCH 09/11] fma compile only --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index b192719a51a2d..a4f4d68bb06bf 100644 --- a/ggml.c +++ b/ggml.c @@ -497,7 +497,7 @@ static inline int hsum_i32_4(const __m128i a) { } // AVX routine provided by GH user jon-chuang -#if __AVX2__ || __AVX512F__ +#if (__AVX2__ || __AVX512F__) && FMA // Given A = K X M, B = K X N, compute one row of C = A^TB void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { alignas(32) float res_vec[8]; @@ -535,7 +535,7 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i _mm256_maskstore_ps(&C[j], mask_vec, c_vec); } } -#elif __AVX__ +#elif __AVX__ && __FMA__ // Given A = K X M, B = K X N, compute one row of C = A^TB void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { for (int j = 0; j < N; j += 4) { // Process 4 elements of C's row at a time - 128 / size_of(float) From 470cc4c5d19687208f303b3b6f274ba28f87e69b Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sun, 30 Apr 2023 20:56:46 +0800 Subject: [PATCH 10/11] minor --- ggml.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index a4f4d68bb06bf..0a3537723b40a 100644 --- a/ggml.c +++ b/ggml.c @@ -497,7 +497,7 @@ static inline int hsum_i32_4(const __m128i a) { } // AVX routine provided by GH user jon-chuang -#if (__AVX2__ || __AVX512F__) && FMA +#if (__AVX2__ || __AVX512F__) && __FMA__ // Given A = K X M, B = K X N, compute one row of C = A^TB void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { alignas(32) float res_vec[8]; @@ -8123,7 +8123,7 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne2 == ne02); assert(ne3 == ne03); -#if defined(__AVX2__) || defined(__AVX__) +#if (__AVX512F__ || __AVX2__ || __AVX__) && __FMA__ if ((ggml_cpu_has_avx2() && ne00 <= 48) || ne00 <= 32) { // Handle tall and skinny matrices // TODO(jon-chuang): Also check that we only handle 2D matrices? From 979010cdba8d147c5fcab2e8dfd7e8d3c9782a14 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sun, 30 Apr 2023 21:02:55 +0800 Subject: [PATCH 11/11] minor --- ggml.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 0a3537723b40a..31592c758f95e 100644 --- a/ggml.c +++ b/ggml.c @@ -8124,9 +8124,8 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne3 == ne03); #if (__AVX512F__ || __AVX2__ || __AVX__) && __FMA__ + // Handle tall and skinny matrices if ((ggml_cpu_has_avx2() && ne00 <= 48) || ne00 <= 32) { - // Handle tall and skinny matrices - // TODO(jon-chuang): Also check that we only handle 2D matrices? if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; }