@@ -3057,8 +3057,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
3057
3057
float sumf = 0.0 ;
3058
3058
3059
3059
#if defined(__ARM_NEON )
3060
- float sum0 = 0.0f ;
3061
- float sum1 = 0.0f ;
3060
+ float32x4_t sumv0 = vdupq_n_f32 ( 0.0f ) ;
3061
+ float32x4_t sumv1 = vdupq_n_f32 ( 0.0f ) ;
3062
3062
3063
3063
for (int i = 0 ; i < nb ; i += 2 ) {
3064
3064
const block_q4_2 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
@@ -3099,10 +3099,21 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
3099
3099
const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
3100
3100
3101
3101
#if defined(__ARM_FEATURE_DOTPROD )
3102
- sum0 += (GGML_FP16_TO_FP32 (x0_0 -> d )* y0 -> d )* vaddvq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ));
3103
- sum0 += (GGML_FP16_TO_FP32 (x0_1 -> d )* y0 -> d )* vaddvq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h ));
3104
- sum1 += (GGML_FP16_TO_FP32 (x1_0 -> d )* y1 -> d )* vaddvq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ));
3105
- sum1 += (GGML_FP16_TO_FP32 (x1_1 -> d )* y1 -> d )* vaddvq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h ));
3102
+ const float32x4_t x0_0d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_0 -> d ));
3103
+ const float32x4_t x0_1d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_1 -> d ));
3104
+ const float32x4_t x1_0d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_0 -> d ));
3105
+ const float32x4_t x1_1d = vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_1 -> d ));
3106
+
3107
+ const float32x4_t y0d = vdupq_n_f32 (y0 -> d );
3108
+ const float32x4_t y1d = vdupq_n_f32 (y1 -> d );
3109
+
3110
+ sumv0 = vaddq_f32 (sumv0 , vmulq_f32 (y0d , vaddq_f32 (
3111
+ vmulq_f32 (x0_0d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l ))),
3112
+ vmulq_f32 (x0_1d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h ))))));
3113
+
3114
+ sumv1 = vaddq_f32 (sumv1 , vmulq_f32 (y1d , vaddq_f32 (
3115
+ vmulq_f32 (x1_0d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l ))),
3116
+ vmulq_f32 (x1_1d , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h ))))));
3106
3117
#else
3107
3118
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
3108
3119
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
@@ -3119,14 +3130,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
3119
3130
const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3120
3131
const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
3121
3132
3122
- sum0 += (GGML_FP16_TO_FP32 (x0_0 -> d )* y0 -> d )* vaddvq_s32 (pl0 );
3123
- sum0 += (GGML_FP16_TO_FP32 (x0_1 -> d )* y0 -> d )* vaddvq_s32 (ph0 );
3124
- sum1 += (GGML_FP16_TO_FP32 (x1_0 -> d )* y1 -> d )* vaddvq_s32 (pl1 );
3125
- sum1 += (GGML_FP16_TO_FP32 (x1_1 -> d )* y1 -> d )* vaddvq_s32 (ph1 );
3133
+ sumv0 = vaddq_f32 (sumv0 , vmulq_f32 (vdupq_n_f32 (y0 -> d ), vaddq_f32 (
3134
+ vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_0 -> d )), vcvtq_f32_s32 (pl0 )),
3135
+ vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x0_1 -> d )), vcvtq_f32_s32 (ph0 )))));
3136
+ sumv1 = vaddq_f32 (sumv1 , vmulq_f32 (vdupq_n_f32 (y1 -> d ), vaddq_f32 (
3137
+ vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_0 -> d )), vcvtq_f32_s32 (pl1 )),
3138
+ vmulq_f32 (vdupq_n_f32 (GGML_FP16_TO_FP32 (x1_1 -> d )), vcvtq_f32_s32 (ph1 )))));
3126
3139
#endif
3127
3140
}
3128
3141
3129
- sumf = sum0 + sum1 ;
3142
+ sumf = vaddvq_f32 ( sumv0 ) + vaddvq_f32 ( sumv1 ) ;
3130
3143
#else
3131
3144
// scalar
3132
3145
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments