Skip to content

Commit 9cd27fa

Browse files
hoangmitblackhole89
authored andcommitted
Add RMS norm and use it (#187)
* add ggml_rms_norm * update op num
1 parent 23ac443 commit 9cd27fa

File tree

3 files changed

+134
-5
lines changed

3 files changed

+134
-5
lines changed

ggml.c

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,6 +2156,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
21562156
"GELU",
21572157
"SILU",
21582158
"NORM",
2159+
"RMS_NORM",
21592160

21602161
"MUL_MAT",
21612162

@@ -2176,7 +2177,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
21762177
"FLASH_FF",
21772178
};
21782179

2179-
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
2180+
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
21802181

21812182
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
21822183
"none",
@@ -2199,6 +2200,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
21992200
"gelu(x)",
22002201
"silu(x)",
22012202
"norm(x)",
2203+
"rms_norm(x)",
22022204

22032205
"X*Y",
22042206

@@ -2219,7 +2221,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
22192221
"flash_ff(x)",
22202222
};
22212223

2222-
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
2224+
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
22232225

22242226
//
22252227
// ggml object
@@ -3705,6 +3707,39 @@ struct ggml_tensor * ggml_norm_inplace(
37053707
return ggml_norm_impl(ctx, a, true);
37063708
}
37073709

3710+
struct ggml_tensor * ggml_rms_norm_impl(
3711+
struct ggml_context * ctx,
3712+
struct ggml_tensor * a,
3713+
bool inplace) {
3714+
bool is_node = false;
3715+
3716+
if (!inplace && (a->grad)) {
3717+
GGML_ASSERT(false); // TODO: implement backward
3718+
is_node = true;
3719+
}
3720+
3721+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3722+
3723+
result->op = GGML_OP_RMS_NORM;
3724+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3725+
result->src0 = a;
3726+
result->src1 = NULL; // TODO: maybe store epsilon here?
3727+
3728+
return result;
3729+
}
3730+
3731+
struct ggml_tensor * ggml_rms_norm(
3732+
struct ggml_context * ctx,
3733+
struct ggml_tensor * a) {
3734+
return ggml_rms_norm_impl(ctx, a, false);
3735+
}
3736+
3737+
struct ggml_tensor * ggml_rms_norm_inplace(
3738+
struct ggml_context * ctx,
3739+
struct ggml_tensor * a) {
3740+
return ggml_rms_norm_impl(ctx, a, true);
3741+
}
3742+
37083743
// ggml_mul_mat
37093744

37103745
struct ggml_tensor * ggml_mul_mat(
@@ -5493,6 +5528,87 @@ static void ggml_compute_forward_norm(
54935528
}
54945529
}
54955530

