1
- #include " norm .hpp"
1
+ #include " softmax .hpp"
2
2
3
- template <bool vals_smem, int ncols_template, int block_size_template>
4
- static void soft_max_f32 (const float * x, const float * mask, float * dst, const int ncols_par,
3
+ template <typename T> static inline float t2f32 (T val) {
4
+ return static_cast <float >(val);
5
+ }
6
+
7
+ template <> inline float t2f32<sycl::half>(sycl::half val) {
8
+ return static_cast <float >(val);
9
+ }
10
+
11
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
12
+ static void soft_max_f32 (const float * x, const T * mask, float * dst, const int ncols_par,
5
13
const int nrows_y, const float scale, const float max_bias, const float m0,
6
14
const float m1, uint32_t n_head_log2, const sycl::nd_item<3 > &item_ct1, float *buf) {
7
15
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,9 +37,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
29
37
slope = sycl::pow (base, float (exp ));
30
38
}
31
39
32
- float *vals = vals_smem ? buf + std ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
40
+ float *vals = vals_smem ? buf + sycl ::max (nwarps, WARP_SIZE) : dst + rowx * ncols;
33
41
float max_val = -INFINITY;
34
42
43
+ #pragma unroll
35
44
for (int col0 = 0 ; col0 < ncols; col0 += block_size) {
36
45
const int col = col0 + tid;
37
46
@@ -42,7 +51,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
42
51
const int ix = rowx*ncols + col;
43
52
const int iy = rowy*ncols + col;
44
53
45
- const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0 .0f );
54
+ const float val = x[ix]*scale + (mask ? slope*t2f32 ( mask[iy]) : 0 .0f );
46
55
47
56
vals[col] = val;
48
57
max_val = sycl::max (max_val, val);
@@ -65,7 +74,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
65
74
item_ct1.barrier (sycl::access ::fence_space::local_space);
66
75
max_val = buf[lane_id];
67
76
for (size_t i = 1 ; i < nreduce; i += 1 ) {
68
- max_val = std ::max (max_val, buf[lane_id + i * WARP_SIZE]);
77
+ max_val = sycl ::max (max_val, buf[lane_id + i * WARP_SIZE]);
69
78
}
70
79
max_val = warp_reduce_max (max_val, item_ct1);
71
80
}
@@ -122,8 +131,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122
131
}
123
132
}
124
133
125
- template <bool vals_smem, int ncols_template, int block_size_template>
126
- static void soft_max_f32_submitter (const float * x, const float * mask, float * dst, const int ncols_par,
134
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T >
135
+ static void soft_max_f32_submitter (const float * x, const T * mask, float * dst, const int ncols_par,
127
136
const int nrows_y, const float scale, const float max_bias, const float m0,
128
137
const float m1, uint32_t n_head_log2, sycl::range<3 > block_nums, sycl::range<3 > block_dims,
129
138
const size_t n_local_scratch, queue_ptr stream) {
@@ -133,15 +142,16 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
133
142
cgh.parallel_for (
134
143
sycl::nd_range<3 >(block_nums * block_dims, block_dims),
135
144
[=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
136
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
145
+ soft_max_f32<vals_smem, ncols_template, block_size_template, T >(x, mask, dst, ncols_par,
137
146
nrows_y, scale, max_bias, m0,
138
147
m1, n_head_log2, item_ct1,
139
148
get_pointer (local_buf_acc));
140
149
});
141
150
});
142
151
}
143
152
144
- static void soft_max_f32_sycl (const float * x, const float * mask,
153
+ template <typename T>
154
+ static void soft_max_f32_sycl (const float * x, const T * mask,
145
155
float * dst, const int ncols_x, const int nrows_x,
146
156
const int nrows_y, const float scale, const float max_bias,
147
157
queue_ptr stream, int device) {
@@ -164,88 +174,99 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
164
174
const size_t local_mem_size = stream->get_device ().get_info <sycl::info::device::local_mem_size>();
165
175
if (n_local_scratch*sizeof (float ) < local_mem_size) {
166
176
if (ncols_x > max_block_size) {
167
- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
177
+ soft_max_f32_submitter<true , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
168
178
max_bias, m0, m1, n_head_log2, block_nums,
169
179
block_dims, n_local_scratch, stream);
170
180
return ;
171
181
}
172
182
switch (ncols_x) {
173
183
case 32 :
174
- soft_max_f32_submitter<true , 32 , 32 >(x, mask, dst, ncols_x, nrows_y, scale,
184
+ soft_max_f32_submitter<true , 32 , 32 , T >(x, mask, dst, ncols_x, nrows_y, scale,
175
185
max_bias, m0, m1, n_head_log2, block_nums,
176
186
block_dims, n_local_scratch, stream);
177
187
break ;
178
188
case 64 :
179
- soft_max_f32_submitter<true , 64 , 64 >(x, mask, dst, ncols_x, nrows_y, scale,
189
+ soft_max_f32_submitter<true , 64 , 64 , T >(x, mask, dst, ncols_x, nrows_y, scale,
180
190
max_bias, m0, m1, n_head_log2, block_nums,
181
191
block_dims, n_local_scratch, stream);
182
192
break ;
183
193
case 128 :
184
- soft_max_f32_submitter<true , 128 , 128 >(x, mask, dst, ncols_x, nrows_y, scale,
194
+ soft_max_f32_submitter<true , 128 , 128 , T >(x, mask, dst, ncols_x, nrows_y, scale,
185
195
max_bias, m0, m1, n_head_log2, block_nums,
186
196
block_dims, n_local_scratch, stream);
187
197
break ;
188
198
case 256 :
189
- soft_max_f32_submitter<true , 256 , 256 >(x, mask, dst, ncols_x, nrows_y, scale,
199
+ soft_max_f32_submitter<true , 256 , 256 , T >(x, mask, dst, ncols_x, nrows_y, scale,
190
200
max_bias, m0, m1, n_head_log2, block_nums,
191
201
block_dims, n_local_scratch, stream);
192
202
break ;
193
203
case 512 :
194
- soft_max_f32_submitter<true , 512 , 512 >(x, mask, dst, ncols_x, nrows_y, scale,
204
+ soft_max_f32_submitter<true , 512 , 512 , T >(x, mask, dst, ncols_x, nrows_y, scale,
195
205
max_bias, m0, m1, n_head_log2, block_nums,
196
206
block_dims, n_local_scratch, stream);
197
207
break ;
198
208
case 1024 :
199
- soft_max_f32_submitter<true , 1024 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
209
+ soft_max_f32_submitter<true , 1024 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
200
210
max_bias, m0, m1, n_head_log2, block_nums,
201
211
block_dims, n_local_scratch, stream);
202
212
break ;
203
213
case 2048 :
204
- soft_max_f32_submitter<true , 2048 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
214
+ soft_max_f32_submitter<true , 2048 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
205
215
max_bias, m0, m1, n_head_log2, block_nums,
206
216
block_dims, n_local_scratch, stream);
207
217
break ;
208
218
case 4096 :
209
- soft_max_f32_submitter<true , 4096 , 1024 >(x, mask, dst, ncols_x, nrows_y, scale,
219
+ soft_max_f32_submitter<true , 4096 , 1024 , T >(x, mask, dst, ncols_x, nrows_y, scale,
210
220
max_bias, m0, m1, n_head_log2, block_nums,
211
221
block_dims, n_local_scratch, stream);
212
222
break ;
213
223
default :
214
- soft_max_f32_submitter<true , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
224
+ soft_max_f32_submitter<true , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
215
225
max_bias, m0, m1, n_head_log2, block_nums,
216
226
block_dims, n_local_scratch, stream);
217
227
break ;
218
228
}
219
229
} else {
220
- soft_max_f32_submitter<false , 0 , 0 >(x, mask, dst, ncols_x, nrows_y, scale,
230
+ soft_max_f32_submitter<false , 0 , 0 , T >(x, mask, dst, ncols_x, nrows_y, scale,
221
231
max_bias, m0, m1, n_head_log2, block_nums,
222
232
block_dims, WARP_SIZE, stream);
223
233
}
224
234
}
225
235
226
- void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227
- const ggml_tensor *src1, ggml_tensor *dst,
228
- const float *src0_dd, const float *src1_dd,
229
- float *dst_dd,
230
- const queue_ptr &main_stream) {
236
+ void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231
237
232
- GGML_ASSERT (src0 ->type == GGML_TYPE_F32);
238
+ GGML_ASSERT (dst-> src [ 0 ] ->type == GGML_TYPE_F32);
233
239
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
234
240
235
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236
- #pragma message("ref: https:// github.com/ggerganov/llama.cpp/pull/5021")
237
- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
241
+ GGML_ASSERT (!dst->src [1 ] || dst->src [1 ]->type == GGML_TYPE_F16 || dst->src [1 ]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238
242
239
- const int64_t ne00 = src0 ->ne [0 ];
240
- const int64_t nrows_x = ggml_nrows (src0 );
241
- const int64_t nrows_y = src0 ->ne [1 ];
243
+ const int64_t ne00 = dst-> src [ 0 ] ->ne [0 ];
244
+ const int64_t nrows_x = ggml_nrows (dst-> src [ 0 ] );
245
+ const int64_t nrows_y = dst-> src [ 0 ] ->ne [1 ];
242
246
243
247
float scale = 1 .0f ;
244
248
float max_bias = 0 .0f ;
245
249
246
250
memcpy (&scale, dst->op_params + 0 , sizeof (float ));
247
251
memcpy (&max_bias, dst->op_params + 1 , sizeof (float ));
248
252
249
- soft_max_f32_sycl (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00,
250
- nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
253
+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
254
+ float * dst_dd = static_cast <float *>(dst->data );
255
+
256
+ ggml_sycl_set_device (ctx.device );
257
+ dpct::queue_ptr main_stream = ctx.stream ();
258
+
259
+ if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F16) {
260
+ // printf("%s: fp16 mask\n", __func__);
261
+ const sycl::half * src1_dd = static_cast <sycl::half *>(dst->src [1 ]->data );
262
+ soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
263
+ main_stream, ctx.device );
264
+ } else if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F32) {
265
+ // printf("%s: fp32 mask\n", __func__);
266
+ const float * src1_dd = static_cast <const float *>(dst->src [1 ]->data );
267
+ soft_max_f32_sycl<float >(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
268
+ } else {
269
+ /* mask unavailable */
270
+ soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
271
+ }
251
272
}
0 commit comments