Skip to content

Commit 5843b45

Browse files
committed
ggml : optimize q4_2 using vmlaq_n_f32 + vmulq_n_f32
1 parent 3a79089 commit 5843b45

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

ggml.c

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,21 +3099,13 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
30993099
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
31003100

31013101
#if defined(__ARM_FEATURE_DOTPROD)
3102-
const float32x4_t x0_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d));
3103-
const float32x4_t x0_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d));
3104-
const float32x4_t x1_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d));
3105-
const float32x4_t x1_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d));
3102+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3103+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
3104+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
31063105

3107-
const float32x4_t y0d = vdupq_n_f32(y0->d);
3108-
const float32x4_t y1d = vdupq_n_f32(y1->d);
3109-
3110-
sumv0 = vaddq_f32(sumv0, vmulq_f32(y0d, vaddq_f32(
3111-
vmulq_f32(x0_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l))),
3112-
vmulq_f32(x0_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h))))));
3113-
3114-
sumv1 = vaddq_f32(sumv1, vmulq_f32(y1d, vaddq_f32(
3115-
vmulq_f32(x1_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l))),
3116-
vmulq_f32(x1_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h))))));
3106+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3107+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3108+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
31173109
#else
31183110
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
31193111
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
@@ -3130,12 +3122,13 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
31303122
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
31313123
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
31323124

3133-
sumv0 = vaddq_f32(sumv0, vmulq_f32(vdupq_n_f32(y0->d), vaddq_f32(
3134-
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d)), vcvtq_f32_s32(pl0)),
3135-
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d)), vcvtq_f32_s32(ph0)))));
3136-
sumv1 = vaddq_f32(sumv1, vmulq_f32(vdupq_n_f32(y1->d), vaddq_f32(
3137-
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d)), vcvtq_f32_s32(pl1)),
3138-
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d)), vcvtq_f32_s32(ph1)))));
3125+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3126+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3127+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3128+
3129+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3130+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3131+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
31393132
#endif
31403133
}
31413134

0 commit comments

Comments
 (0)