1
1
#include " norm.hpp"
2
+ #include " ggml-sycl/common.hpp"
3
+ #include " ggml-sycl/presets.hpp"
2
4
3
- static void norm_f32 (const float * x, float * dst, const int ncols, const float eps ,
4
- const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
5
- const int row = item_ct1. get_group ( 2 ) * item_ct1. get_local_range ( 1 ) +
6
- item_ct1.get_local_id ( 1 );
7
- const int tid = item_ct1.get_local_id ( 2 );
5
+ static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel ,
6
+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
7
+
8
+ const int nrows = item_ct1.get_group_range ( 2 );
9
+ const int nchannels = item_ct1.get_group_range ( 1 );
8
10
9
11
const int nthreads = item_ct1.get_local_range (2 );
12
+ const int sample = item_ct1.get_group (0 );
13
+ const int channel = item_ct1.get_group (1 );
14
+ const int row = item_ct1.get_group (2 );
15
+
16
+ const int tid = item_ct1.get_local_id (2 );
10
17
const int nwarps = nthreads / WARP_SIZE;
18
+
19
+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
20
+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21
+
22
+ x += strided_offset;
23
+ dst += packed_offset;
24
+
11
25
sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
12
26
13
27
for (int col = tid; col < ncols; col += block_size) {
14
- const float xi = x[row * ncols + col];
28
+ const float xi = x[col];
15
29
mean_var.x () += xi;
16
30
mean_var.y () += xi * xi;
17
31
}
18
32
19
33
// sum up partial sums
20
34
mean_var = warp_reduce_sum (mean_var, item_ct1);
21
- if (block_size > WARP_SIZE) {
22
-
23
- int warp_id = item_ct1. get_local_id ( 2 ) / WARP_SIZE ;
24
- int lane_id = item_ct1. get_local_id ( 2 ) % WARP_SIZE ;
25
- if (lane_id == 0 ) {
26
- s_sum[warp_id ] = mean_var;
35
+ if (block_size > WARP_SIZE) {
36
+ const auto sub_group = item_ct1. get_sub_group ();
37
+ const auto sg_id = sub_group. get_group_linear_id () ;
38
+ const auto wi_in_sg = sub_group. get_local_linear_id () ;
39
+ if (wi_in_sg == 0 ) {
40
+ s_sum[sg_id ] = mean_var;
27
41
}
28
- /*
29
- DPCT1118:0: SYCL group functions and algorithms must be encountered in
30
- converged control flow. You may need to adjust the code.
31
- */
32
42
item_ct1.barrier (sycl::access ::fence_space::local_space);
33
43
mean_var = 0 .f ;
34
- size_t nreduce = nwarps / WARP_SIZE;
44
+ const size_t nreduce = ceil_div ( nwarps, WARP_SIZE) ;
35
45
for (size_t i = 0 ; i < nreduce; i += 1 )
36
46
{
37
- mean_var += s_sum[lane_id + i * WARP_SIZE];
47
+ mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
38
48
}
39
49
mean_var = warp_reduce_sum (mean_var, item_ct1);
40
50
}
@@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
44
54
const float inv_std = sycl::rsqrt (var + eps);
45
55
46
56
for (int col = tid; col < ncols; col += block_size) {
47
- dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
57
+ dst[col] = (x[col] - mean) * inv_std;
48
58
}
49
59
}
50
60
@@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
135
145
}
136
146
}
137
147
138
- static void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps,
139
- const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
140
- const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
141
- item_ct1.get_local_id (1 );
142
- const int tid = item_ct1.get_local_id (2 );
148
+ static void rms_norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149
+ const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
150
+
151
+ const int nrows = item_ct1.get_group_range (2 );
152
+ const int nchannels = item_ct1.get_group_range (1 );
153
+
154
+ const int sample = item_ct1.get_group (0 );
155
+ const int channel = item_ct1.get_group (1 );
156
+ const int row = item_ct1.get_group (2 );
157
+
143
158
const int nthreads = item_ct1.get_local_range (2 );
159
+
160
+ const int tid = item_ct1.get_local_id (2 );
144
161
const int nwarps = nthreads / WARP_SIZE;
162
+
163
+ const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164
+ const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
165
+
166
+ x += strided_offset;
167
+ dst += packed_offset;
168
+
169
+
145
170
float tmp = 0 .0f ; // partial sum for thread in warp
146
171
147
172
for (int col = tid; col < ncols; col += block_size) {
148
- const float xi = x[row * ncols + col];
173
+ const float xi = x[col];
149
174
tmp += xi * xi;
150
175
}
151
176
152
177
// sum up partial sums
153
178
tmp = warp_reduce_sum (tmp, item_ct1);
154
179
if (block_size > WARP_SIZE) {
155
-
156
- int warp_id = item_ct1. get_local_id ( 2 ) / WARP_SIZE ;
157
- int lane_id = item_ct1. get_local_id ( 2 ) % WARP_SIZE ;
158
- if (lane_id == 0 ) {
159
- s_sum[warp_id ] = tmp;
180
+ const auto sub_group = item_ct1. get_sub_group ();
181
+ const auto sg_id = sub_group. get_group_linear_id () ;
182
+ const auto wi_in_sg = sub_group. get_local_linear_id () ;
183
+ if (wi_in_sg == 0 ) {
184
+ s_sum[sg_id ] = tmp;
160
185
}
161
- /*
162
- DPCT1118:3: SYCL group functions and algorithms must be encountered in
163
- converged control flow. You may need to adjust the code.
164
- */
186
+
165
187
item_ct1.barrier (sycl::access ::fence_space::local_space);
166
- size_t nreduce = nwarps / WARP_SIZE;
188
+ const size_t nreduce = ceil_div ( nwarps, WARP_SIZE) ;
167
189
tmp = 0 .f ;
168
190
for (size_t i = 0 ; i < nreduce; i += 1 )
169
191
{
170
- tmp += s_sum[lane_id + i * WARP_SIZE];
192
+ tmp += s_sum[wi_in_sg + i * WARP_SIZE];
171
193
}
172
194
tmp = warp_reduce_sum (tmp, item_ct1);
173
195
}
@@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
176
198
const float scale = sycl::rsqrt (mean + eps);
177
199
178
200
for (int col = tid; col < ncols; col += block_size) {
179
- dst[row * ncols + col] = scale * x[row * ncols + col];
201
+ dst[col] = scale * x[col];
180
202
}
181
203
}
182
204
@@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float
224
246
}
225
247
}
226
248
227
- static void norm_f32_sycl (const float * x, float * dst, const int ncols,
228
- const int nrows, const float eps,
229
- queue_ptr stream, int device) {
249
+ static void norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
250
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
251
+ const float eps, queue_ptr stream, int device) {
252
+
253
+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
230
254
GGML_ASSERT (ncols % WARP_SIZE == 0 );
231
255
if (ncols < 1024 ) {
232
256
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
233
257
stream->submit ([&](sycl::handler& cgh) {
234
258
cgh.parallel_for (
235
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
236
- block_dims),
259
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
237
260
[=](sycl::nd_item<3 > item_ct1)
238
261
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
239
- norm_f32 (x, dst, ncols, eps, item_ct1,
240
- nullptr , WARP_SIZE);
262
+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
241
263
});
242
264
});
243
265
}
@@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
252
274
*/
253
275
stream->submit ([&](sycl::handler& cgh) {
254
276
sycl::local_accessor<sycl::float2, 1 > s_sum_acc_ct1 (
255
- sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
256
-
277
+ sycl::range<1 >(work_group_size / WARP_SIZE), cgh);
257
278
cgh.parallel_for (
258
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
259
- block_dims),
279
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
260
280
[=](sycl::nd_item<3 > item_ct1)
261
281
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
262
- norm_f32 (x, dst, ncols, eps, item_ct1,
263
- get_pointer (s_sum_acc_ct1), work_group_size);
282
+ norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
264
283
});
265
284
});
266
285
}
@@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst,
313
332
}
314
333
}
315
334
316
- static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols,
317
- const int nrows, const float eps,
318
- queue_ptr stream, int device) {
335
+ static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
336
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
319
337
GGML_ASSERT (ncols % WARP_SIZE == 0 );
320
338
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
339
+
340
+ const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
321
341
if (ncols < 1024 ) {
322
342
const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
323
343
stream->submit ([&](sycl::handler& cgh) {
324
344
cgh.parallel_for (
325
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
326
- block_dims),
345
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
327
346
[=](sycl::nd_item<3 > item_ct1)
328
347
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
329
- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
330
- nullptr , WARP_SIZE);
348
+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE);
331
349
});
332
350
});
333
351
}
@@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
344
362
sycl::local_accessor<float , 1 > s_sum_acc_ct1 (sycl::range<1 >(work_group_size / WARP_SIZE),
345
363
cgh);
346
364
cgh.parallel_for (
347
- sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , nrows) * block_dims,
348
- block_dims),
365
+ sycl::nd_range<3 >(global_dims * block_dims, block_dims),
349
366
[=](sycl::nd_item<3 > item_ct1)
350
367
[[sycl::reqd_sub_group_size (WARP_SIZE)]] {
351
- rms_norm_f32 (x, dst, ncols, eps, item_ct1,
352
- get_pointer (s_sum_acc_ct1), work_group_size);
368
+ rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
353
369
});
354
370
});
355
371
}
@@ -398,21 +414,27 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
398
414
}
399
415
400
416
void ggml_sycl_op_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
417
+ const ggml_tensor * src0 = dst->src [0 ];
401
418
402
419
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
403
420
GGML_ASSERT (dst->type == GGML_TYPE_F32);
404
421
405
- const int64_t ne00 = dst->src [0 ]->ne [0 ];
406
- const int64_t nrows = ggml_nrows (dst->src [0 ]);
422
+ GGML_TENSOR_UNARY_OP_LOCALS
407
423
dpct::queue_ptr main_stream = ctx.stream ();
408
424
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
409
425
const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
410
426
float * dst_dd = static_cast <float *>(dst->data );
411
427
412
428
float eps;
413
429
memcpy (&eps, dst->op_params , sizeof (float ));
414
-
415
- norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
430
+ GGML_ASSERT (eps >= 0 .0f );
431
+ const size_t ts0 = ggml_type_size (src0->type );
432
+ GGML_ASSERT (nb00 == ts0);
433
+ const int64_t s01 = nb01 / ts0;
434
+ const int64_t s02 = nb02 / ts0;
435
+ const int64_t s03 = nb03 / ts0;
436
+
437
+ norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
416
438
}
417
439
418
440
void ggml_sycl_op_group_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
436
458
437
459
void ggml_sycl_op_rms_norm (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
438
460
461
+ const ggml_tensor * src0 = dst->src [0 ];
439
462
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
440
463
GGML_ASSERT (dst->type == GGML_TYPE_F32);
441
464
442
- const int64_t ne00 = dst->src [0 ]->ne [0 ];
443
- const int64_t nrows = ggml_nrows (dst->src [0 ]);
444
465
dpct::queue_ptr main_stream = ctx.stream ();
445
466
SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
446
467
@@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
450
471
float eps;
451
472
memcpy (&eps, dst->op_params , sizeof (float ));
452
473
453
- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device );
474
+ GGML_TENSOR_UNARY_OP_LOCALS
475
+ const size_t ts0 = ggml_type_size (src0->type );
476
+ GGML_ASSERT (nb00 == ts0);
477
+ const int64_t s01 = nb01 / ts0;
478
+ const int64_t s02 = nb02 / ts0;
479
+ const int64_t s03 = nb03 / ts0;
480
+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
454
481
}
455
482
456
483
void ggml_sycl_op_l2_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
0 commit comments