@@ -3099,21 +3099,13 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
3099
3099
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
3100
3100
3101
3101
#if defined(__ARM_FEATURE_DOTPROD )
3102
- const float32x4_t x0_0d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_0 -> d ));
3103
- const float32x4_t x0_1d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_1 -> d ));
3104
- const float32x4_t x1_0d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_0 -> d ));
3105
- const float32x4_t x1_1d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_1 -> d ));
3102
+ sumv0 = vmlaq_n_f32 (sumv0 , vaddq_f32 (
3103
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), GGML_FP16_TO_FP32 (x0_0 -> d )),
3104
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), GGML_FP16_TO_FP32 (x0_1 -> d ))), y0 -> d );
3106
3105
3107
- const float32x4_t y0d = vdupq_n_f32 (y0 -> d );
3108
- const float32x4_t y1d = vdupq_n_f32 (y1 -> d );
3109
-
3110
- sumv0 = vaddq_f32 (sumv0 , vmulq_f32 (y0d , vaddq_f32 (
3111
- vmulq_f32 (x0_0d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ))),
3112
- vmulq_f32 (x0_1d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h ))))));
3113
-
3114
- sumv1 = vaddq_f32 (sumv1 , vmulq_f32 (y1d , vaddq_f32 (
3115
- vmulq_f32 (x1_0d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ))),
3116
- vmulq_f32 (x1_1d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h ))))));
3106
+ sumv1 = vmlaq_n_f32 (sumv1 , vaddq_f32 (
3107
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), GGML_FP16_TO_FP32 (x1_0 -> d )),
3108
+ vmulq_n_f32 (vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), GGML_FP16_TO_FP32 (x1_1 -> d ))), y1 -> d );
3117
3109
#else
3118
3110
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
3119
3111
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
@@ -3130,12 +3122,13 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
3130
3122
const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3131
3123
const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
3132
3124
3133
- sumv0 = vaddq_f32 (sumv0 , vmulq_f32 (vdupq_n_f32 (y0 -> d ), vaddq_f32 (
3134
- vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_0 -> d )), vcvtq_f32_s32 (pl0 )),
3135
- vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_1 -> d )), vcvtq_f32_s32 (ph0 )))));
3136
- sumv1 = vaddq_f32 (sumv1 , vmulq_f32 (vdupq_n_f32 (y1 -> d ), vaddq_f32 (
3137
- vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_0 -> d )), vcvtq_f32_s32 (pl1 )),
3138
- vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_1 -> d )), vcvtq_f32_s32 (ph1 )))));
3125
+ sumv0 = vmlaq_n_f32 (sumv0 , vaddq_f32 (
3126
+ vmulq_n_f32 (vcvtq_f32_s32 (pl0 ), GGML_FP16_TO_FP32 (x0_0 -> d )),
3127
+ vmulq_n_f32 (vcvtq_f32_s32 (ph0 ), GGML_FP16_TO_FP32 (x0_1 -> d ))), y0 -> d );
3128
+
3129
+ sumv1 = vmlaq_n_f32 (sumv1 , vaddq_f32 (
3130
+ vmulq_n_f32 (vcvtq_f32_s32 (pl1 ), GGML_FP16_TO_FP32 (x1_0 -> d )),
3131
+ vmulq_n_f32 (vcvtq_f32_s32 (ph1 ), GGML_FP16_TO_FP32 (x1_1 -> d ))), y1 -> d );
3139
3132
#endif
3140
3133
}
3141
3134
0 commit comments