@@ -461,6 +461,39 @@ static inline __m128i packNibbles( __m256i bytes )
461
461
__m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
462
462
return _mm_packus_epi16 ( r0 , r1 );
463
463
}
464
+ #elif __AVX__
465
+ static inline __m128i bytesFromNibbles ( const uint8_t * rsi )
466
+ {
467
+ // Load 8 bytes from memory
468
+ __m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
469
+
470
+ // Expand bytes into uint16_t values
471
+ __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
472
+
473
+ // Unpack values into individual bytes
474
+ const __m128i lowMask = _mm_set1_epi8 ( 0xF );
475
+ __m128i high = _mm_andnot_si128 ( lowMask , bytes );
476
+ __m128i low = _mm_and_si128 ( lowMask , bytes );
477
+ high = _mm_slli_epi16 ( high , 4 );
478
+ bytes = _mm_or_si128 ( low , high );
479
+ return bytes ;
480
+ }
481
+
482
+ static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
483
+ {
484
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
485
+ const __m128i lowByte = _mm_set1_epi16 ( 0xFF );
486
+ __m128i high = _mm_andnot_si128 ( lowByte , bytes1 );
487
+ __m128i low = _mm_and_si128 ( lowByte , bytes1 );
488
+ high = _mm_srli_epi16 ( high , 4 );
489
+ bytes1 = _mm_or_si128 ( low , high );
490
+ high = _mm_andnot_si128 ( lowByte , bytes2 );
491
+ low = _mm_and_si128 ( lowByte , bytes2 );
492
+ high = _mm_srli_epi16 ( high , 4 );
493
+ bytes2 = _mm_or_si128 ( low , high );
494
+
495
+ return _mm_packus_epi16 ( bytes1 , bytes2 );
496
+ }
464
497
#endif
465
498
466
499
// method 5
@@ -660,6 +693,80 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
660
693
__m128i res = packNibbles ( i0 );
661
694
_mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
662
695
}
696
+ #elif defined(__AVX__ )
697
+ for (int i = 0 ; i < nb ; i ++ ) {
698
+ // Load elements into 4 AVX vectors
699
+ __m256 v0 = _mm256_loadu_ps ( x );
700
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
701
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
702
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
703
+ x += 32 ;
704
+
705
+ // Compute max(abs(e)) for the block
706
+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
707
+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
708
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
709
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
710
+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
711
+
712
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
713
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
714
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
715
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
716
+
717
+ // Quantize these floats
718
+ const float d = maxScalar / 7.0f ;
719
+ y [i ].d = d ;
720
+ const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f ;
721
+ const __m256 mul = _mm256_set1_ps ( id );
722
+
723
+ // Apply the multiplier
724
+ v0 = _mm256_mul_ps ( v0 , mul );
725
+ v1 = _mm256_mul_ps ( v1 , mul );
726
+ v2 = _mm256_mul_ps ( v2 , mul );
727
+ v3 = _mm256_mul_ps ( v3 , mul );
728
+
729
+ // Round to nearest integer
730
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
731
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
732
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
733
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
734
+
735
+ // Convert floats to integers
736
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
737
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
738
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
739
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
740
+
741
+ // Since we don't have in AVX some necessary functions,
742
+ // we split the registers in half and call AVX2 analogs from SSE
743
+ __m128i ni0 = _mm256_castsi256_si128 ( i0 );
744
+ __m128i ni1 = _mm256_extractf128_si256 ( i0 , 1 );
745
+ __m128i ni2 = _mm256_castsi256_si128 ( i1 );
746
+ __m128i ni3 = _mm256_extractf128_si256 ( i1 , 1 );
747
+ __m128i ni4 = _mm256_castsi256_si128 ( i2 );
748
+ __m128i ni5 = _mm256_extractf128_si256 ( i2 , 1 );
749
+ __m128i ni6 = _mm256_castsi256_si128 ( i3 );
750
+ __m128i ni7 = _mm256_extractf128_si256 ( i3 , 1 );
751
+
752
+ // Convert int32 to int16
753
+ ni0 = _mm_packs_epi32 ( ni0 , ni1 );
754
+ ni2 = _mm_packs_epi32 ( ni2 , ni3 );
755
+ ni4 = _mm_packs_epi32 ( ni4 , ni5 );
756
+ ni6 = _mm_packs_epi32 ( ni6 , ni7 );
757
+ // Convert int16 to int8
758
+ ni0 = _mm_packs_epi16 ( ni0 , ni2 );
759
+ ni4 = _mm_packs_epi16 ( ni4 , ni6 );
760
+
761
+ // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
762
+ const __m128i off = _mm_set1_epi8 ( 8 );
763
+ ni0 = _mm_add_epi8 ( ni0 , off );
764
+ ni4 = _mm_add_epi8 ( ni4 , off );
765
+
766
+ // Compress the vector into 4 bit/value, and store
767
+ __m128i res = packNibbles ( ni0 , ni4 );
768
+ _mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
769
+ }
663
770
#elif defined(__wasm_simd128__ )
664
771
for (int i = 0 ; i < nb ; i ++ ) {
665
772
float amax = 0.0f ; // absolute max
@@ -1892,6 +1999,52 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1892
1999
res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
1893
2000
res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
1894
2001
2002
+ sumf = _mm_cvtss_f32 ( res );
2003
+ #elif defined(__AVX__ )
2004
+ // Initialize accumulator with zeros
2005
+ __m256 acc = _mm256_setzero_ps ();
2006
+
2007
+ // Main loop
2008
+ for (int i = 0 ; i < nb ; ++ i ) {
2009
+ // Compute combined scale for the block
2010
+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2011
+
2012
+ __m128i i32 [2 ];
2013
+ for (int j = 0 ; j < 2 ; ++ j ) {
2014
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2015
+ __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2016
+ __m128i by = bytesFromNibbles ( y [i ].qs + 8 * j );
2017
+
2018
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2019
+ const __m128i off = _mm_set1_epi8 ( 8 );
2020
+ bx = _mm_sub_epi8 ( bx , off );
2021
+ by = _mm_sub_epi8 ( by , off );
2022
+
2023
+ // Get absolute values of x vectors
2024
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2025
+
2026
+ // Sign the values of the y vectors
2027
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2028
+
2029
+ // Perform multiplication and create 16-bit values
2030
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2031
+
2032
+ const __m128i ones = _mm_set1_epi16 (1 );
2033
+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2034
+ }
2035
+
2036
+ // Convert int32_t to float
2037
+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [0 ], i32 [1 ] ));
2038
+ // Apply the scale, and accumulate
2039
+ acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2040
+ }
2041
+
2042
+ // Return horizontal sum of the acc vector
2043
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2044
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2045
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2046
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2047
+
1895
2048
sumf = _mm_cvtss_f32 ( res );
1896
2049
#elif defined(__wasm_simd128__ )
1897
2050
// wasm simd
0 commit comments