diff --git a/kernel/x86_64/bf16_common_macros.h b/kernel/x86_64/bf16_common_macros.h index 78db7abb2a..cdb4beff6e 100644 --- a/kernel/x86_64/bf16_common_macros.h +++ b/kernel/x86_64/bf16_common_macros.h @@ -56,25 +56,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \ - regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \ - regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \ - regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \ - regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \ - regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \ - regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \ - regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \ - regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]); + regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \ + regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \ + regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \ + regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \ + regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \ + regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \ + regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \ + regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n])); #define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \ - regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \ - regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \ - regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \ - regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \ - regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \ - regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \ - regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \ - regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]); + regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \ + regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \ + regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \ + regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \ + regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \ + regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \ + regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \ + regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n])); #define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \ @@ -153,11 +153,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \ - reg = _mm256_loadu_si256(x + idx_n); + reg = _mm256_loadu_si256((__m256i *)(x + idx_n)); #define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \ - reg = _mm_loadu_si128(x + idx_n); + reg = _mm_loadu_si128((__m128i *)(x + idx_n)); #define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \ diff --git a/kernel/x86_64/sbdot_microk_cooperlake.c b/kernel/x86_64/sbdot_microk_cooperlake.c index 067726cb1c..2aefe46ffb 100644 --- a/kernel/x86_64/sbdot_microk_cooperlake.c +++ b/kernel/x86_64/sbdot_microk_cooperlake.c @@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) __m256 accum256_1 = _mm256_setzero_ps(); int tail_index_32 = n&(~31); for (int j = 0; j < tail_index_32; j += 32) { - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0])); - accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+ 0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+ 0])); + accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+16]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+16])); } accum256 = _mm256_add_ps(accum256, accum256_1); /* Processing the remaining <32 chunk with 16-elements processing */ if ((n&16) != 0) { - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[tail_index_32]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[tail_index_32])); } accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); /* Processing the remaining <16 chunk with 8-elements processing */ if ((n&8) != 0) { int tail_index_16 = n&(~15); - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16])); } /* Processing the remaining <8 chunk with masked 8-elements processing */ @@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) } else if (n > 15) { /* n range from 16 to 31 */ /* Processing <32 chunk with 16-elements processing */ __m256 accum256 = _mm256_setzero_ps(); - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[0])); accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); /* Processing the remaining <16 chunk with 8-elements processing */ if ((n&8) != 0) { int tail_index_16 = n&(~15); - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16])); } /* Processing the remaining <8 chunk with masked 8-elements processing */ @@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) } } else if (n > 7) { /* n range from 8 to 15 */ /* Processing <16 chunk with 8-elements processing */ - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[0]), (__m128bh) _mm_loadu_si128((__m128i *)&y[0])); /* Processing the remaining <8 chunk with masked 8-elements processing */ if ((n&7) != 0) { diff --git a/kernel/x86_64/sbgemm_block_microk_cooperlake.c b/kernel/x86_64/sbgemm_block_microk_cooperlake.c index 2c27221ac9..b8c41f4f72 100644 --- a/kernel/x86_64/sbgemm_block_microk_cooperlake.c +++ b/kernel/x86_64/sbgemm_block_microk_cooperlake.c @@ -1246,7 +1246,7 @@ void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat // K=Any number but will be processed based on 32, M<=16 void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) { - bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * src_addr0; bfloat16 * dst_addr0, * dst_addr1; BLASLONG tag_k_32x = k & (~31); diff --git a/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c b/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c index 46e6d0ff9a..4711e9720c 100644 --- a/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c @@ -30,6 +30,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Include common macros for BF16 based operations with IA intrinsics #include "bf16_common_macros.h" +#undef STORE16_COMPLETE_RESULT +#undef STORE16_MASK_COMPLETE_RESULT +#undef STORE8_COMPLETE_RESULT +#undef STORE8_MASK_COMPLETE_RESULT +#undef STORE4_COMPLETE_RESULT +#undef STORE4_MASK_COMPLETE_RESULT + #ifndef ZERO_BETA // Beta is non-zero #ifndef ONE_BETA // BETA is not ONE @@ -103,7 +110,9 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3; @@ -202,7 +211,7 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); + unsigned int store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); __mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value); accum512_0 = _mm512_setzero_ps(); diff --git a/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c b/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c index 51e681add3..8a3a022fb3 100644 --- a/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c @@ -29,6 +29,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Include common macros for BF16 based operations with IA intrinsics #include "bf16_common_macros.h" +#undef STORE16_COMPLETE_RESULT +#undef STORE16_MASK_COMPLETE_RESULT +#undef STORE8_COMPLETE_RESULT +#undef STORE8_MASK_COMPLETE_RESULT +#undef STORE4_COMPLETE_RESULT +#undef STORE4_MASK_COMPLETE_RESULT + #ifndef ZERO_BETA // Beta is non-zero #ifndef ONE_BETA // BETA is not ONE @@ -231,7 +238,9 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif unsigned char load_mask_value = (((unsigned char)0xff) >> 6); @@ -280,7 +289,7 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, } else if (tail_num == 8) { __m256 result256 = _mm256_setzero_ps(); - __m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2 + __m256i matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*2]); // Load 8 rows with n=2 __m256i xArray256 = _mm512_castsi512_si256(xArray); result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256); @@ -323,7 +332,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); @@ -395,9 +406,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, result256_0 = _mm256_setzero_ps(); result256_1 = _mm256_setzero_ps(); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element - matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element - matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row @@ -423,8 +434,8 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, if (tail_num > 10) { unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1))); __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element - matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row @@ -439,7 +450,7 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, } else if (tail_num > 5) { unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2))); __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows matrixArray256_2 = _mm256_setzero_si256(); @@ -499,7 +510,9 @@ static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_1 = _mm512_set1_epi32(1); @@ -591,7 +604,9 @@ static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512 result_0, result_1; @@ -782,7 +797,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_1 = _mm512_set1_epi32(1); @@ -866,9 +883,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, result256_0 = _mm256_setzero_ps(); - matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element - matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element - matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element + matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element + matrixArray_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element + matrixArray_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element // Process the 0|1 elements // Select the 0|1 elements for each row @@ -957,7 +974,9 @@ static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_2 = _mm512_set1_epi32(2); @@ -1110,7 +1129,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, { BLASLONG tag_m_16x = m & (~15); - __m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| if (tag_m_16x > 0) { __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; @@ -1122,7 +1141,9 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_2 = _mm512_set1_epi32(2); @@ -1214,7 +1235,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m128 result128, tmp128; for (BLASLONG i = tag_m_16x; i < m; i++) { result128 = _mm_setzero_ps(); - matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8 + matrixArray128 = _mm_loadu_si128((__m128i *)&a[(i)*8]); // Load 1 rows with n=8 result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128); tmp128 = _mm_shuffle_ps(result128, result128, 14); result128 = _mm_add_ps(result128, tmp128); @@ -1258,7 +1279,7 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0| if (tag_m_14x > 0) { @@ -1271,7 +1292,9 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m256i M256_EPI16_2 = _mm256_set1_epi16(2); @@ -1390,7 +1413,7 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0| if (tag_m_12x > 0) { @@ -1403,7 +1426,9 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m256i M256_EPI32_1 = _mm256_set1_epi32(1); @@ -1522,7 +1547,7 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0| if (tag_m_15x > 0) { @@ -1535,7 +1560,9 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; @@ -1690,7 +1717,7 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2| x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0| if (tag_m_15x > 0) { @@ -1703,7 +1730,9 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; @@ -1873,16 +1902,15 @@ static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); - unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6); - __mmask32 load_mask = *((__mmask32*) &load_mask_value); - // Prepare X with 2-step interleave way xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); BF16_INTERLEAVE_1x32(xArray) @@ -2045,7 +2073,9 @@ static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); @@ -2207,16 +2237,15 @@ static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); - unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2); - __mmask32 load_mask = *((__mmask32*) &load_mask_value); - // Prepare X with 2-step interleave way xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); BF16_INTERLEAVE_1x32(xArray) @@ -2364,7 +2393,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x { BLASLONG tag_m_16x = m & (~15); - __m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| + __m256i x256 = _mm256_loadu_si256((__m256i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| if (tag_m_16x > 0) { __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ @@ -2377,7 +2406,9 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); @@ -2484,7 +2515,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m128 accum128, tmp128; for (BLASLONG i = tag_m_16x; i < m; i++) { accum256 = _mm256_setzero_ps(); - matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16 + matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(i)*16]); // Load 1 rows with n=16 accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256); accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1)); tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e); @@ -2535,7 +2566,9 @@ static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ @@ -2647,8 +2680,6 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b BLASLONG tag_n_32x = n & (~31); BLASLONG tag_n_128x = n & (~127); - __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \ - accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15; __m512 accum512_bridge[8]; __m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3; __m256 accum256_0; @@ -2658,7 +2689,9 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; @@ -2825,7 +2858,9 @@ static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7; @@ -2961,7 +2996,9 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha)); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta)); +#endif #endif __m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \ @@ -3012,7 +3049,7 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m128 accum128, tmp128; for (BLASLONG i = tag_m_8x; i < m; i++) { accum256_0 = _mm256_setzero_ps(); - matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16 + matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(i)*lda]); // Load 1 rows with n=16 accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256); accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1)); tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e);