From 142c38a4f396cc785e7d71273f80049453099afc Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Wed, 19 Apr 2023 03:13:20 +0200 Subject: [PATCH] AVX2 implementation of ggml_vec_dot_q4_1_q8_0 --- ggml.c | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/ggml.c b/ggml.c index 2a54cd6f72e51..8f493cb1e3389 100644 --- a/ggml.c +++ b/ggml.c @@ -2518,6 +2518,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + const float * d0 = &x[i].d; + const float * d1 = &y[i].d; + const float * m0 = &x[i].m; + + const __m256 d0v = _mm256_broadcast_ss( d0 ); + const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 m0v = _mm256_broadcast_ss( m0 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytesFromNibbles( x[i].qs ); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8( bx, bx ); + + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8( by, bx ); + + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16( ax, sy ); + const __m256i ones = _mm256_set1_epi16( 1 ); + const __m256i xy_q = _mm256_madd_epi16( ones, dot ); + + // Convert to vector of 8 int32_t to 8 floats + const __m256 xy = _mm256_cvtepi32_ps( xy_q ); + + // Accumulate d0*d1*x*y + acc = _mm256_fmadd_ps( d0d1, xy, acc ); + + // Compute sum of y values + const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); + const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); + const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); + const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); + + // Accumulate d1*m0*y + acc = _mm256_fmadd_ps( d1m0, ysum, acc ); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); #else // scalar for (int i = 0; i < nb; i++) {