@@ -487,6 +487,15 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487
487
return bytes ;
488
488
}
489
489
490
+ // horizontally add 8 floats
491
+ static inline float hsum_float_8 (const __m256 x ) {
492
+ __m128 res = _mm256_extractf128_ps (x , 1 );
493
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (x ));
494
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
495
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
496
+ return _mm_cvtss_f32 (res );
497
+ }
498
+
490
499
#if __AVX2__ || __AVX512F__
491
500
// Unpack 32 4-bit fields into 32 bytes
492
501
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@@ -507,6 +516,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507
516
return bytes ;
508
517
}
509
518
519
+ // add int16_t pairwise and return as float vector
520
+ static inline __m256 sum_i16_pairs_float (const __m256i x ) {
521
+ const __m256i ones = _mm256_set1_epi16 (1 );
522
+ const __m256i summed_pairs = _mm256_madd_epi16 (ones , x );
523
+ return _mm256_cvtepi32_ps (summed_pairs );
524
+ }
525
+
526
+ // multiply int8_t, add results pairwise twice and return as float vector
527
+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
528
+ // Get absolute values of x vectors
529
+ const __m256i ax = _mm256_sign_epi8 (x , x );
530
+ // Sign the values of the y vectors
531
+ const __m256i sy = _mm256_sign_epi8 (y , x );
532
+ // Perform multiplication and create 16-bit values
533
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
534
+ return sum_i16_pairs_float (dot );
535
+ }
536
+
510
537
static inline __m128i packNibbles ( __m256i bytes )
511
538
{
512
539
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2366,8 +2393,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2366
2393
const block_q4_0 * restrict x = vx ;
2367
2394
const block_q8_0 * restrict y = vy ;
2368
2395
2369
- float sumf = 0.0 ;
2370
-
2371
2396
#if defined(__ARM_NEON )
2372
2397
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2373
2398
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2436,7 +2461,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2461
#endif
2437
2462
}
2438
2463
2439
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2464
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2440
2465
#elif defined(__AVX2__ )
2441
2466
// Initialize accumulator with zeros
2442
2467
__m256 acc = _mm256_setzero_ps ();
@@ -2454,32 +2479,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2454
2479
2455
2480
__m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2456
2481
2457
- // Get absolute values of x vectors
2458
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2459
-
2460
- // Sign the values of the y vectors
2461
- const __m256i sy = _mm256_sign_epi8 (by , bx );
2462
-
2463
- // Perform multiplication and create 16-bit values
2464
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2465
-
2466
- const __m256i ones = _mm256_set1_epi16 (1 );
2467
- __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2468
-
2469
- /* Convert to vectore of 8 int32_t to 8 floats */
2470
- __m256 q = _mm256_cvtepi32_ps ( xy_q );
2482
+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2471
2483
2472
2484
/* Multiply q with scale and accumulate */
2473
2485
acc = _mm256_fmadd_ps ( d , q , acc );
2474
2486
}
2475
2487
2476
- // Return horizontal sum of the acc vector
2477
- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2478
- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2479
- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2480
- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2481
-
2482
- sumf = _mm_cvtss_f32 ( res );
2488
+ * s = hsum_float_8 (acc );
2483
2489
#elif defined(__AVX__ )
2484
2490
// Initialize accumulator with zeros
2485
2491
__m256 acc = _mm256_setzero_ps ();
@@ -2518,15 +2524,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2518
2524
acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2519
2525
}
2520
2526
2521
- // Return horizontal sum of the acc vector
2522
- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2523
- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2524
- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2525
- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2526
-
2527
- sumf = _mm_cvtss_f32 ( res );
2527
+ * s = hsum_float_8 (acc );
2528
2528
#else
2529
2529
// scalar
2530
+ float sumf = 0.0 ;
2530
2531
for (int i = 0 ; i < nb ; i ++ ) {
2531
2532
const float d0 = x [i ].d ;
2532
2533
const float d1 = y [i ].d ;
@@ -2548,9 +2549,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2548
2549
}
2549
2550
sumf += d0 * d1 * sumi ;
2550
2551
}
2551
- #endif
2552
-
2553
2552
* s = sumf ;
2553
+ #endif
2554
2554
}
2555
2555
2556
2556
static void ggml_vec_dot_q4_1_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2562,8 +2562,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2562
2562
const block_q4_1 * restrict x = vx ;
2563
2563
const block_q8_0 * restrict y = vy ;
2564
2564
2565
- float sumf = 0.0 ;
2566
-
2567
2565
// TODO: add AVX / WASM SIMD / etc
2568
2566
#if defined(__ARM_NEON )
2569
2567
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
@@ -2637,7 +2635,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
2635
#endif
2638
2636
}
2639
2637
2640
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2638
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2641
2639
#elif defined(__AVX2__ )
2642
2640
// Initialize accumulator with zeros
2643
2641
__m256 acc = _mm256_setzero_ps ();
@@ -2660,42 +2658,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2660
2658
const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2661
2659
const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2662
2660
2663
- // Get absolute values of x vectors
2664
- const __m256i ax = _mm256_sign_epi8 ( bx , bx );
2665
-
2666
- // Sign the values of the y vectors
2667
- const __m256i sy = _mm256_sign_epi8 ( by , bx );
2668
-
2669
- // Perform multiplication and create 16-bit values
2670
- const __m256i dot = _mm256_maddubs_epi16 ( ax , sy );
2671
- const __m256i ones = _mm256_set1_epi16 ( 1 );
2672
- const __m256i xy_q = _mm256_madd_epi16 ( ones , dot );
2673
-
2674
- // Convert to vector of 8 int32_t to 8 floats
2675
- const __m256 xy = _mm256_cvtepi32_ps ( xy_q );
2661
+ const __m256 xy = mul_sum_i8_pairs_float (bx , by );
2676
2662
2677
2663
// Accumulate d0*d1*x*y
2678
2664
acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2679
2665
2680
2666
// Compute sum of y values
2681
2667
const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2682
2668
const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2669
+ const __m256 ysum = sum_i16_pairs_float (_mm256_add_epi16 (y16_l , y16_h ));
2685
2670
2686
2671
// Accumulate d1*m0*y
2687
2672
acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2688
2673
}
2689
2674
2690
- // Return horizontal sum of the acc vector
2691
- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2692
- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2693
- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2694
- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2695
-
2696
- sumf = _mm_cvtss_f32 ( res );
2675
+ * s = hsum_float_8 (acc );
2697
2676
#else
2698
2677
// scalar
2678
+ float sumf = 0.0 ;
2699
2679
for (int i = 0 ; i < nb ; i ++ ) {
2700
2680
const float d0 = x [i ].d ;
2701
2681
const float m0 = x [i ].m ;
@@ -2717,9 +2697,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2717
2697
sumf += f0 * f2 + f1 * f3 ;
2718
2698
}
2719
2699
}
2720
- #endif
2721
-
2722
2700
* s = sumf ;
2701
+ #endif
2723
2702
}
2724
2703
2725
2704
static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2732,8 +2711,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2732
2711
const block_q4_2 * restrict x = vx ;
2733
2712
const block_q8_0 * restrict y = vy ;
2734
2713
2735
- float sumf = 0.0 ;
2736
-
2737
2714
#if defined(__ARM_NEON )
2738
2715
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2739
2716
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2811,7 +2788,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2811
2788
#endif
2812
2789
}
2813
2790
2814
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2791
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2815
2792
#elif defined(__AVX2__ )
2816
2793
// Initialize accumulator with zeros
2817
2794
__m256 acc = _mm256_setzero_ps ();
@@ -2833,32 +2810,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2833
2810
2834
2811
__m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2835
2812
2836
- // Get absolute values of x vectors
2837
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2838
- // Sign the values of the y vectors
2839
- const __m256i sy = _mm256_sign_epi8 (by , bx );
2840
- // Perform multiplication and create 16-bit values
2841
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2842
-
2843
- const __m256i ones = _mm256_set1_epi16 (1 );
2844
- __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2845
-
2846
- /* Convert to vectore of 8 int32_t to 8 floats */
2847
- __m256 q = _mm256_cvtepi32_ps (xy_q );
2813
+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2848
2814
2849
2815
/* Multiply q with scale and accumulate */
2850
2816
acc = _mm256_fmadd_ps (d , q , acc );
2851
2817
}
2852
2818
2853
- // Return horizontal sum of the acc vector
2854
- __m128 res = _mm256_extractf128_ps (acc , 1 );
2855
- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2856
- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2857
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2858
-
2859
- sumf = _mm_cvtss_f32 (res );
2819
+ * s = hsum_float_8 (acc );
2860
2820
#else
2861
2821
// scalar
2822
+ float sumf = 0.0 ;
2862
2823
for (int i = 0 ; i < nb ; i ++ ) {
2863
2824
const uint8_t * restrict x0 = x [2 * i + 0 ].qs ;
2864
2825
const uint8_t * restrict x1 = x [2 * i + 1 ].qs ;
@@ -2893,9 +2854,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2893
2854
sumf += (d0 * y [i ].d ) * sumi_0 ;
2894
2855
sumf += (d1 * y [i ].d ) * sumi_1 ;
2895
2856
}
2896
- #endif
2897
-
2898
2857
* s = sumf ;
2858
+ #endif
2899
2859
}
2900
2860
2901
2861
static void ggml_vec_dot_q4_3_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
@@ -2908,8 +2868,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2908
2868
const block_q4_3 * restrict x = vx ;
2909
2869
const block_q8_0 * restrict y = vy ;
2910
2870
2911
- float sumf = 0.0 ;
2912
-
2913
2871
#if defined(__ARM_NEON )
2914
2872
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2915
2873
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
@@ -2995,9 +2953,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2995
2953
#endif
2996
2954
}
2997
2955
2998
- sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2956
+ * s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2957
+ #elif defined(__AVX2__ )
2958
+ // Initialize accumulator with zeros
2959
+ __m256 acc = _mm256_setzero_ps ();
2960
+
2961
+ // Main loop
2962
+ for (int i = 0 ; i < nb ; i ++ ) {
2963
+ const __m128 d0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].d ));
2964
+ const __m128 d1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].d ));
2965
+ const __m256 dx = _mm256_set_m128 (d1 , d0 );
2966
+
2967
+ const __m128 m0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].m ));
2968
+ const __m128 m1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].m ));
2969
+ const __m256 mx = _mm256_set_m128 (m1 , m0 );
2970
+
2971
+ const __m128i bx0 = bytes_from_nibbles_16 (x [2 * i + 0 ].qs );
2972
+ const __m128i bx1 = bytes_from_nibbles_16 (x [2 * i + 1 ].qs );
2973
+ const __m256i bx = _mm256_set_m128i (bx1 , bx0 );
2974
+
2975
+ const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
2976
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2977
+
2978
+ const __m256i syi = _mm256_maddubs_epi16 (_mm256_set1_epi8 (1 ), by );
2979
+ const __m256 syf = sum_i16_pairs_float (syi );
2980
+
2981
+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2982
+
2983
+ const __m256 sxy = _mm256_fmadd_ps (q , dx , _mm256_mul_ps (mx , syf ));
2984
+ acc = _mm256_fmadd_ps (sxy , dy , acc );
2985
+ }
2986
+
2987
+ * s = hsum_float_8 (acc );
2999
2988
#else
3000
2989
// scalar
2990
+ float sumf = 0.0 ;
3001
2991
for (int i = 0 ; i < nb ; i ++ ) {
3002
2992
const uint8_t * restrict x0 = x [2 * i + 0 ].qs ;
3003
2993
const uint8_t * restrict x1 = x [2 * i + 1 ].qs ;
@@ -3040,9 +3030,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
3040
3030
sumf += (d0 * sxy_0 + m0 * sy_0 )* y [i ].d ;
3041
3031
sumf += (d1 * sxy_1 + m1 * sy_1 )* y [i ].d ;
3042
3032
}
3043
- #endif
3044
-
3045
3033
* s = sumf ;
3034
+ #endif
3046
3035
}
3047
3036
3048
3037
0 commit comments