|
1 | 1 | #include "softmax.hpp"
|
2 | 2 |
|
3 |
| -template <typename T> static inline float t2f32(T val) { |
4 |
| - return static_cast<float>(val); |
5 |
| -} |
6 |
| - |
7 |
| -template <> inline float t2f32<sycl::half>(sycl::half val) { |
8 |
| - return static_cast<float>(val); |
9 |
| -} |
10 |
| - |
11 | 3 | template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
12 | 4 | static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
|
13 | 5 | const int nrows_y, const float scale, const float max_bias, const float m0,
|
@@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
|
51 | 43 | const int ix = rowx*ncols + col;
|
52 | 44 | const int iy = rowy*ncols + col;
|
53 | 45 |
|
54 |
| - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); |
| 46 | + const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f); |
55 | 47 |
|
56 | 48 | vals[col] = val;
|
57 | 49 | max_val = sycl::max(max_val, val);
|
@@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
|
174 | 166 | const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
175 | 167 | if (n_local_scratch*sizeof(float) < local_mem_size) {
|
176 | 168 | if (ncols_x > max_block_size) {
|
177 |
| - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 169 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
178 | 170 | max_bias, m0, m1, n_head_log2, block_nums,
|
179 | 171 | block_dims, n_local_scratch, stream);
|
180 | 172 | return;
|
181 | 173 | }
|
182 | 174 | switch (ncols_x) {
|
183 | 175 | case 32:
|
184 |
| - soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 176 | + soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale, |
185 | 177 | max_bias, m0, m1, n_head_log2, block_nums,
|
186 | 178 | block_dims, n_local_scratch, stream);
|
187 | 179 | break;
|
188 | 180 | case 64:
|
189 |
| - soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 181 | + soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale, |
190 | 182 | max_bias, m0, m1, n_head_log2, block_nums,
|
191 | 183 | block_dims, n_local_scratch, stream);
|
192 | 184 | break;
|
193 | 185 | case 128:
|
194 |
| - soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 186 | + soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale, |
195 | 187 | max_bias, m0, m1, n_head_log2, block_nums,
|
196 | 188 | block_dims, n_local_scratch, stream);
|
197 | 189 | break;
|
198 | 190 | case 256:
|
199 |
| - soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 191 | + soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale, |
200 | 192 | max_bias, m0, m1, n_head_log2, block_nums,
|
201 | 193 | block_dims, n_local_scratch, stream);
|
202 | 194 | break;
|
203 | 195 | case 512:
|
204 |
| - soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 196 | + soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale, |
205 | 197 | max_bias, m0, m1, n_head_log2, block_nums,
|
206 | 198 | block_dims, n_local_scratch, stream);
|
207 | 199 | break;
|
208 | 200 | case 1024:
|
209 |
| - soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 201 | + soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
210 | 202 | max_bias, m0, m1, n_head_log2, block_nums,
|
211 | 203 | block_dims, n_local_scratch, stream);
|
212 | 204 | break;
|
213 | 205 | case 2048:
|
214 |
| - soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 206 | + soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
215 | 207 | max_bias, m0, m1, n_head_log2, block_nums,
|
216 | 208 | block_dims, n_local_scratch, stream);
|
217 | 209 | break;
|
218 | 210 | case 4096:
|
219 |
| - soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 211 | + soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale, |
220 | 212 | max_bias, m0, m1, n_head_log2, block_nums,
|
221 | 213 | block_dims, n_local_scratch, stream);
|
222 | 214 | break;
|
223 | 215 | default:
|
224 |
| - soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 216 | + soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
225 | 217 | max_bias, m0, m1, n_head_log2, block_nums,
|
226 | 218 | block_dims, n_local_scratch, stream);
|
227 | 219 | break;
|
228 | 220 | }
|
229 | 221 | } else {
|
230 |
| - soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale, |
| 222 | + soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale, |
231 | 223 | max_bias, m0, m1, n_head_log2, block_nums,
|
232 | 224 | block_dims, WARP_SIZE, stream);
|
233 | 225 | }
|
|
0 commit comments