@@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2069
2069
"GELU" ,
2070
2070
"SILU" ,
2071
2071
"NORM" ,
2072
+ "RMS_NORM" ,
2072
2073
2073
2074
"MUL_MAT" ,
2074
2075
@@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2089
2090
"FLASH_FF" ,
2090
2091
};
2091
2092
2092
- static_assert (GGML_OP_COUNT == 34 , "GGML_OP_COUNT != 34 " );
2093
+ static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2093
2094
2094
2095
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2095
2096
"none" ,
@@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2112
2113
"gelu(x)" ,
2113
2114
"silu(x)" ,
2114
2115
"norm(x)" ,
2116
+ "rms_norm(x)" ,
2115
2117
2116
2118
"X*Y" ,
2117
2119
@@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2132
2134
"flash_ff(x)" ,
2133
2135
};
2134
2136
2135
- static_assert (GGML_OP_COUNT == 34 , "GGML_OP_COUNT != 34 " );
2137
+ static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2136
2138
2137
2139
//
2138
2140
// ggml object
@@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
3618
3620
return ggml_norm_impl (ctx , a , true);
3619
3621
}
3620
3622
3623
+ struct ggml_tensor * ggml_rms_norm_impl (
3624
+ struct ggml_context * ctx ,
3625
+ struct ggml_tensor * a ,
3626
+ bool inplace ) {
3627
+ bool is_node = false;
3628
+
3629
+ if (!inplace && (a -> grad )) {
3630
+ GGML_ASSERT (false); // TODO: implement backward
3631
+ is_node = true;
3632
+ }
3633
+
3634
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3635
+
3636
+ result -> op = GGML_OP_RMS_NORM ;
3637
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3638
+ result -> src0 = a ;
3639
+ result -> src1 = NULL ; // TODO: maybe store epsilon here?
3640
+
3641
+ return result ;
3642
+ }
3643
+
3644
+ struct ggml_tensor * ggml_rms_norm (
3645
+ struct ggml_context * ctx ,
3646
+ struct ggml_tensor * a ) {
3647
+ return ggml_rms_norm_impl (ctx , a , false);
3648
+ }
3649
+
3650
+ struct ggml_tensor * ggml_rms_norm_inplace (
3651
+ struct ggml_context * ctx ,
3652
+ struct ggml_tensor * a ) {
3653
+ return ggml_rms_norm_impl (ctx , a , true);
3654
+ }
3655
+
3621
3656
// ggml_mul_mat
3622
3657
3623
3658
struct ggml_tensor * ggml_mul_mat (
@@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm(
5406
5441
}
5407
5442
}
5408
5443
5444
+ static void ggml_compute_forward_rms_norm_f32 (
5445
+ const struct ggml_compute_params * params ,
5446
+ const struct ggml_tensor * src0 ,
5447
+ struct ggml_tensor * dst ) {
5448
+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
5449
+
5450
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5451
+ return ;
5452
+ }
5453
+
5454
+ GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
5455
+
5456
+ const int ith = params -> ith ;
5457
+ const int nth = params -> nth ;
5458
+
5459
+ const int ne00 = src0 -> ne [0 ];
5460
+ const int ne01 = src0 -> ne [1 ];
5461
+ const int ne02 = src0 -> ne [2 ];
5462
+ const int ne03 = src0 -> ne [3 ];
5463
+
5464
+ const size_t nb01 = src0 -> nb [1 ];
5465
+ const size_t nb02 = src0 -> nb [2 ];
5466
+ const size_t nb03 = src0 -> nb [3 ];
5467
+
5468
+ const size_t nb1 = dst -> nb [1 ];
5469
+ const size_t nb2 = dst -> nb [2 ];
5470
+ const size_t nb3 = dst -> nb [3 ];
5471
+
5472
+ const ggml_float eps = 1e-5f ; // TODO: make this a parameter
5473
+
5474
+ // TODO: optimize
5475
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5476
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5477
+ for (int i01 = ith ; i01 < ne01 ; i01 += nth ) {
5478
+ const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
5479
+
5480
+ ggml_float mean = 0.0 ;
5481
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
5482
+ mean += x [i00 ] * x [i00 ];
5483
+ }
5484
+
5485
+ mean /= ne00 ;
5486
+
5487
+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
5488
+
5489
+ memcpy (y , x , ne00 * sizeof (float ));
5490
+ // for (int i00 = 0; i00 < ne00; i00++) {
5491
+ // y[i00] = x[i00];
5492
+ // }
5493
+
5494
+ const float scale = 1.0 /sqrt (mean + eps );
5495
+
5496
+ ggml_vec_scale_f32 (ne00 , y , scale );
5497
+ }
5498
+ }
5499
+ }
5500
+ }
5501
+
5502
+ static void ggml_compute_forward_rms_norm (
5503
+ const struct ggml_compute_params * params ,
5504
+ const struct ggml_tensor * src0 ,
5505
+ struct ggml_tensor * dst ) {
5506
+ switch (src0 -> type ) {
5507
+ case GGML_TYPE_F32 :
5508
+ {
5509
+ ggml_compute_forward_rms_norm_f32 (params , src0 , dst );
5510
+ } break ;
5511
+ case GGML_TYPE_Q4_0 :
5512
+ case GGML_TYPE_Q4_1 :
5513
+ case GGML_TYPE_I8 :
5514
+ case GGML_TYPE_I16 :
5515
+ case GGML_TYPE_I32 :
5516
+ case GGML_TYPE_F16 :
5517
+ case GGML_TYPE_COUNT :
5518
+ {
5519
+ GGML_ASSERT (false);
5520
+ } break ;
5521
+ }
5522
+ }
5523
+
5524
+
5409
5525
// ggml_compute_forward_mul_mat
5410
5526
5411
5527
#if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
@@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8522
8638
{
8523
8639
ggml_compute_forward_norm (params , tensor -> src0 , tensor );
8524
8640
} break ;
8641
+ case GGML_OP_RMS_NORM :
8642
+ {
8643
+ ggml_compute_forward_rms_norm (params , tensor -> src0 , tensor );
8644
+ } break ;
8525
8645
case GGML_OP_MUL_MAT :
8526
8646
{
8527
8647
ggml_compute_forward_mul_mat (params , tensor -> src0 , tensor -> src1 , tensor );
@@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8764
8884
{
8765
8885
GGML_ASSERT (false); // TODO: not implemented
8766
8886
} break ;
8887
+ case GGML_OP_RMS_NORM :
8888
+ {
8889
+ GGML_ASSERT (false); // TODO: not implemented
8890
+ } break ;
8767
8891
case GGML_OP_MUL_MAT :
8768
8892
{
8769
8893
if (src0 -> grad ) {
0 commit comments