Skip to content

Commit a32f581

Browse files
committed
AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring
1 parent 9ff334f commit a32f581

File tree

1 file changed

+78
-89
lines changed

1 file changed

+78
-89
lines changed

ggml.c

Lines changed: 78 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,15 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487487
return bytes;
488488
}
489489

490+
// horizontally add 8 floats
491+
static inline float hsum_float_8(const __m256 x) {
492+
__m128 res = _mm256_extractf128_ps(x, 1);
493+
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
494+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
495+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
496+
return _mm_cvtss_f32(res);
497+
}
498+
490499
#if __AVX2__ || __AVX512F__
491500
// Unpack 32 4-bit fields into 32 bytes
492501
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@@ -507,6 +516,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507516
return bytes;
508517
}
509518

519+
// add int16_t pairwise and return as float vector
520+
static inline __m256 sum_i16_pairs_float(const __m256i x) {
521+
const __m256i ones = _mm256_set1_epi16(1);
522+
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
523+
return _mm256_cvtepi32_ps(summed_pairs);
524+
}
525+
526+
// multiply int8_t, add results pairwise twice and return as float vector
527+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
528+
// Get absolute values of x vectors
529+
const __m256i ax = _mm256_sign_epi8(x, x);
530+
// Sign the values of the y vectors
531+
const __m256i sy = _mm256_sign_epi8(y, x);
532+
// Perform multiplication and create 16-bit values
533+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
534+
return sum_i16_pairs_float(dot);
535+
}
536+
510537
static inline __m128i packNibbles( __m256i bytes )
511538
{
512539
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2366,8 +2393,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23662393
const block_q4_0 * restrict x = vx;
23672394
const block_q8_0 * restrict y = vy;
23682395

2369-
float sumf = 0.0;
2370-
23712396
#if defined(__ARM_NEON)
23722397
float32x4_t sumv0 = vdupq_n_f32(0.0f);
23732398
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2436,7 +2461,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362461
#endif
24372462
}
24382463

2439-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2464+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
24402465
#elif defined(__AVX2__)
24412466
// Initialize accumulator with zeros
24422467
__m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2479,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24542479

24552480
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
24562481

2457-
// Get absolute values of x vectors
2458-
const __m256i ax = _mm256_sign_epi8(bx, bx);
2459-
2460-
// Sign the values of the y vectors
2461-
const __m256i sy = _mm256_sign_epi8(by, bx);
2462-
2463-
// Perform multiplication and create 16-bit values
2464-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2465-
2466-
const __m256i ones = _mm256_set1_epi16(1);
2467-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2468-
2469-
/* Convert to vectore of 8 int32_t to 8 floats */
2470-
__m256 q = _mm256_cvtepi32_ps( xy_q );
2482+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
24712483

24722484
/* Multiply q with scale and accumulate */
24732485
acc = _mm256_fmadd_ps( d, q, acc );
24742486
}
24752487

2476-
// Return horizontal sum of the acc vector
2477-
__m128 res = _mm256_extractf128_ps( acc, 1 );
2478-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2479-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2480-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2481-
2482-
sumf = _mm_cvtss_f32( res );
2488+
*s = hsum_float_8(acc);
24832489
#elif defined(__AVX__)
24842490
// Initialize accumulator with zeros
24852491
__m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2524,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25182524
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
25192525
}
25202526

2521-
// Return horizontal sum of the acc vector
2522-
__m128 res = _mm256_extractf128_ps( acc, 1 );
2523-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2524-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2525-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2526-
2527-
sumf = _mm_cvtss_f32( res );
2527+
*s = hsum_float_8(acc);
25282528
#else
25292529
// scalar
2530+
float sumf = 0.0;
25302531
for (int i = 0; i < nb; i++) {
25312532
const float d0 = x[i].d;
25322533
const float d1 = y[i].d;
@@ -2548,9 +2549,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25482549
}
25492550
sumf += d0*d1*sumi;
25502551
}
2551-
#endif
2552-
25532552
*s = sumf;
2553+
#endif
25542554
}
25552555

25562556
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2562,8 +2562,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25622562
const block_q4_1 * restrict x = vx;
25632563
const block_q8_0 * restrict y = vy;
25642564

2565-
float sumf = 0.0;
2566-
25672565
// TODO: add AVX / WASM SIMD / etc
25682566
#if defined(__ARM_NEON)
25692567
float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -2637,7 +2635,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372635
#endif
26382636
}
26392637

2640-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2638+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
26412639
#elif defined(__AVX2__)
26422640
// Initialize accumulator with zeros
26432641
__m256 acc = _mm256_setzero_ps();
@@ -2660,42 +2658,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26602658
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
26612659
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
26622660

2663-
// Get absolute values of x vectors
2664-
const __m256i ax = _mm256_sign_epi8( bx, bx );
2665-
2666-
// Sign the values of the y vectors
2667-
const __m256i sy = _mm256_sign_epi8( by, bx );
2668-
2669-
// Perform multiplication and create 16-bit values
2670-
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671-
const __m256i ones = _mm256_set1_epi16( 1 );
2672-
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673-
2674-
// Convert to vector of 8 int32_t to 8 floats
2675-
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2661+
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
26762662

