Skip to content

Commit 1e2fe41

Browse files
committed
SYCL: SOFTMAX F16 mask support and other fixes
1 parent 99487b5 commit 1e2fe41

File tree

6 files changed

+62
-53
lines changed

6 files changed

+62
-53
lines changed

ggml/src/ggml-sycl/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
7979
float * src0_ddf = (float *) src0->data;
8080
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
8181
float * dst_ddf = (float *) dst->data;
82-
82+
/* These are never used */
8383
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
8484
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
8585
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "dmmv.hpp"
33
#include "dequantize.hpp"
44
#include "presets.hpp"
5+
#include "ggml-impl.h"
56

67

78
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
@@ -973,6 +974,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
973974
}
974975
#else
975976
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
977+
GGML_UNUSED(ctx);
976978
#endif // GGML_SYCL_F16
977979

978980
switch (src0->type) {
@@ -1010,7 +1012,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
10101012
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
10111013
break;
10121014
default:
1013-
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1015+
GGML_LOG_ERROR("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
10141016
GGML_ABORT("fatal error");
10151017
break;
10161018
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
38783878
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
38793879
}
38803880

3881-
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882-
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
3883-
}
3884-
38853881
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
38863882
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
38873883
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
40904086
ggml_sycl_diag_mask_inf(ctx, dst);
40914087
break;
40924088
case GGML_OP_SOFT_MAX:
4093-
ggml_sycl_soft_max(ctx, dst);
4089+
ggml_sycl_op_soft_max(ctx, dst);
40944090
break;
40954091
case GGML_OP_ROPE:
40964092
ggml_sycl_rope(ctx, dst);

ggml/src/ggml-sycl/softmax.cpp

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
#include "norm.hpp"
1+
#include "softmax.hpp"
22

3-
template <bool vals_smem, int ncols_template, int block_size_template>
4-
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
3+
template <typename T> static inline float t2f32(T val) {
4+
return static_cast<float>(val);
5+
}
6+
7+
template <> inline float t2f32<sycl::half>(sycl::half val) {
8+
return static_cast<float>(val);
9+
}
10+
11+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
12+
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
513
const int nrows_y, const float scale, const float max_bias, const float m0,
614
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
715
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -29,9 +37,10 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
2937
slope = sycl::pow(base, float(exp));
3038
}
3139

32-
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
40+
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
3341
float max_val = -INFINITY;
3442

43+
#pragma unroll
3544
for (int col0 = 0; col0 < ncols; col0 += block_size) {
3645
const int col = col0 + tid;
3746

@@ -42,7 +51,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
4251
const int ix = rowx*ncols + col;
4352
const int iy = rowy*ncols + col;
4453

45-
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
54+
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
4655

4756
vals[col] = val;
4857
max_val = sycl::max(max_val, val);
@@ -65,7 +74,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
6574
item_ct1.barrier(sycl::access::fence_space::local_space);
6675
max_val = buf[lane_id];
6776
for (size_t i = 1; i < nreduce; i += 1) {
68-
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
77+
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
6978
}
7079
max_val = warp_reduce_max(max_val, item_ct1);
7180
}
@@ -122,8 +131,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
122131
}
123132
}
124133

125-
template <bool vals_smem, int ncols_template, int block_size_template>
126-
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
134+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
135+
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
127136
const int nrows_y, const float scale, const float max_bias, const float m0,
128137
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129138
const size_t n_local_scratch, queue_ptr stream) {
@@ -133,15 +142,16 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
133142
cgh.parallel_for(
134143
sycl::nd_range<3>(block_nums * block_dims, block_dims),
135144
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
136-
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
145+
soft_max_f32<vals_smem, ncols_template, block_size_template, T>(x, mask, dst, ncols_par,
137146
nrows_y, scale, max_bias, m0,
138147
m1, n_head_log2, item_ct1,
139148
get_pointer(local_buf_acc));
140149
});
141150
});
142151
}
143152

