Skip to content

Commit 5e66b6b

Browse files
committed
Small optimisations to q4_1 dot product (@Const-me)
1 parent f765cce commit 5e66b6b

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

ggml.c

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,12 +1582,19 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
15821582
const uint8_t * restrict p0 = pb0 + i*QK/2;
15831583
const uint8_t * restrict p1 = pb1 + i*QK/2;
15841584

1585+
const __m256 d0v = _mm256_broadcast_ss( d0 );
1586+
const __m256 d1v = _mm256_broadcast_ss( d1 );
1587+
const __m256 m0v = _mm256_broadcast_ss( m0 );
1588+
const __m256 m1v = _mm256_broadcast_ss( m1 );
1589+
1590+
15851591
// Compute combined scale for the block
1586-
const __m256 scale_01 = _mm256_mul_ps( _mm256_broadcast_ss( d0 ), _mm256_broadcast_ss( d1 ) );
1592+
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
15871593

15881594
// Compute cross scales for the block
1589-
const __m256 scale_0 = _mm256_mul_ps( _mm256_broadcast_ss( d0 ), _mm256_broadcast_ss( m1 ) );
1590-
const __m256 scale_1 = _mm256_mul_ps( _mm256_broadcast_ss( m0 ), _mm256_broadcast_ss( d1 ) );
1595+
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
1596+
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
1597+
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
15911598

15921599
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
15931600
__m256i bx = bytesFromNibbles( p0 );
@@ -1608,20 +1615,22 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
16081615
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
16091616

16101617
// compute sums of unsigned bytes in bx, by in blocks of 8.
1611-
// This results in a layout like S100 0000 S200 0000 S300 0000 S400 0000,
1612-
// so if we then cast to 8 singles, we get 8 floats like [ s0_7, 0.0, s8_15, 0.0, s16_23, 0.0, s24_31, 0.0 ]
1613-
__m256 xsum = _mm256_cvtepi32_ps( _mm256_sad_epu8( bx, _mm256_setzero_si256() ) );
1614-
__m256 ysum = _mm256_cvtepi32_ps( _mm256_sad_epu8( by, _mm256_setzero_si256() ) );
1618+
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
1619+
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
1620+
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
1621+
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
1622+
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
1623+
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
1624+
__m256 sums = _mm256_cvtepi32_ps( sumsi );
16151625

16161626
// Convert int32_t to float
16171627
__m256 p = _mm256_cvtepi32_ps( i32 );
16181628
// Apply the scale, and accumulate
16191629
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
16201630
acc = _mm256_fmadd_ps( scale_01, p, acc );
1621-
acc = _mm256_fmadd_ps( scale_0, xsum, acc );
1622-
acc = _mm256_fmadd_ps( scale_1, ysum, acc );
1631+
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
16231632
// acc_offset += m0*m1 (for each entry in the block)
1624-
acc_offset += (*m0)*(*m1)*QK;
1633+
acc_offset += (*m0)*(*m1);
16251634
}
16261635

16271636
// Return horizontal sum of the acc vector
@@ -1630,7 +1639,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
16301639
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
16311640
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
16321641

1633-
sumf = _mm_cvtss_f32( res ) + acc_offset;
1642+
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
16341643
#else
16351644
#error "not implemented for QK"
16361645
#endif

0 commit comments

Comments
 (0)