Skip to content

Commit 421d574

Browse files
committed
softmax: review update
1 parent 7b066d4 commit 421d574

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
#include "softmax.hpp"
22

3-
template <typename T> static inline float t2f32(T val) {
4-
return static_cast<float>(val);
5-
}
6-
7-
template <> inline float t2f32<sycl::half>(sycl::half val) {
8-
return static_cast<float>(val);
9-
}
10-
113
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
124
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
135
const int nrows_y, const float scale, const float max_bias, const float m0,
@@ -51,7 +43,7 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
5143
const int ix = rowx*ncols + col;
5244
const int iy = rowy*ncols + col;
5345

54-
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
46+
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
5547

5648
vals[col] = val;
5749
max_val = sycl::max(max_val, val);
@@ -174,60 +166,60 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
174166
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
175167
if (n_local_scratch*sizeof(float) < local_mem_size) {
176168
if (ncols_x > max_block_size) {
177-
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
169+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
178170
max_bias, m0, m1, n_head_log2, block_nums,
179171
block_dims, n_local_scratch, stream);
180172
return;
181173
}
182174
switch (ncols_x) {
183175
case 32:
184-
soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale,
176+
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
185177
max_bias, m0, m1, n_head_log2, block_nums,
186178
block_dims, n_local_scratch, stream);
187179
break;
188180
case 64:
189-
soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale,
181+
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
190182
max_bias, m0, m1, n_head_log2, block_nums,
191183
block_dims, n_local_scratch, stream);
192184
break;
193185
case 128:
194-
soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale,
186+
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
195187
max_bias, m0, m1, n_head_log2, block_nums,
196188
block_dims, n_local_scratch, stream);
197189
break;
198190
case 256:
199-
soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale,
191+
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
200192
max_bias, m0, m1, n_head_log2, block_nums,
201193
block_dims, n_local_scratch, stream);
202194
break;
203195
case 512:
204-
soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale,
196+
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
205197
max_bias, m0, m1, n_head_log2, block_nums,
206198
block_dims, n_local_scratch, stream);
207199
break;
208200
case 1024:
209-
soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
201+
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
210202
max_bias, m0, m1, n_head_log2, block_nums,
211203
block_dims, n_local_scratch, stream);
212204
break;
213205
case 2048:
214-
soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
206+
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
215207
max_bias, m0, m1, n_head_log2, block_nums,
216208
block_dims, n_local_scratch, stream);
217209
break;
218210
case 4096:
219-
soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
211+
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
220212
max_bias, m0, m1, n_head_log2, block_nums,
221213
block_dims, n_local_scratch, stream);
222214
break;
223215
default:
224-
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
216+
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
225217
max_bias, m0, m1, n_head_log2, block_nums,
226218
block_dims, n_local_scratch, stream);
227219
break;
228220
}
229221
} else {
230-
soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
222+
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
231223
max_bias, m0, m1, n_head_log2, block_nums,
232224
block_dims, WARP_SIZE, stream);
233225
}

0 commit comments

Comments
 (0)