26772663
// Accumulate d0*d1*x*y
26782664
acc = _mm256_fmadd_ps( d0d1, xy, acc );
26792665

26802666
// Compute sum of y values
26812667
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
26822668
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2669+
const __m256 ysum = sum_i16_pairs_float(_mm256_add_epi16(y16_l, y16_h));
26852670

26862671
// Accumulate d1*m0*y
26872672
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
26882673
}
26892674

2690-
// Return horizontal sum of the acc vector
2691-
__m128 res = _mm256_extractf128_ps( acc, 1 );
2692-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695-
2696-
sumf = _mm_cvtss_f32( res );
2675+
*s = hsum_float_8(acc);
26972676
#else
26982677
// scalar
2678+
float sumf = 0.0;
26992679
for (int i = 0; i < nb; i++) {
27002680
const float d0 = x[i].d;
27012681
const float m0 = x[i].m;
@@ -2717,9 +2697,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
27172697
sumf += f0*f2 + f1*f3;
27182698
}
27192699
}
2720-
#endif
2721-
27222700
*s = sumf;
2701+
#endif
27232702
}
27242703

27252704
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2732,8 +2711,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
27322711
const block_q4_2 * restrict x = vx;
27332712
const block_q8_0 * restrict y = vy;
27342713

2735-
float sumf = 0.0;
2736-
27372714
#if defined(__ARM_NEON)
27382715
float32x4_t sumv0 = vdupq_n_f32(0.0f);
27392716
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2811,7 +2788,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28112788
#endif
28122789
}
28132790

2814-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2791+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
28152792
#elif defined(__AVX2__)
28162793
// Initialize accumulator with zeros
28172794
__m256 acc = _mm256_setzero_ps();
@@ -2833,32 +2810,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28332810

28342811
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
28352812

2836-
// Get absolute values of x vectors
2837-
const __m256i ax = _mm256_sign_epi8(bx, bx);
2838-
// Sign the values of the y vectors
2839-
const __m256i sy = _mm256_sign_epi8(by, bx);
2840-
// Perform multiplication and create 16-bit values
2841-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2842-
2843-
const __m256i ones = _mm256_set1_epi16(1);
2844-
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2845-
2846-
/* Convert to vectore of 8 int32_t to 8 floats */
2847-
__m256 q = _mm256_cvtepi32_ps(xy_q);
2813+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
28482814

28492815
/* Multiply q with scale and accumulate */
28502816
acc = _mm256_fmadd_ps(d, q, acc);
28512817
}
28522818

2853-
// Return horizontal sum of the acc vector
2854-
__m128 res = _mm256_extractf128_ps(acc, 1);
2855-
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856-
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857-
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858-
2859-
sumf = _mm_cvtss_f32(res);
2819+
*s = hsum_float_8(acc);
28602820
#else
28612821
// scalar
2822+
float sumf = 0.0;
28622823
for (int i = 0; i < nb; i++) {
28632824
const uint8_t * restrict x0 = x[2*i + 0].qs;
28642825
const uint8_t * restrict x1 = x[2*i + 1].qs;
@@ -2893,9 +2854,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
28932854
sumf += (d0 * y[i].d) * sumi_0;
28942855
sumf += (d1 * y[i].d) * sumi_1;
28952856
}
2896-
#endif
2897-
28982857
*s = sumf;
2858+
#endif
28992859
}
29002860

29012861
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2908,8 +2868,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29082868
const block_q4_3 * restrict x = vx;
29092869
const block_q8_0 * restrict y = vy;
29102870

2911-
float sumf = 0.0;
2912-
29132871
#if defined(__ARM_NEON)
29142872
float32x4_t sumv0 = vdupq_n_f32(0.0f);
29152873
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2995,9 +2953,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29952953
#endif
29962954
}
29972955

2998-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2956+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2957+
#elif defined(__AVX2__)
2958+
// Initialize accumulator with zeros
2959+
__m256 acc = _mm256_setzero_ps();
2960+
2961+
// Main loop
2962+
for (int i = 0; i < nb; i++) {
2963+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2964+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2965+
const __m256 dx = _mm256_set_m128(d1, d0);
2966+
2967+
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
2968+
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
2969+
const __m256 mx = _mm256_set_m128(m1, m0);
2970+
2971+
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2972+
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2973+
const __m256i bx = _mm256_set_m128i(bx1, bx0);
2974+
2975+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
2976+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2977+
2978+
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
2979+
const __m256 syf = sum_i16_pairs_float(syi);
2980+
2981+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
2982+
2983+
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
2984+
acc = _mm256_fmadd_ps(sxy, dy, acc);
2985+
}
2986+
2987+
*s = hsum_float_8(acc);
29992988
#else
30002989
// scalar
2990+
float sumf = 0.0;
30012991
for (int i = 0; i < nb; i++) {
30022992
const uint8_t * restrict x0 = x[2*i + 0].qs;
30032993
const uint8_t * restrict x1 = x[2*i + 1].qs;
@@ -3040,9 +3030,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
30403030
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
30413031
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
30423032
}
3043-
#endif
3044-
30453033
*s = sumf;
3034+
#endif
30463035
}
30473036

30483037

0 commit comments

Comments
 (0)