@@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
79
79
__m256 accum256_1 = _mm256_setzero_ps ();
80
80
int tail_index_32 = n & (~31 );
81
81
for (int j = 0 ; j < tail_index_32 ; j += 32 ) {
82
- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [j + 0 ]), (__m256bh ) _mm256_loadu_si256 (& y [j + 0 ]));
83
- accum256_1 = _mm256_dpbf16_ps (accum256_1 , (__m256bh ) _mm256_loadu_si256 (& x [j + 16 ]), (__m256bh ) _mm256_loadu_si256 (& y [j + 16 ]));
82
+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [j + 0 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [j + 0 ]));
83
+ accum256_1 = _mm256_dpbf16_ps (accum256_1 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [j + 16 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [j + 16 ]));
84
84
}
85
85
accum256 = _mm256_add_ps (accum256 , accum256_1 );
86
86
87
87
/* Processing the remaining <32 chunk with 16-elements processing */
88
88
if ((n & 16 ) != 0 ) {
89
- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [tail_index_32 ]), (__m256bh ) _mm256_loadu_si256 (& y [tail_index_32 ]));
89
+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [tail_index_32 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [tail_index_32 ]));
90
90
}
91
91
accum128 = _mm_add_ps (_mm256_castps256_ps128 (accum256 ), _mm256_extractf128_ps (accum256 , 1 ));
92
92
93
93
/* Processing the remaining <16 chunk with 8-elements processing */
94
94
if ((n & 8 ) != 0 ) {
95
95
int tail_index_16 = n & (~15 );
96
- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (& y [tail_index_16 ]));
96
+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [tail_index_16 ]));
97
97
}
98
98
99
99
/* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
108
108
} else if (n > 15 ) { /* n range from 16 to 31 */
109
109
/* Processing <32 chunk with 16-elements processing */
110
110
__m256 accum256 = _mm256_setzero_ps ();
111
- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [0 ]), (__m256bh ) _mm256_loadu_si256 (& y [0 ]));
111
+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [0 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [0 ]));
112
112
accum128 += _mm_add_ps (_mm256_castps256_ps128 (accum256 ), _mm256_extractf128_ps (accum256 , 1 ));
113
113
114
114
/* Processing the remaining <16 chunk with 8-elements processing */
115
115
if ((n & 8 ) != 0 ) {
116
116
int tail_index_16 = n & (~15 );
117
- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (& y [tail_index_16 ]));
117
+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [tail_index_16 ]));
118
118
}
119
119
120
120
/* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
128
128
}
129
129
} else if (n > 7 ) { /* n range from 8 to 15 */
130
130
/* Processing <16 chunk with 8-elements processing */
131
- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [0 ]), (__m128bh ) _mm_loadu_si128 (& y [0 ]));
131
+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [0 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [0 ]));
132
132
133
133
/* Processing the remaining <8 chunk with masked 8-elements processing */
134
134
if ((n & 7 ) != 0 ) {
0 commit comments