Skip to content

Commit 02c5b27

Browse files
authored
Add AVX acceleration (#617)
* ggml : add AVX quantize_row_q4_0() * ggml : add AVX ggml_vec_dot_q4_0() * ggml : refactor AVX part of ggml_vec_dot_q4_0() #617 (comment)
1 parent cbef542 commit 02c5b27

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

ggml.c

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,39 @@ static inline __m128i packNibbles( __m256i bytes )
461461
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
462462
return _mm_packus_epi16( r0, r1 );
463463
}
464+
#elif __AVX__
465+
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
466+
{
467+
// Load 8 bytes from memory
468+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
469+
470+
// Expand bytes into uint16_t values
471+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
472+
473+
// Unpack values into individual bytes
474+
const __m128i lowMask = _mm_set1_epi8( 0xF );
475+
__m128i high = _mm_andnot_si128( lowMask, bytes );
476+
__m128i low = _mm_and_si128( lowMask, bytes );
477+
high = _mm_slli_epi16( high, 4 );
478+
bytes = _mm_or_si128( low, high );
479+
return bytes;
480+
}
481+
482+
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
483+
{
484+
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
485+
const __m128i lowByte = _mm_set1_epi16( 0xFF );
486+
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
487+
__m128i low = _mm_and_si128( lowByte, bytes1 );
488+
high = _mm_srli_epi16( high, 4 );
489+
bytes1 = _mm_or_si128( low, high );
490+
high = _mm_andnot_si128( lowByte, bytes2 );
491+
low = _mm_and_si128( lowByte, bytes2 );
492+
high = _mm_srli_epi16( high, 4 );
493+
bytes2 = _mm_or_si128( low, high );
494+
495+
return _mm_packus_epi16( bytes1, bytes2);
496+
}
464497
#endif
465498

466499
// method 5
@@ -660,6 +693,80 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
660693
__m128i res = packNibbles( i0 );
661694
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
662695
}
696+
#elif defined(__AVX__)
697+
for (int i = 0; i < nb; i++) {
698+
// Load elements into 4 AVX vectors
699+
__m256 v0 = _mm256_loadu_ps( x );
700+
__m256 v1 = _mm256_loadu_ps( x + 8 );
701+
__m256 v2 = _mm256_loadu_ps( x + 16 );
702+
__m256 v3 = _mm256_loadu_ps( x + 24 );
703+
x += 32;
704+
705+
// Compute max(abs(e)) for the block
706+
const __m256 signBit = _mm256_set1_ps( -0.0f );
707+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
708+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
709+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
710+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
711+
712+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
713+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
714+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
715+
const float maxScalar = _mm_cvtss_f32( max4 );
716+
717+
// Quantize these floats
718+
const float d = maxScalar / 7.0f;
719+
y[i].d = d;
720+
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
721+
const __m256 mul = _mm256_set1_ps( id );
722+
723+
// Apply the multiplier
724+
v0 = _mm256_mul_ps( v0, mul );
725+
v1 = _mm256_mul_ps( v1, mul );
726+
v2 = _mm256_mul_ps( v2, mul );
727+
v3 = _mm256_mul_ps( v3, mul );
728+
729+
// Round to nearest integer
730+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
731+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
732+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
733+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
734+
735+
// Convert floats to integers
736+
__m256i i0 = _mm256_cvtps_epi32( v0 );
737+
__m256i i1 = _mm256_cvtps_epi32( v1 );
738+
__m256i i2 = _mm256_cvtps_epi32( v2 );
739+
__m256i i3 = _mm256_cvtps_epi32( v3 );
740+
741+
// Since we don't have in AVX some necessary functions,
742+
// we split the registers in half and call AVX2 analogs from SSE
743+
__m128i ni0 = _mm256_castsi256_si128( i0 );
744+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
745+
__m128i ni2 = _mm256_castsi256_si128( i1 );
746+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
747+
__m128i ni4 = _mm256_castsi256_si128( i2 );
748+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
749+
__m128i ni6 = _mm256_castsi256_si128( i3 );
750+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
751+
752+
// Convert int32 to int16
753+
ni0 = _mm_packs_epi32( ni0, ni1 );
754+
ni2 = _mm_packs_epi32( ni2, ni3 );
755+
ni4 = _mm_packs_epi32( ni4, ni5 );
756+
ni6 = _mm_packs_epi32( ni6, ni7 );
757+
// Convert int16 to int8
758+
ni0 = _mm_packs_epi16( ni0, ni2 );
759+
ni4 = _mm_packs_epi16( ni4, ni6 );
760+
761+
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
762+
const __m128i off = _mm_set1_epi8( 8);
763+
ni0 = _mm_add_epi8( ni0, off );
764+
ni4 = _mm_add_epi8( ni4, off );
765+
766+
// Compress the vector into 4 bit/value, and store
767+
__m128i res = packNibbles( ni0, ni4 );
768+
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
769+
}
663770
#elif defined(__wasm_simd128__)
664771
for (int i = 0; i < nb; i++) {
665772
float amax = 0.0f; // absolute max
@@ -1892,6 +1999,52 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
18921999
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
18932000
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
18942001

2002+
sumf = _mm_cvtss_f32( res );
2003+
#elif defined(__AVX__)
2004+
// Initialize accumulator with zeros
2005+
__m256 acc = _mm256_setzero_ps();
2006+
2007+
// Main loop
2008+
for (int i = 0; i < nb; ++i) {
2009+
// Compute combined scale for the block
2010+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2011+
2012+
__m128i i32[2];
2013+
for (int j = 0; j < 2; ++j) {
2014+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2015+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2016+
__m128i by = bytesFromNibbles( y[i].qs + 8*j );
2017+
2018+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2019+
const __m128i off = _mm_set1_epi8( 8 );
2020+
bx = _mm_sub_epi8( bx, off );
2021+
by = _mm_sub_epi8( by, off );
2022+
2023+
// Get absolute values of x vectors
2024+
const __m128i ax = _mm_sign_epi8(bx, bx);
2025+
2026+
// Sign the values of the y vectors
2027+
const __m128i sy = _mm_sign_epi8(by, bx);
2028+
2029+
// Perform multiplication and create 16-bit values
2030+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2031+
2032+
const __m128i ones = _mm_set1_epi16(1);
2033+
i32[j] = _mm_madd_epi16(ones, dot);
2034+
}
2035+
2036+
// Convert int32_t to float
2037+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2038+
// Apply the scale, and accumulate
2039+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2040+
}
2041+
2042+
// Return horizontal sum of the acc vector
2043+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2044+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2045+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2046+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2047+
18952048
sumf = _mm_cvtss_f32( res );
18962049
#elif defined(__wasm_simd128__)
18972050
// wasm simd

0 commit comments

Comments
 (0)