diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 1f19fa75de7..76dbf776700 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -120,6 +120,88 @@ void _q_at_k_gemm( } } +// Refactor op_dequantize.cpp to avoid code duplication +void dequantize_optimized( + const int8_t* in, + const float scale, + const int8_t zero_point, + float* out, + int64_t quant_min, + int64_t quant_max, + size_t numel) { + size_t i = 0; +#if defined(__aarch64__) || defined(__ARM_NEON) + int8x8_t zero_point_vec = vdup_n_s8(zero_point); + float32x4_t scales = vdupq_n_f32(static_cast(scale)); + constexpr int32_t kVecSize = 16; + const size_t num_vecs = numel / kVecSize; + const int8_t* in_copy = in; + float* out_copy = out; + for (; i < num_vecs; i++) { + int8x16_t in_vec = vld1q_s8(in_copy); + int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7)); + int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7)); + float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales); + float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales); + + int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec); + int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15)); + int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15)); + float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales); + float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales); + vst1q_f32(out_copy + 0, out_vec_0_3); + vst1q_f32(out_copy + 4, out_vec_4_7); + vst1q_f32(out_copy + 8, out_vec_8_11); + vst1q_f32(out_copy + 12, out_vec_12_15); + in_copy += kVecSize; + out_copy += kVecSize; + } + i = i * kVecSize; +#endif + for (; i < numel; i++) { + out[i] = (static_cast(in[i]) - static_cast(zero_point)) * + scale; + } +} + +void dequantize_per_channel_optimized( + const int8_t* in_data, + const float* scales_data, + const int8_t* zero_points_data, + float* out_data, + int64_t quant_min, + int64_t quant_max, + size_t outer_size, + size_t in_outer_stride, + size_t out_outer_stride, + size_t num_channels, + size_t in_channel_stride, + size_t out_channel_stride, + size_t channel_size, + size_t qparams_stride) { + for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + // Loop through dim + for (size_t channel_idx = 0; channel_idx < num_channels; ++channel_idx) { + const int8_t* in_data_local = in_data + outer_idx * in_outer_stride + + channel_idx * in_channel_stride; + const float scale = *(scales_data + channel_idx * qparams_stride); + const int8_t zero_point = + *(zero_points_data + channel_idx * qparams_stride); + float* out_data_local = out_data + outer_idx * out_outer_stride + + channel_idx * out_channel_stride; + dequantize_optimized( + in_data_local, + scale, + zero_point, + out_data_local, + quant_min, + quant_max, + channel_size); + } + } +} + template void _qk_at_v_gemm( const int64_t m, @@ -134,24 +216,36 @@ void _qk_at_v_gemm( const accum_t beta) { if (v_data.dtype == ScalarType::Char) { if constexpr (std::is_same::value) { - int a_stride_m_tmp, b_stride_n_tmp; - auto kernel = torchao::kernels::cpu::quantized_matmul:: - get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( - m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp); - kernel( - m, + std::vector dequantized_v_data(v_data.m * v_data.n); + dequantize_per_channel_optimized( + static_cast(v_data.data), + static_cast(v_data.scales), + static_cast(v_data.zero_points), + dequantized_v_data.data(), + -128, + 127, + 1, + 0, + 0, + v_data.m, + v_stride_n, + v_data.n, + v_data.n, + v_data.zero_points_stride); + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::NoTranspose, + ::executorch::cpublas::TransposeType::NoTranspose, n, + m, k, + static_cast(1), + dequantized_v_data.data(), + v_data.n, qk_data, - qk_stride_m /*lhs_stride_m*/, - static_cast(v_data.data), - v_stride_n /*rhs_stride_n*/, - o_data, - o_stride_m /*out_stride_n*/, - static_cast(v_data.zero_points), - static_cast(v_data.scales), + qk_stride_m, beta, - v_data.zero_points_stride); + o_data, + o_stride_m); } else { ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet");