Skip to content

Commit 2405550

Browse files
authored
Merge pull request #3389 from guowangy/bf16-build-warn-fix
x86_64: BFLOAT16: fix build warning
2 parents 9f52abf + ee5ca8a commit 2405550

File tree

5 files changed

+100
-54
lines changed

5 files changed

+100
-54
lines changed

kernel/x86_64/bf16_common_macros.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,25 +56,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5656

5757

5858
#define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \
59-
regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \
60-
regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \
61-
regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \
62-
regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \
63-
regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \
64-
regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \
65-
regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \
66-
regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]);
59+
regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \
60+
regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \
61+
regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \
62+
regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \
63+
regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \
64+
regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \
65+
regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \
66+
regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n]));
6767

6868

6969
#define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \
70-
regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \
71-
regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \
72-
regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \
73-
regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \
74-
regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \
75-
regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \
76-
regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \
77-
regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]);
70+
regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \
71+
regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \
72+
regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \
73+
regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \
74+
regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \
75+
regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \
76+
regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \
77+
regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n]));
7878

7979

8080
#define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \
@@ -153,11 +153,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
153153

154154

155155
#define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \
156-
reg = _mm256_loadu_si256(x + idx_n);
156+
reg = _mm256_loadu_si256((__m256i *)(x + idx_n));
157157

158158

159159
#define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \
160-
reg = _mm_loadu_si128(x + idx_n);
160+
reg = _mm_loadu_si128((__m128i *)(x + idx_n));
161161

162162

163163
#define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \

kernel/x86_64/sbdot_microk_cooperlake.c

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
7979
__m256 accum256_1 = _mm256_setzero_ps();
8080
int tail_index_32 = n&(~31);
8181
for (int j = 0; j < tail_index_32; j += 32) {
82-
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
83-
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
82+
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+ 0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+ 0]));
83+
accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+16]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+16]));
8484
}
8585
accum256 = _mm256_add_ps(accum256, accum256_1);
8686

8787
/* Processing the remaining <32 chunk with 16-elements processing */
8888
if ((n&16) != 0) {
89-
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
89+
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[tail_index_32]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[tail_index_32]));
9090
}
9191
accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
9292

9393
/* Processing the remaining <16 chunk with 8-elements processing */
9494
if ((n&8) != 0) {
9595
int tail_index_16 = n&(~15);
96-
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
96+
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16]));
9797
}
9898

9999
/* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
108108
} else if (n > 15) { /* n range from 16 to 31 */
109109
/* Processing <32 chunk with 16-elements processing */
110110
__m256 accum256 = _mm256_setzero_ps();
111-
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
111+
accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[0]));
112112
accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
113113

114114
/* Processing the remaining <16 chunk with 8-elements processing */
115115
if ((n&8) != 0) {
116116
int tail_index_16 = n&(~15);
117-
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
117+
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16]));
118118
}
119119

120120
/* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
128128
}
129129
} else if (n > 7) { /* n range from 8 to 15 */
130130
/* Processing <16 chunk with 8-elements processing */
131-
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
131+
accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[0]), (__m128bh) _mm_loadu_si128((__m128i *)&y[0]));
132132

133133
/* Processing the remaining <8 chunk with masked 8-elements processing */
134134
if ((n&7) != 0) {

kernel/x86_64/sbgemm_block_microk_cooperlake.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat
12461246
// K=Any number but will be processed based on 32, M<=16
12471247
void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A)
12481248
{
1249-
bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3;
1249+
bfloat16 * src_addr0;
12501250
bfloat16 * dst_addr0, * dst_addr1;
12511251

12521252
BLASLONG tag_k_32x = k & (~31);

kernel/x86_64/sbgemv_n_microk_cooperlake_template.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3030
// Include common macros for BF16 based operations with IA intrinsics
3131
#include "bf16_common_macros.h"
3232

33+
#undef STORE16_COMPLETE_RESULT
34+
#undef STORE16_MASK_COMPLETE_RESULT
35+
#undef STORE8_COMPLETE_RESULT
36+
#undef STORE8_MASK_COMPLETE_RESULT
37+
#undef STORE4_COMPLETE_RESULT
38+
#undef STORE4_MASK_COMPLETE_RESULT
39+
3340
#ifndef ZERO_BETA // Beta is non-zero
3441

3542
#ifndef ONE_BETA // BETA is not ONE
@@ -103,7 +110,9 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf
103110
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
104111
#endif
105112
#ifndef ZERO_BETA
113+
#ifndef ONE_BETA
106114
__m512 BETAVECTOR = _mm512_set1_ps(beta);
115+
#endif
107116
#endif
108117

109118
__m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3;
@@ -202,7 +211,7 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf
202211
unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31)));
203212
__mmask32 tail_mask = *((__mmask32*) &tail_mask_value);
204213

205-
unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
214+
unsigned int store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
206215
__mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value);
207216

208217
accum512_0 = _mm512_setzero_ps();

0 commit comments

Comments
 (0)