From 767cf1b3264c6bcb4fd41f5e1e76f2d606f25728 Mon Sep 17 00:00:00 2001 From: Casey Primozic Date: Sun, 19 Mar 2023 23:15:42 -0700 Subject: [PATCH 1/2] Add initial AVX512 support for dot product on Linux * Update Makefile to detect AVX512 support and add compiler flags if it's available * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16 * Use built-in AVX512 horizontal reduce add to get sum at the end --- Makefile | 20 ++++++++++++++++++ ggml.c | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index fb83d41badc56..fcfe7da292f79 100644 --- a/Makefile +++ b/Makefile @@ -95,6 +95,26 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) ifneq (,$(findstring sse3,$(SSE3_M))) CFLAGS += -msse3 endif + AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo) + ifneq (,$(findstring avx512f,$(AVX512F_M))) + CFLAGS += -mavx512f + endif + AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo) + ifneq (,$(findstring avx512bw,$(AVX512BW_M))) + CFLAGS += -mavx512bw + endif + AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo) + ifneq (,$(findstring avx512dq,$(AVX512DQ_M))) + CFLAGS += -mavx512dq + endif + AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo) + ifneq (,$(findstring avx512vl,$(AVX512VL_M))) + CFLAGS += -mavx512vl + endif + AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo) + ifneq (,$(findstring avx512cd,$(AVX512CD_M))) + CFLAGS += -mavx512cd + endif else ifeq ($(UNAME_S),Haiku) AVX1_M := $(shell sysinfo -cpu | grep "AVX ") ifneq (,$(findstring avx,$(AVX1_M))) diff --git a/ggml.c b/ggml.c index 535c7b7d281dd..9ebdff0217798 100644 --- a/ggml.c +++ b/ggml.c @@ -1412,10 +1412,68 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void #else #error "not implemented for QK" #endif -#elif defined(__AVX2__) +#elif defined(__AVX512F__) + +inline __m256i bytesFromNibbles( const uint8_t* rsi ){ + // Load 16 bytes from memory + __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); + + // Expand bytes into uint16_t values + __m256i bytes = _mm256_cvtepu8_epi16( tmp ); + + // Unpack values into individual bytes + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + __m256i high = _mm256_andnot_si256( lowMask, bytes ); + __m256i low = _mm256_and_si256( lowMask, bytes ); + high = _mm256_slli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + return bytes; +} + #if QK == 32 - const size_t countBlocks = nb; + // Initialize accumulator with zeros + __m512 acc = _mm512_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + const float * d0_0 = (const float *) (pd0 + i*bs); + const float * d1_0 = (const float *) (pd1 + i*bs); + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; + + // Compute combined scale for the block + float scaleScalar = d0_0[0] * d1_0[0]; + __m512 scale = _mm512_set1_ps( scaleScalar ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + __m256i bx = bytesFromNibbles( p0 ); + __m256i by = bytesFromNibbles( p1 ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + by = _mm256_sub_epi8( by, off ); + + // Sign-extend 16 signed bytes into int16_t + __m512i x32 = _mm512_cvtepi8_epi16( bx ); + __m512i y32 = _mm512_cvtepi8_epi16( by ); + // Compute products of int16_t integers, add pairwise + __m512i i64 = _mm512_madd_epi16( x32, y32 ); + + // Convert int32_t to float + __m512 p = _mm512_cvtepi32_ps( i64 ); + // Apply the scale, and accumulate + acc = _mm512_fmadd_ps( scale, p, acc ); + } + + // Horizontal sum of all lanes of the accumulator + sumf = _mm512_reduce_add_ps( acc ); +#else +#error "not implemented for QK" +#endif +#elif defined(__AVX2__) +#if QK == 32 // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -5485,7 +5543,7 @@ static void ggml_compute_forward_rms_norm_f32( mean /= ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - + memcpy(y, x, ne00 * sizeof(float)); // for (int i00 = 0; i00 < ne00; i00++) { // y[i00] = x[i00]; From b4d82fb19f1d765028bc73d0af696bd8215f312e Mon Sep 17 00:00:00 2001 From: Casey Primozic Date: Mon, 20 Mar 2023 03:26:21 -0700 Subject: [PATCH 2/2] Some optimizations to AVX512 code * Manual unrolling on inner dot product loop to reduce loop counter overhead * Add some extra AVX512 compiler flags if detected in makefile --- Makefile | 12 ++++++++ ggml.c | 88 +++++++++++++++++++++++++++++++++++++------------------- 2 files changed, 70 insertions(+), 30 deletions(-) diff --git a/Makefile b/Makefile index fcfe7da292f79..8d28ef6435971 100644 --- a/Makefile +++ b/Makefile @@ -115,6 +115,18 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) ifneq (,$(findstring avx512cd,$(AVX512CD_M))) CFLAGS += -mavx512cd endif + AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo) + ifneq (,$(findstring avx512er,$(AVX512ER_M))) + CFLAGS += -mavx512er + endif + AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo) + ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M))) + CFLAGS += -mavx512ifma + endif + AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo) + ifneq (,$(findstring avx512pf,$(AVX512PF_M))) + CFLAGS += -mavx512pf + endif else ifeq ($(UNAME_S),Haiku) AVX1_M := $(shell sysinfo -cpu | grep "AVX ") ifneq (,$(findstring avx,$(AVX1_M))) diff --git a/ggml.c b/ggml.c index 9ebdff0217798..4a453a531833a 100644 --- a/ggml.c +++ b/ggml.c @@ -1430,45 +1430,73 @@ inline __m256i bytesFromNibbles( const uint8_t* rsi ){ return bytes; } +inline __m512 process_one_block( + __m512 acc, + const uint8_t * pd0, + const uint8_t * pd1, + const uint8_t * pb0, + const uint8_t * pb1, + int i +) { + const float * d0_0 = (const float *) (pd0 + i*bs); + const float * d1_0 = (const float *) (pd1 + i*bs); + + const uint8_t * restrict p0 = pb0 + (i+0)*bs; + const uint8_t * restrict p1 = pb1 + (i+0)*bs; + + // Compute combined scale for the block + float scaleScalar = d0_0[0] * d1_0[0]; + __m512 scale = _mm512_set1_ps( scaleScalar ); + + __m256i bx = bytesFromNibbles( p0 ); + __m256i by = bytesFromNibbles( p1 ); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + by = _mm256_sub_epi8( by, off ); + + // Sign-extend 16 signed bytes into int16_t + __m512i x32 = _mm512_cvtepi8_epi16( bx ); + __m512i y32 = _mm512_cvtepi8_epi16( by ); + // Compute products of int16_t integers, add pairwise + __m512i i64 = _mm512_madd_epi16( x32, y32 ); + + // Convert int32_t to float + __m512 p = _mm512_cvtepi32_ps( i64 ); + // Apply the scale, and accumulate + return _mm512_fmadd_ps( scale, p, acc ); +} + #if QK == 32 // Initialize accumulator with zeros - __m512 acc = _mm512_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - const float * d0_0 = (const float *) (pd0 + i*bs); - const float * d1_0 = (const float *) (pd1 + i*bs); + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); - const uint8_t * restrict p0 = pb0 + i*bs; - const uint8_t * restrict p1 = pb1 + i*bs; + const int superblock_size = 8; + const int superblock_count = nb / superblock_size; + const int remainder = nb % superblock_size; - // Compute combined scale for the block - float scaleScalar = d0_0[0] * d1_0[0]; - __m512 scale = _mm512_set1_ps( scaleScalar ); + for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) { + int i = superblock_ix * superblock_size; - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - __m256i bx = bytesFromNibbles( p0 ); - __m256i by = bytesFromNibbles( p1 ); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - by = _mm256_sub_epi8( by, off ); - - // Sign-extend 16 signed bytes into int16_t - __m512i x32 = _mm512_cvtepi8_epi16( bx ); - __m512i y32 = _mm512_cvtepi8_epi16( by ); - // Compute products of int16_t integers, add pairwise - __m512i i64 = _mm512_madd_epi16( x32, y32 ); + acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+0 ); + acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+1 ); + acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+2 ); + acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+3 ); + acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+4 ); + acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+5 ); + acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+6 ); + acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+7 ); + } - // Convert int32_t to float - __m512 p = _mm512_cvtepi32_ps( i64 ); - // Apply the scale, and accumulate - acc = _mm512_fmadd_ps( scale, p, acc ); + // Remainders + for (int i = superblock_count * superblock_size; i < nb; ++i) { + acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i ); } // Horizontal sum of all lanes of the accumulator - sumf = _mm512_reduce_add_ps( acc ); + sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); #else #error "not implemented for QK" #endif