144-
static void soft_max_f32_sycl(const float * x, const float * mask,
153+
template<typename T>
154+
static void soft_max_f32_sycl(const float * x, const T * mask,
145155
float * dst, const int ncols_x, const int nrows_x,
146156
const int nrows_y, const float scale, const float max_bias,
147157
queue_ptr stream, int device) {
@@ -164,88 +174,99 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
164174
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
165175
if (n_local_scratch*sizeof(float) < local_mem_size) {
166176
if (ncols_x > max_block_size) {
167-
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
177+
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
168178
max_bias, m0, m1, n_head_log2, block_nums,
169179
block_dims, n_local_scratch, stream);
170180
return;
171181
}
172182
switch (ncols_x) {
173183
case 32:
174-
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
184+
soft_max_f32_submitter<true, 32, 32, T>(x, mask, dst, ncols_x, nrows_y, scale,
175185
max_bias, m0, m1, n_head_log2, block_nums,
176186
block_dims, n_local_scratch, stream);
177187
break;
178188
case 64:
179-
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
189+
soft_max_f32_submitter<true, 64, 64, T>(x, mask, dst, ncols_x, nrows_y, scale,
180190
max_bias, m0, m1, n_head_log2, block_nums,
181191
block_dims, n_local_scratch, stream);
182192
break;
183193
case 128:
184-
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
194+
soft_max_f32_submitter<true, 128, 128, T>(x, mask, dst, ncols_x, nrows_y, scale,
185195
max_bias, m0, m1, n_head_log2, block_nums,
186196
block_dims, n_local_scratch, stream);
187197
break;
188198
case 256:
189-
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
199+
soft_max_f32_submitter<true, 256, 256, T>(x, mask, dst, ncols_x, nrows_y, scale,
190200
max_bias, m0, m1, n_head_log2, block_nums,
191201
block_dims, n_local_scratch, stream);
192202
break;
193203
case 512:
194-
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
204+
soft_max_f32_submitter<true, 512, 512, T>(x, mask, dst, ncols_x, nrows_y, scale,
195205
max_bias, m0, m1, n_head_log2, block_nums,
196206
block_dims, n_local_scratch, stream);
197207
break;
198208
case 1024:
199-
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
209+
soft_max_f32_submitter<true, 1024, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
200210
max_bias, m0, m1, n_head_log2, block_nums,
201211
block_dims, n_local_scratch, stream);
202212
break;
203213
case 2048:
204-
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
214+
soft_max_f32_submitter<true, 2048, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
205215
max_bias, m0, m1, n_head_log2, block_nums,
206216
block_dims, n_local_scratch, stream);
207217
break;
208218
case 4096:
209-
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
219+
soft_max_f32_submitter<true, 4096, 1024, T>(x, mask, dst, ncols_x, nrows_y, scale,
210220
max_bias, m0, m1, n_head_log2, block_nums,
211221
block_dims, n_local_scratch, stream);
212222
break;
213223
default:
214-
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
224+
soft_max_f32_submitter<true, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
215225
max_bias, m0, m1, n_head_log2, block_nums,
216226
block_dims, n_local_scratch, stream);
217227
break;
218228
}
219229
} else {
220-
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
230+
soft_max_f32_submitter<false, 0, 0, T>(x, mask, dst, ncols_x, nrows_y, scale,
221231
max_bias, m0, m1, n_head_log2, block_nums,
222232
block_dims, WARP_SIZE, stream);
223233
}
224234
}
225235

226-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
227-
const ggml_tensor *src1, ggml_tensor *dst,
228-
const float *src0_dd, const float *src1_dd,
229-
float *dst_dd,
230-
const queue_ptr &main_stream) {
236+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
231237

232-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
238+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
233239
GGML_ASSERT( dst->type == GGML_TYPE_F32);
234240

235-
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
236-
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
237-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
241+
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
238242

239-
const int64_t ne00 = src0->ne[0];
240-
const int64_t nrows_x = ggml_nrows(src0);
241-
const int64_t nrows_y = src0->ne[1];
243+
const int64_t ne00 = dst->src[0]->ne[0];
244+
const int64_t nrows_x = ggml_nrows(dst->src[0]);
245+
const int64_t nrows_y = dst->src[0]->ne[1];
242246

243247
float scale = 1.0f;
244248
float max_bias = 0.0f;
245249

246250
memcpy(&scale, dst->op_params + 0, sizeof(float));
247251
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
248252

249-
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
250-
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
253+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
254+
float * dst_dd = static_cast<float *>(dst->data);
255+
256+
ggml_sycl_set_device(ctx.device);
257+
dpct::queue_ptr main_stream = ctx.stream();
258+
259+
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
260+
//printf("%s: fp16 mask\n", __func__);
261+
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
262+
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
263+
main_stream, ctx.device);
264+
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
265+
//printf("%s: fp32 mask\n", __func__);
266+
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
267+
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
268+
} else {
269+
/* mask unavailable */
270+
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
271+
}
251272
}

ggml/src/ggml-sycl/softmax.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515

1616
#include "common.hpp"
1717

18-
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
19-
const ggml_tensor *src1, ggml_tensor *dst,
20-
const float *src0_dd, const float *src1_dd,
21-
float *dst_dd,
22-
const queue_ptr &main_stream);
18+
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
2319

2420
#endif // GGML_SYCL_SOFTMAX_HPP

ggml/src/ggml-sycl/wkv6.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,6 @@ static void rwkv_wkv_f32_kernel(
9797

9898
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
9999

100-
const ggml_tensor *src0 = dst->src[0];
101-
const ggml_tensor *src1 = dst->src[1];
102-
103100
const float* k_d = (const float*)dst->src[0]->data;
104101
const float* v_d = (const float*)dst->src[1]->data;
105102
const float* r_d = (const float*)dst->src[2]->data;
@@ -137,7 +134,4 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
137134
);
138135
});
139136
});
140-
141-
GGML_UNUSED(src0);
142-
GGML_UNUSED(src1);
143137
}

0 commit comments

Comments
 (0)