Skip to content

Commit 40f2c98

Browse files
hoangmitbitRAKE
authored andcommitted
Add RMS norm and use it (ggml-org#187)
* add ggml_rms_norm * update op num
1 parent e47da88 commit 40f2c98

File tree

3 files changed

+134
-5
lines changed

3 files changed

+134
-5
lines changed

ggml.c

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
20692069
"GELU",
20702070
"SILU",
20712071
"NORM",
2072+
"RMS_NORM",
20722073

20732074
"MUL_MAT",
20742075

@@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
20892090
"FLASH_FF",
20902091
};
20912092

2092-
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
2093+
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
20932094

20942095
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
20952096
"none",
@@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
21122113
"gelu(x)",
21132114
"silu(x)",
21142115
"norm(x)",
2116+
"rms_norm(x)",
21152117

21162118
"X*Y",
21172119

@@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
21322134
"flash_ff(x)",
21332135
};
21342136

2135-
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
2137+
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
21362138

21372139
//
21382140
// ggml object
@@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
36183620
return ggml_norm_impl(ctx, a, true);
36193621
}
36203622

3623+
struct ggml_tensor * ggml_rms_norm_impl(
3624+
struct ggml_context * ctx,
3625+
struct ggml_tensor * a,
3626+
bool inplace) {
3627+
bool is_node = false;
3628+
3629+
if (!inplace && (a->grad)) {
3630+
GGML_ASSERT(false); // TODO: implement backward
3631+
is_node = true;
3632+
}
3633+
3634+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3635+
3636+
result->op = GGML_OP_RMS_NORM;
3637+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3638+
result->src0 = a;
3639+
result->src1 = NULL; // TODO: maybe store epsilon here?
3640+
3641+
return result;
3642+
}
3643+
3644+
struct ggml_tensor * ggml_rms_norm(
3645+
struct ggml_context * ctx,
3646+
struct ggml_tensor * a) {
3647+
return ggml_rms_norm_impl(ctx, a, false);
3648+
}
3649+
3650+
struct ggml_tensor * ggml_rms_norm_inplace(
3651+
struct ggml_context * ctx,
3652+
struct ggml_tensor * a) {
3653+
return ggml_rms_norm_impl(ctx, a, true);
3654+
}
3655+
36213656
// ggml_mul_mat
36223657

36233658
struct ggml_tensor * ggml_mul_mat(
@@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm(
54065441
}
54075442
}
54085443

5444+
static void ggml_compute_forward_rms_norm_f32(
5445+
const struct ggml_compute_params * params,
5446+
const struct ggml_tensor * src0,
5447+
struct ggml_tensor * dst) {
5448+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
5449+
5450+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5451+
return;
5452+
}
5453+
5454+
GGML_ASSERT(src0->nb[0] == sizeof(float));
5455+
5456+
const int ith = params->ith;
5457+
const int nth = params->nth;
5458+
5459+
const int ne00 = src0->ne[0];
5460+
const int ne01 = src0->ne[1];
5461+
const int ne02 = src0->ne[2];
5462+
const int ne03 = src0->ne[3];
5463+
5464+
const size_t nb01 = src0->nb[1];
5465+
const size_t nb02 = src0->nb[2];
5466+
const size_t nb03 = src0->nb[3];
5467+
5468+
const size_t nb1 = dst->nb[1];
5469+
const size_t nb2 = dst->nb[2];
5470+
const size_t nb3 = dst->nb[3];
5471+
5472+
const ggml_float eps = 1e-5f; // TODO: make this a parameter
5473+
5474+
// TODO: optimize
5475+
for (int i03 = 0; i03 < ne03; i03++) {
5476+
for (int i02 = 0; i02 < ne02; i02++) {
5477+
for (int i01 = ith; i01 < ne01; i01 += nth) {
5478+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5479+
5480+
ggml_float mean = 0.0;
5481+
for (int i00 = 0; i00 < ne00; i00++) {
5482+
mean += x[i00] * x[i00];
5483+
}
5484+
5485+
mean /= ne00;
5486+
5487+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5488+
5489+
memcpy(y, x, ne00 * sizeof(float));
5490+
// for (int i00 = 0; i00 < ne00; i00++) {
5491+
// y[i00] = x[i00];
5492+
// }
5493+
5494+
const float scale = 1.0/sqrt(mean + eps);
5495+
5496+
ggml_vec_scale_f32(ne00, y, scale);
5497+
}
5498+
}
5499+
}
5500+
}
5501+
5502+
static void ggml_compute_forward_rms_norm(
5503+
const struct ggml_compute_params * params,
5504+
const struct ggml_tensor * src0,
5505+
struct ggml_tensor * dst) {
5506+
switch (src0->type) {
5507+
case GGML_TYPE_F32:
5508+
{
5509+
ggml_compute_forward_rms_norm_f32(params, src0, dst);
5510+
} break;
5511+
case GGML_TYPE_Q4_0:
5512+
case GGML_TYPE_Q4_1:
5513+
case GGML_TYPE_I8:
5514+
case GGML_TYPE_I16:
5515+
case GGML_TYPE_I32:
5516+
case GGML_TYPE_F16:
5517+
case GGML_TYPE_COUNT:
5518+
{
5519+
GGML_ASSERT(false);
5520+
} break;
5521+
}
5522+
}
5523+
5524+
54095525
// ggml_compute_forward_mul_mat
54105526

54115527
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
85228638
{
85238639
ggml_compute_forward_norm(params, tensor->src0, tensor);
85248640
} break;
8641+
case GGML_OP_RMS_NORM:
8642+
{
8643+
ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
8644+
} break;
85258645
case GGML_OP_MUL_MAT:
85268646
{
85278647
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
87648884
{
87658885
GGML_ASSERT(false); // TODO: not implemented
87668886
} break;
8887+
case GGML_OP_RMS_NORM:
8888+
{
8889+
GGML_ASSERT(false); // TODO: not implemented
8890+
} break;
87678891
case GGML_OP_MUL_MAT:
87688892
{
87698893
if (src0->grad) {

ggml.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ enum ggml_op {
230230
GGML_OP_GELU,
231231
GGML_OP_SILU,
232232
GGML_OP_NORM, // normalize
233+
GGML_OP_RMS_NORM,
233234

234235
GGML_OP_MUL_MAT,
235236

@@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
482483
struct ggml_context * ctx,
483484
struct ggml_tensor * a);
484485

486+
struct ggml_tensor * ggml_rms_norm(
487+
struct ggml_context * ctx,
488+
struct ggml_tensor * a);
489+
485490
// A: m rows, n columns
486491
// B: p rows, n columns (i.e. we transpose it internally)
487492
// result is m columns, p rows

main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ bool llama_eval(
588588

589589
// norm
590590
{
591-
cur = ggml_norm(ctx0, inpL);
591+
cur = ggml_rms_norm(ctx0, inpL);
592592

593593
// cur = attention_norm*cur
594594
cur = ggml_mul(ctx0,
@@ -678,7 +678,7 @@ bool llama_eval(
678678
{
679679
// norm
680680
{
681-
cur = ggml_norm(ctx0, inpFF);
681+
cur = ggml_rms_norm(ctx0, inpFF);
682682

683683
// cur = ffn_norm*cur
684684
cur = ggml_mul(ctx0,
@@ -713,7 +713,7 @@ bool llama_eval(
713713

714714
// norm
715715
{
716-
inpL = ggml_norm(ctx0, inpL);
716+
inpL = ggml_rms_norm(ctx0, inpL);
717717

718718
// inpL = norm*inpL
719719
inpL = ggml_mul(ctx0,

0 commit comments

Comments
 (0)