@@ -2156,6 +2156,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2156
2156
"GELU" ,
2157
2157
"SILU" ,
2158
2158
"NORM" ,
2159
+ "RMS_NORM" ,
2159
2160
2160
2161
"MUL_MAT" ,
2161
2162
@@ -2176,7 +2177,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2176
2177
"FLASH_FF" ,
2177
2178
};
2178
2179
2179
- static_assert (GGML_OP_COUNT == 34 , "GGML_OP_COUNT != 34 " );
2180
+ static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2180
2181
2181
2182
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2182
2183
"none" ,
@@ -2199,6 +2200,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2199
2200
"gelu(x)" ,
2200
2201
"silu(x)" ,
2201
2202
"norm(x)" ,
2203
+ "rms_norm(x)" ,
2202
2204
2203
2205
"X*Y" ,
2204
2206
@@ -2219,7 +2221,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2219
2221
"flash_ff(x)" ,
2220
2222
};
2221
2223
2222
- static_assert (GGML_OP_COUNT == 34 , "GGML_OP_COUNT != 34 " );
2224
+ static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2223
2225
2224
2226
//
2225
2227
// ggml object
@@ -3705,6 +3707,39 @@ struct ggml_tensor * ggml_norm_inplace(
3705
3707
return ggml_norm_impl (ctx , a , true);
3706
3708
}
3707
3709
3710
+ struct ggml_tensor * ggml_rms_norm_impl (
3711
+ struct ggml_context * ctx ,
3712
+ struct ggml_tensor * a ,
3713
+ bool inplace ) {
3714
+ bool is_node = false;
3715
+
3716
+ if (!inplace && (a -> grad )) {
3717
+ GGML_ASSERT (false); // TODO: implement backward
3718
+ is_node = true;
3719
+ }
3720
+
3721
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3722
+
3723
+ result -> op = GGML_OP_RMS_NORM ;
3724
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3725
+ result -> src0 = a ;
3726
+ result -> src1 = NULL ; // TODO: maybe store epsilon here?
3727
+
3728
+ return result ;
3729
+ }
3730
+
3731
+ struct ggml_tensor * ggml_rms_norm (
3732
+ struct ggml_context * ctx ,
3733
+ struct ggml_tensor * a ) {
3734
+ return ggml_rms_norm_impl (ctx , a , false);
3735
+ }
3736
+
3737
+ struct ggml_tensor * ggml_rms_norm_inplace (
3738
+ struct ggml_context * ctx ,
3739
+ struct ggml_tensor * a ) {
3740
+ return ggml_rms_norm_impl (ctx , a , true);
3741
+ }
3742
+
3708
3743
// ggml_mul_mat
3709
3744
3710
3745
struct ggml_tensor * ggml_mul_mat (
@@ -5493,6 +5528,87 @@ static void ggml_compute_forward_norm(
5493
5528
}
5494
5529
}
5495
5530
5531
+ static void ggml_compute_forward_rms_norm_f32 (
5532
+ const struct ggml_compute_params * params ,
5533
+ const struct ggml_tensor * src0 ,
5534
+ struct ggml_tensor * dst ) {
5535
+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
5536
+
5537
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5538
+ return ;
5539
+ }
5540
+
5541
+ GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
5542
+
5543
+ const int ith = params -> ith ;
5544
+ const int nth = params -> nth ;
5545
+
5546
+ const int ne00 = src0 -> ne [0 ];
5547
+ const int ne01 = src0 -> ne [1 ];
5548
+ const int ne02 = src0 -> ne [2 ];
5549
+ const int ne03 = src0 -> ne [3 ];
5550
+
5551
+ const size_t nb01 = src0 -> nb [1 ];
5552
+ const size_t nb02 = src0 -> nb [2 ];
5553
+ const size_t nb03 = src0 -> nb [3 ];
5554
+
5555
+ const size_t nb1 = dst -> nb [1 ];
5556
+ const size_t nb2 = dst -> nb [2 ];
5557
+ const size_t nb3 = dst -> nb [3 ];
5558
+
5559
+ const ggml_float eps = 1e-5f ; // TODO: make this a parameter
5560
+
5561
+ // TODO: optimize
5562
+ for (int i03 = 0 ; i03 < ne03 ; i03 ++ ) {
5563
+ for (int i02 = 0 ; i02 < ne02 ; i02 ++ ) {
5564
+ for (int i01 = ith ; i01 < ne01 ; i01 += nth ) {
5565
+ const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
5566
+
5567
+ ggml_float mean = 0.0 ;
5568
+ for (int i00 = 0 ; i00 < ne00 ; i00 ++ ) {
5569
+ mean += x [i00 ] * x [i00 ];
5570
+ }
5571
+
5572
+ mean /= ne00 ;
5573
+
5574
+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
5575
+
5576
+ memcpy (y , x , ne00 * sizeof (float ));
5577
+ // for (int i00 = 0; i00 < ne00; i00++) {
5578
+ // y[i00] = x[i00];
5579
+ // }
5580
+
5581
+ const float scale = 1.0 /sqrt (mean + eps );
5582
+
5583
+ ggml_vec_scale_f32 (ne00 , y , scale );
5584
+ }
5585
+ }
5586
+ }
5587
+ }
5588
+
5589
+ static void ggml_compute_forward_rms_norm (
5590
+ const struct ggml_compute_params * params ,
5591
+ const struct ggml_tensor * src0 ,
5592
+ struct ggml_tensor * dst ) {
5593
+ switch (src0 -> type ) {
5594
+ case GGML_TYPE_F32 :
5595
+ {
5596
+ ggml_compute_forward_rms_norm_f32 (params , src0 , dst );
5597
+ } break ;
5598
+ case GGML_TYPE_Q4_0 :
5599
+ case GGML_TYPE_Q4_1 :
5600
+ case GGML_TYPE_I8 :
5601
+ case GGML_TYPE_I16 :
5602
+ case GGML_TYPE_I32 :
5603
+ case GGML_TYPE_F16 :
5604
+ case GGML_TYPE_COUNT :
5605
+ {
5606
+ GGML_ASSERT (false);
5607
+ } break ;
5608
+ }
5609
+ }
5610
+
5611
+
5496
5612
// ggml_compute_forward_mul_mat
5497
5613
5498
5614
#if defined(GGML_USE_ACCELERATE ) || defined(GGML_USE_OPENBLAS )
@@ -8609,6 +8725,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8609
8725
{
8610
8726
ggml_compute_forward_norm (params , tensor -> src0 , tensor );
8611
8727
} break ;
8728
+ case GGML_OP_RMS_NORM :
8729
+ {
8730
+ ggml_compute_forward_rms_norm (params , tensor -> src0 , tensor );
8731
+ } break ;
8612
8732
case GGML_OP_MUL_MAT :
8613
8733
{
8614
8734
ggml_compute_forward_mul_mat (params , tensor -> src0 , tensor -> src1 , tensor );
@@ -8851,6 +8971,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8851
8971
{
8852
8972
GGML_ASSERT (false); // TODO: not implemented
8853
8973
} break ;
8974
+ case GGML_OP_RMS_NORM :
8975
+ {
8976
+ GGML_ASSERT (false); // TODO: not implemented
8977
+ } break ;
8854
8978
case GGML_OP_MUL_MAT :
8855
8979
{
8856
8980
if (src0 -> grad ) {
0 commit comments