@@ -1582,12 +1582,19 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1582
1582
const uint8_t * restrict p0 = pb0 + i * QK /2 ;
1583
1583
const uint8_t * restrict p1 = pb1 + i * QK /2 ;
1584
1584
1585
+ const __m256 d0v = _mm256_broadcast_ss ( d0 );
1586
+ const __m256 d1v = _mm256_broadcast_ss ( d1 );
1587
+ const __m256 m0v = _mm256_broadcast_ss ( m0 );
1588
+ const __m256 m1v = _mm256_broadcast_ss ( m1 );
1589
+
1590
+
1585
1591
// Compute combined scale for the block
1586
- const __m256 scale_01 = _mm256_mul_ps ( _mm256_broadcast_ss ( d0 ), _mm256_broadcast_ss ( d1 ) );
1592
+ const __m256 scale_01 = _mm256_mul_ps ( d0v , d1v );
1587
1593
1588
1594
// Compute cross scales for the block
1589
- const __m256 scale_0 = _mm256_mul_ps ( _mm256_broadcast_ss ( d0 ), _mm256_broadcast_ss ( m1 ) );
1590
- const __m256 scale_1 = _mm256_mul_ps ( _mm256_broadcast_ss ( m0 ), _mm256_broadcast_ss ( d1 ) );
1595
+ const __m256 scale_0 = _mm256_mul_ps ( d0v , m1v );
1596
+ const __m256 scale_1 = _mm256_mul_ps ( m0v , d1v );
1597
+ const __m256 cross_scales = _mm256_blend_ps ( scale_0 , scale_1 , 0b10101010 );
1591
1598
1592
1599
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1593
1600
__m256i bx = bytesFromNibbles ( p0 );
@@ -1608,20 +1615,22 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1608
1615
i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16_h , y16_h ) );
1609
1616
1610
1617
// compute sums of unsigned bytes in bx, by in blocks of 8.
1611
- // This results in a layout like S100 0000 S200 0000 S300 0000 S400 0000,
1612
- // so if we then cast to 8 singles, we get 8 floats like [ s0_7, 0.0, s8_15, 0.0, s16_23, 0.0, s24_31, 0.0 ]
1613
- __m256 xsum = _mm256_cvtepi32_ps ( _mm256_sad_epu8 ( bx , _mm256_setzero_si256 () ) );
1614
- __m256 ysum = _mm256_cvtepi32_ps ( _mm256_sad_epu8 ( by , _mm256_setzero_si256 () ) );
1618
+ // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
1619
+ // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
1620
+ // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
1621
+ __m256i xsumi = _mm256_sad_epu8 ( bx , _mm256_setzero_si256 () );
1622
+ __m256i ysumi = _mm256_sad_epu8 ( by , _mm256_setzero_si256 () );
1623
+ __m256i sumsi = _mm256_or_si256 ( xsumi , _mm256_slli_si256 ( ysumi , 4 ) );
1624
+ __m256 sums = _mm256_cvtepi32_ps ( sumsi );
1615
1625
1616
1626
// Convert int32_t to float
1617
1627
__m256 p = _mm256_cvtepi32_ps ( i32 );
1618
1628
// Apply the scale, and accumulate
1619
1629
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
1620
1630
acc = _mm256_fmadd_ps ( scale_01 , p , acc );
1621
- acc = _mm256_fmadd_ps ( scale_0 , xsum , acc );
1622
- acc = _mm256_fmadd_ps ( scale_1 , ysum , acc );
1631
+ acc = _mm256_fmadd_ps ( cross_scales , sums , acc );
1623
1632
// acc_offset += m0*m1 (for each entry in the block)
1624
- acc_offset += (* m0 )* (* m1 )* QK ;
1633
+ acc_offset += (* m0 )* (* m1 );
1625
1634
}
1626
1635
1627
1636
// Return horizontal sum of the acc vector
@@ -1630,7 +1639,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
1630
1639
res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
1631
1640
res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
1632
1641
1633
- sumf = _mm_cvtss_f32 ( res ) + acc_offset ;
1642
+ sumf = _mm_cvtss_f32 ( res ) + acc_offset * QK ;
1634
1643
#else
1635
1644
#error "not implemented for QK"
1636
1645
#endif
0 commit comments