@@ -9282,6 +9282,14 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
9282
9282
#endif
9283
9283
}
9284
9284
9285
+ #ifdef __AVX2__
9286
+ static inline __m256i mul_add_epi8 (const __m256i x , const __m256i y ) {
9287
+ const __m256i ax = _mm256_sign_epi8 (x , x );
9288
+ const __m256i sy = _mm256_sign_epi8 (y , x );
9289
+ return _mm256_maddubs_epi16 (ax , sy );
9290
+ }
9291
+ #endif
9292
+
9285
9293
void ggml_vec_dot_iq1_s_q8_K (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
9286
9294
assert (n % QK_K == 0 );
9287
9295
@@ -9290,6 +9298,59 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
9290
9298
9291
9299
const int nb = n / QK_K ;
9292
9300
9301
+ #if defined __AVX2__
9302
+
9303
+ const __m128i m8 = _mm_set1_epi8 (0x08 );
9304
+ const __m128i m7 = _mm_set1_epi8 (0x07 );
9305
+ const __m128i shuffle_h = _mm_set_epi8 (15 , 7 , 14 , 6 , 13 , 5 , 12 , 4 , 11 , 3 , 10 , 2 , 9 , 1 , 8 , 0 );
9306
+ const __m128i shuffle_s [4 ] = {
9307
+ _mm_set_epi32 (0x03030303 , 0x02020202 , 0x01010101 , 0x00000000 ),
9308
+ _mm_set_epi32 (0x07070707 , 0x06060606 , 0x05050505 , 0x04040404 ),
9309
+ _mm_set_epi32 (0x0b0b0b0b , 0x0a0a0a0a , 0x09090909 , 0x08080808 ),
9310
+ _mm_set_epi32 (0x0f0f0f0f , 0x0e0e0e0e , 0x0d0d0d0d , 0x0c0c0c0c )
9311
+ };
9312
+
9313
+ uint64_t aux64 ;
9314
+
9315
+ __m256i v_gindex ;
9316
+ const uint16_t * gindex = (const uint16_t * )& v_gindex ;
9317
+
9318
+ __m256 accum = _mm256_setzero_ps ();
9319
+ for (int i = 0 ; i < nb ; ++ i ) {
9320
+
9321
+ const int8_t * q8 = y [i ].qs ;
9322
+ const uint8_t * qs = x [i ].qs ;
9323
+ const uint8_t * sc = x [i ].scales ;
9324
+
9325
+ __m256i sumi = _mm256_setzero_si256 ();
9326
+ for (int i128 = 0 ; i128 < QK_K /128 ; ++ i128 ) {
9327
+ const __m128i ql = _mm_loadu_si128 ((const __m128i * )qs ); qs += 16 ;
9328
+ memcpy (& aux64 , sc , 8 ); sc += 8 ;
9329
+ const __m128i qh = _mm_shuffle_epi8 (_mm_set_epi64x (aux64 >> 4 , aux64 ), shuffle_h );
9330
+ const __m256i hbit = _mm256_cvtepi8_epi16 (_mm_and_si128 (qh , m8 ));
9331
+ v_gindex = _mm256_or_si256 (_mm256_cvtepi8_epi16 (ql ), _mm256_slli_epi16 (hbit , 5 ));
9332
+ const __m128i scales = _mm_and_si128 (qh , m7 );
9333
+
9334
+ for (int i32 = 0 ; i32 < 4 ; ++ i32 ) {
9335
+ const __m256i q8b = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
9336
+ const __m256i q1b = _mm256_set_epi64x (iq1s_grid [gindex [4 * i32 + 3 ]], iq1s_grid [gindex [4 * i32 + 2 ]],
9337
+ iq1s_grid [gindex [4 * i32 + 1 ]], iq1s_grid [gindex [4 * i32 + 0 ]]);
9338
+ const __m256i dot = mul_add_epi8 (q1b , q8b );
9339
+ const __m256i s16 = _mm256_cvtepi8_epi16 (_mm_shuffle_epi8 (scales , shuffle_s [i32 ]));
9340
+ const __m256i p = _mm256_madd_epi16 (s16 , dot );
9341
+ sumi = _mm256_add_epi32 (sumi , p );
9342
+ }
9343
+
9344
+ }
9345
+
9346
+ accum = _mm256_fmadd_ps (_mm256_set1_ps (y [i ].d * GGML_FP16_TO_FP32 (x [i ].d )), _mm256_cvtepi32_ps (sumi ), accum );
9347
+
9348
+ }
9349
+
9350
+ * s = hsum_float_8 (accum );
9351
+
9352
+ #else
9353
+
9293
9354
int db [4 ];
9294
9355
uint16_t idx [4 ];
9295
9356
@@ -9326,6 +9387,8 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
9326
9387
9327
9388
* s = sumf ;
9328
9389
9390
+ #endif
9391
+
9329
9392
}
9330
9393
9331
9394
// ================================ IQ2 quantization =============================================
0 commit comments