Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 108 additions & 14 deletions extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,88 @@ void _q_at_k_gemm(
}
}

// Refactor op_dequantize.cpp to avoid code duplication
void dequantize_optimized(
const int8_t* in,
const float scale,
const int8_t zero_point,
float* out,
int64_t quant_min,
int64_t quant_max,
size_t numel) {
size_t i = 0;
#if defined(__aarch64__) || defined(__ARM_NEON)
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
constexpr int32_t kVecSize = 16;
const size_t num_vecs = numel / kVecSize;
const int8_t* in_copy = in;
float* out_copy = out;
for (; i < num_vecs; i++) {
int8x16_t in_vec = vld1q_s8(in_copy);
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);

int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
vst1q_f32(out_copy + 0, out_vec_0_3);
vst1q_f32(out_copy + 4, out_vec_4_7);
vst1q_f32(out_copy + 8, out_vec_8_11);
vst1q_f32(out_copy + 12, out_vec_12_15);
in_copy += kVecSize;
out_copy += kVecSize;
}
i = i * kVecSize;
#endif
for (; i < numel; i++) {
out[i] = (static_cast<int16_t>(in[i]) - static_cast<int16_t>(zero_point)) *
scale;
}
}

void dequantize_per_channel_optimized(
const int8_t* in_data,
const float* scales_data,
const int8_t* zero_points_data,
float* out_data,
int64_t quant_min,
int64_t quant_max,
size_t outer_size,
size_t in_outer_stride,
size_t out_outer_stride,
size_t num_channels,
size_t in_channel_stride,
size_t out_channel_stride,
size_t channel_size,
size_t qparams_stride) {
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
// Loop through dim
for (size_t channel_idx = 0; channel_idx < num_channels; ++channel_idx) {
const int8_t* in_data_local = in_data + outer_idx * in_outer_stride +
channel_idx * in_channel_stride;
const float scale = *(scales_data + channel_idx * qparams_stride);
const int8_t zero_point =
*(zero_points_data + channel_idx * qparams_stride);
float* out_data_local = out_data + outer_idx * out_outer_stride +
channel_idx * out_channel_stride;
dequantize_optimized(
in_data_local,
scale,
zero_point,
out_data_local,
quant_min,
quant_max,
channel_size);
}
}
}

template <typename accum_t>
void _qk_at_v_gemm(
const int64_t m,
Expand All @@ -134,24 +216,36 @@ void _qk_at_v_gemm(
const accum_t beta) {
if (v_data.dtype == ScalarType::Char) {
if constexpr (std::is_same<accum_t, float>::value) {
int a_stride_m_tmp, b_stride_n_tmp;
auto kernel = torchao::kernels::cpu::quantized_matmul::
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp);
kernel(
m,
std::vector<float> dequantized_v_data(v_data.m * v_data.n);
dequantize_per_channel_optimized(
static_cast<const int8_t*>(v_data.data),
static_cast<const float*>(v_data.scales),
static_cast<const int8_t*>(v_data.zero_points),
dequantized_v_data.data(),
-128,
127,
1,
0,
0,
v_data.m,
v_stride_n,
v_data.n,
v_data.n,
v_data.zero_points_stride);
::executorch::cpublas::gemm(
::executorch::cpublas::TransposeType::NoTranspose,
::executorch::cpublas::TransposeType::NoTranspose,
n,
m,
k,
static_cast<accum_t>(1),
dequantized_v_data.data(),
v_data.n,
qk_data,
qk_stride_m /*lhs_stride_m*/,
static_cast<const int8_t*>(v_data.data),
v_stride_n /*rhs_stride_n*/,
o_data,
o_stride_m /*out_stride_n*/,
static_cast<const int8_t*>(v_data.zero_points),
static_cast<const float*>(v_data.scales),
qk_stride_m,
beta,
v_data.zero_points_stride);
o_data,
o_stride_m);
} else {
ET_CHECK_MSG(
false, "Accumulation in dtype other than float not supported yet");
Expand Down
Loading