5531+
static void ggml_compute_forward_rms_norm_f32(
5532+
const struct ggml_compute_params * params,
5533+
const struct ggml_tensor * src0,
5534+
struct ggml_tensor * dst) {
5535+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
5536+
5537+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5538+
return;
5539+
}
5540+
5541+
GGML_ASSERT(src0->nb[0] == sizeof(float));
5542+
5543+
const int ith = params->ith;
5544+
const int nth = params->nth;
5545+
5546+
const int ne00 = src0->ne[0];
5547+
const int ne01 = src0->ne[1];
5548+
const int ne02 = src0->ne[2];
5549+
const int ne03 = src0->ne[3];
5550+
5551+
const size_t nb01 = src0->nb[1];
5552+
const size_t nb02 = src0->nb[2];
5553+
const size_t nb03 = src0->nb[3];
5554+
5555+
const size_t nb1 = dst->nb[1];
5556+
const size_t nb2 = dst->nb[2];
5557+
const size_t nb3 = dst->nb[3];
5558+
5559+
const ggml_float eps = 1e-5f; // TODO: make this a parameter
5560+
5561+
// TODO: optimize
5562+
for (int i03 = 0; i03 < ne03; i03++) {
5563+
for (int i02 = 0; i02 < ne02; i02++) {
5564+
for (int i01 = ith; i01 < ne01; i01 += nth) {
5565+
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5566+
5567+
ggml_float mean = 0.0;
5568+
for (int i00 = 0; i00 < ne00; i00++) {
5569+
mean += x[i00] * x[i00];
5570+
}
5571+
5572+
mean /= ne00;
5573+
5574+
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5575+
5576+
memcpy(y, x, ne00 * sizeof(float));
5577+
// for (int i00 = 0; i00 < ne00; i00++) {
5578+
// y[i00] = x[i00];
5579+
// }
5580+
5581+
const float scale = 1.0/sqrt(mean + eps);
5582+
5583+
ggml_vec_scale_f32(ne00, y, scale);
5584+
}
5585+
}
5586+
}
5587+
}
5588+
5589+
static void ggml_compute_forward_rms_norm(
5590+
const struct ggml_compute_params * params,
5591+
const struct ggml_tensor * src0,
5592+
struct ggml_tensor * dst) {
5593+
switch (src0->type) {
5594+
case GGML_TYPE_F32:
5595+
{
5596+
ggml_compute_forward_rms_norm_f32(params, src0, dst);
5597+
} break;
5598+
case GGML_TYPE_Q4_0:
5599+
case GGML_TYPE_Q4_1:
5600+
case GGML_TYPE_I8:
5601+
case GGML_TYPE_I16:
5602+
case GGML_TYPE_I32:
5603+
case GGML_TYPE_F16:
5604+
case GGML_TYPE_COUNT:
5605+
{
5606+
GGML_ASSERT(false);
5607+
} break;
5608+
}
5609+
}
5610+
5611+
54965612
// ggml_compute_forward_mul_mat
54975613

54985614
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -8609,6 +8725,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
86098725
{
86108726
ggml_compute_forward_norm(params, tensor->src0, tensor);
86118727
} break;
8728+
case GGML_OP_RMS_NORM:
8729+
{
8730+
ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
8731+
} break;
86128732
case GGML_OP_MUL_MAT:
86138733
{
86148734
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -8851,6 +8971,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
88518971
{
88528972
GGML_ASSERT(false); // TODO: not implemented
88538973
} break;
8974+
case GGML_OP_RMS_NORM:
8975+
{
8976+
GGML_ASSERT(false); // TODO: not implemented
8977+
} break;
88548978
case GGML_OP_MUL_MAT:
88558979
{
88568980
if (src0->grad) {

ggml.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ enum ggml_op {
230230
GGML_OP_GELU,
231231
GGML_OP_SILU,
232232
GGML_OP_NORM, // normalize
233+
GGML_OP_RMS_NORM,
233234

234235
GGML_OP_MUL_MAT,
235236

@@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
482483
struct ggml_context * ctx,
483484
struct ggml_tensor * a);
484485

486+
struct ggml_tensor * ggml_rms_norm(
487+
struct ggml_context * ctx,
488+
struct ggml_tensor * a);
489+
485490
// A: m rows, n columns
486491
// B: p rows, n columns (i.e. we transpose it internally)
487492
// result is m columns, p rows

main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ bool llama_eval(
588588

589589
// norm
590590
{
591-
cur = ggml_norm(ctx0, inpL);
591+
cur = ggml_rms_norm(ctx0, inpL);
592592

593593
// cur = attention_norm*cur
594594
cur = ggml_mul(ctx0,
@@ -678,7 +678,7 @@ bool llama_eval(
678678
{
679679
// norm
680680
{
681-
cur = ggml_norm(ctx0, inpFF);
681+
cur = ggml_rms_norm(ctx0, inpFF);
682682

683683
// cur = ffn_norm*cur
684684
cur = ggml_mul(ctx0,
@@ -713,7 +713,7 @@ bool llama_eval(
713713

714714
// norm
715715
{
716-
inpL = ggml_norm(ctx0, inpL);
716+
inpL = ggml_rms_norm(ctx0, inpL);
717717

718718
// inpL = norm*inpL
719719
inpL = ggml_mul(ctx0,

0 commit comments

Comments
 (0)