@@ -105,45 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105
105
}
106
106
107
107
template <int block_size, bool do_multiply = false , bool do_add = false >
108
- static __global__ void rms_norm_f32 (const float * x,
109
- float * dst,
110
- const int ncols,
111
- const int64_t stride_row,
112
- const int64_t stride_channel,
113
- const int64_t stride_sample,
114
- const float eps,
115
- const float * mul = nullptr ,
116
- const int64_t mul_stride_row = 0 ,
117
- const int64_t mul_stride_channel = 0 ,
118
- const int64_t mul_stride_sample = 0 ,
119
- const uint32_t mul_ncols = 0 ,
120
- const uint32_t mul_nrows = 0 ,
121
- const uint32_t mul_nchannels = 0 ,
122
- const uint32_t mul_nsamples = 0 ,
123
- const uint32_t mp_mul_cols = 0 ,
124
- const uint32_t L_mul_cols = 0 ,
125
- const uint32_t mp_mul_rows = 0 ,
126
- const uint32_t L_mul_rows = 0 ,
127
- const uint32_t mp_mul_channels = 0 ,
128
- const uint32_t L_mul_channels = 0 ,
129
- const uint32_t mp_mul_samples = 0 ,
130
- const uint32_t L_mul_samples = 0 ,
131
- const float * add = nullptr ,
132
- const int64_t add_stride_row = 0 ,
133
- const int64_t add_stride_channel = 0 ,
134
- const int64_t add_stride_sample = 0 ,
135
- const uint32_t add_ncols = 0 ,
136
- const uint32_t add_nrows = 0 ,
137
- const uint32_t add_nchannels = 0 ,
138
- const uint32_t add_nsamples = 0 ,
139
- const uint32_t mp_add_cols = 0 ,
140
- const uint32_t L_add_cols = 0 ,
141
- const uint32_t mp_add_rows = 0 ,
142
- const uint32_t L_add_rows = 0 ,
143
- const uint32_t mp_add_channels = 0 ,
144
- const uint32_t L_add_channels = 0 ,
145
- const uint32_t mp_add_samples = 0 ,
146
- const uint32_t L_add_samples = 0 ) {
108
+ static __global__ void rms_norm_f32 (const float * x,
109
+ float * dst,
110
+ const int ncols,
111
+ const int64_t stride_row,
112
+ const int64_t stride_channel,
113
+ const int64_t stride_sample,
114
+ const float eps,
115
+ const float * mul = nullptr ,
116
+ const int64_t mul_stride_row = 0 ,
117
+ const int64_t mul_stride_channel = 0 ,
118
+ const int64_t mul_stride_sample = 0 ,
119
+ const uint3 mul_ncols_packed = make_uint3(0 , 0 , 0 ),
120
+ const uint3 mul_nrows_packed = make_uint3(0 , 0 , 0 ),
121
+ const uint3 mul_nchannels_packed = make_uint3(0 , 0 , 0 ),
122
+ const uint3 mul_nsamples_packed = make_uint3(0 , 0 , 0 ),
123
+ const float * add = nullptr,
124
+ const int64_t add_stride_row = 0,
125
+ const int64_t add_stride_channel = 0,
126
+ const int64_t add_stride_sample = 0,
127
+ const uint3 add_ncols_packed = make_uint3(0 , 0 , 0 ),
128
+ const uint3 add_nrows_packed = make_uint3(0 , 0 , 0 ),
129
+ const uint3 add_nchannels_packed = make_uint3(0 , 0 , 0 ),
130
+ const uint3 add_nsamples_packed = make_uint3(0 , 0 , 0 )) {
147
131
const int nrows = gridDim .x ;
148
132
const int nchannels = gridDim .y ;
149
133
@@ -158,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x,
158
142
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
159
143
160
144
if constexpr (do_multiply) {
161
- const uint32_t mul_row = fastmodulo (row, mul_nrows, mp_mul_rows, L_mul_rows );
162
- const uint32_t mul_channel = fastmodulo (channel, mul_nchannels, mp_mul_channels, L_mul_channels );
163
- const uint32_t mul_sample = fastmodulo (sample, mul_nsamples, mp_mul_samples, L_mul_samples );
145
+ const uint32_t mul_row = fastmodulo (row, mul_nrows_packed );
146
+ const uint32_t mul_channel = fastmodulo (channel, mul_nchannels_packed );
147
+ const uint32_t mul_sample = fastmodulo (sample, mul_nsamples_packed );
164
148
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
165
149
}
166
150
167
151
if constexpr (do_add) {
168
- const int add_row = fastmodulo (row, add_nrows, mp_add_rows, L_add_rows );
169
- const int add_channel = fastmodulo (channel, add_nchannels, mp_add_channels, L_add_channels );
170
- const int add_sample = fastmodulo (sample, add_nsamples, mp_add_samples, L_add_samples );
152
+ const int add_row = fastmodulo (row, add_nrows_packed );
153
+ const int add_channel = fastmodulo (channel, add_nchannels_packed );
154
+ const int add_sample = fastmodulo (sample, add_nsamples_packed );
171
155
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
172
156
}
173
157
@@ -201,11 +185,11 @@ static __global__ void rms_norm_f32(const float * x,
201
185
202
186
for (int col = tid; col < ncols; col += block_size) {
203
187
if constexpr (do_multiply && do_add) {
204
- const int mul_col = fastmodulo (col, mul_ncols, mp_mul_cols, L_mul_cols );
205
- const int add_col = fastmodulo (col, add_ncols, mp_add_cols, L_add_cols );
188
+ const int mul_col = fastmodulo (col, mul_ncols_packed );
189
+ const int add_col = fastmodulo (col, add_ncols_packed );
206
190
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
207
191
} else if constexpr (do_multiply) {
208
- const int mul_col = fastmodulo (col, mul_ncols, mp_mul_cols, L_mul_cols );
192
+ const int mul_col = fastmodulo (col, mul_ncols_packed );
209
193
dst[col] = scale * x[col] * mul[mul_col];
210
194
} else {
211
195
dst[col] = scale * x[col];
@@ -414,63 +398,45 @@ static void rms_norm_mul_f32_cuda(const float * x,
414
398
return ;
415
399
}
416
400
if (add == nullptr ) {
417
- uint32_t mp_mul_cols, L_mul_cols;
418
- init_fastdiv_values (mul_ncols, mp_mul_cols, L_mul_cols);
419
- uint32_t mp_mul_rows, L_mul_rows;
420
- init_fastdiv_values (mul_nrows, mp_mul_rows, L_mul_rows);
421
- uint32_t mp_mul_channels, L_mul_channels;
422
- init_fastdiv_values (mul_nchannels, mp_mul_channels, L_mul_channels);
423
- uint32_t mp_mul_samples, L_mul_samples;
424
- init_fastdiv_values (mul_nsamples, mp_mul_samples, L_mul_samples);
401
+ uint3 mul_ncols_packed = init_fastmodulo_values (mul_ncols);
402
+ uint3 mul_nrows_packed = init_fastmodulo_values (mul_nrows);
403
+ uint3 mul_nchannels_packed = init_fastmodulo_values (mul_nchannels);
404
+ uint3 mul_nsamples_packed = init_fastmodulo_values (mul_nsamples);
425
405
if (ncols < 1024 ) {
426
406
const dim3 block_dims (256 , 1 , 1 );
427
407
rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (
428
408
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
429
- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
430
- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
409
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
431
410
} else {
432
411
const dim3 block_dims (1024 , 1 , 1 );
433
412
rms_norm_f32<1024 , true ><<<blocks_num, block_dims, 0 , stream>>> (
434
413
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
435
- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
436
- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
414
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
437
415
}
438
416
} else {
439
- uint32_t mp_mul_cols, L_mul_cols;
440
- init_fastdiv_values (mul_ncols, mp_mul_cols, L_mul_cols);
441
- uint32_t mp_mul_rows, L_mul_rows;
442
- init_fastdiv_values (mul_nrows, mp_mul_rows, L_mul_rows);
443
- uint32_t mp_mul_channels, L_mul_channels;
444
- init_fastdiv_values (mul_nchannels, mp_mul_channels, L_mul_channels);
445
- uint32_t mp_mul_samples, L_mul_samples;
446
- init_fastdiv_values (mul_nsamples, mp_mul_samples, L_mul_samples);
447
-
448
- uint32_t mp_add_cols, L_add_cols;
449
- init_fastdiv_values (add_ncols, mp_add_cols, L_add_cols);
450
- uint32_t mp_add_rows, L_add_rows;
451
- init_fastdiv_values (add_nrows, mp_add_rows, L_add_rows);
452
- uint32_t mp_add_channels, L_add_channels;
453
- init_fastdiv_values (add_nchannels, mp_add_channels, L_add_channels);
454
- uint32_t mp_add_samples, L_add_samples;
455
- init_fastdiv_values (add_nsamples, mp_add_samples, L_add_samples);
417
+ uint3 mul_ncols_packed = init_fastmodulo_values (mul_ncols);
418
+ uint3 mul_nrows_packed = init_fastmodulo_values (mul_nrows);
419
+ uint3 mul_nchannels_packed = init_fastmodulo_values (mul_nchannels);
420
+ uint3 mul_nsamples_packed = init_fastmodulo_values (mul_nsamples);
421
+
422
+ uint3 add_ncols_packed = init_fastmodulo_values (add_ncols);
423
+ uint3 add_nrows_packed = init_fastmodulo_values (add_nrows);
424
+ uint3 add_nchannels_packed = init_fastmodulo_values (add_nchannels);
425
+ uint3 add_nsamples_packed = init_fastmodulo_values (add_nsamples);
456
426
if (ncols < 1024 ) {
457
427
const dim3 block_dims (256 , 1 , 1 );
458
428
rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
459
429
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
460
- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
461
- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
462
- add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
463
- add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
464
- mp_add_samples, L_add_samples);
430
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432
+ add_nchannels_packed, add_nsamples_packed);
465
433
} else {
466
434
const dim3 block_dims (1024 , 1 , 1 );
467
435
rms_norm_f32<1024 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
468
436
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
469
- mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
470
- mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
471
- add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
472
- add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
473
- mp_add_samples, L_add_samples);
437
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439
+ add_nchannels_packed, add_nsamples_packed);
474
440
}
475
441
}
476
442
}
0 commit comments