@@ -398,10 +398,10 @@ static void rms_norm_mul_f32_cuda(const float * x,
398
398
return ;
399
399
}
400
400
if (add == nullptr ) {
401
- uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
402
- uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
403
- uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
404
- uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
401
+ const uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
402
+ const uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
403
+ const uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
404
+ const uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
405
405
if (ncols < 1024 ) {
406
406
const dim3 block_dims (256 , 1 , 1 );
407
407
rms_norm_f32<256 , true ><<<blocks_num, block_dims, 0 , stream>>> (
@@ -414,15 +414,15 @@ static void rms_norm_mul_f32_cuda(const float * x,
414
414
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
415
415
}
416
416
} else {
417
- uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
418
- uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
419
- uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
420
- uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
421
-
422
- uint3 add_ncols_packed = init_fastdiv_values (add_ncols);
423
- uint3 add_nrows_packed = init_fastdiv_values (add_nrows);
424
- uint3 add_nchannels_packed = init_fastdiv_values (add_nchannels);
425
- uint3 add_nsamples_packed = init_fastdiv_values (add_nsamples);
417
+ const uint3 mul_ncols_packed = init_fastdiv_values (mul_ncols);
418
+ const uint3 mul_nrows_packed = init_fastdiv_values (mul_nrows);
419
+ const uint3 mul_nchannels_packed = init_fastdiv_values (mul_nchannels);
420
+ const uint3 mul_nsamples_packed = init_fastdiv_values (mul_nsamples);
421
+
422
+ const uint3 add_ncols_packed = init_fastdiv_values (add_ncols);
423
+ const uint3 add_nrows_packed = init_fastdiv_values (add_nrows);
424
+ const uint3 add_nchannels_packed = init_fastdiv_values (add_nchannels);
425
+ const uint3 add_nsamples_packed = init_fastdiv_values (add_nsamples);
426
426
if (ncols < 1024 ) {
427
427
const dim3 block_dims (256 , 1 , 1 );
428
428
rms_norm_f32<256 , true , true ><<<blocks_num, block_dims, 0 , stream>>> (
0 commit comments