Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -218,21 +218,7 @@ void kernel_impl(
if constexpr (has_clamp) {
res = clamp(res, clamp_min, clamp_max);
}

// Store result
int remaining = n - n_idx;
float* store_loc = output + m_idx * output_m_stride + n_idx;
if (remaining >= 4) {
vst1q_f32(store_loc, res);
} else if (remaining >= 3) {
vst1_f32(store_loc, vget_low_f32(res));
*(store_loc + 2) = res[2];
} else if (remaining >= 2) {
vst1_f32(store_loc, vget_low_f32(res));
} else {
*(store_loc) = res[0];
}

vst1q_f32(output + m_idx * output_m_stride + n_idx, res);
} // n_idx
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
} // m_idx
Original file line number Diff line number Diff line change
@@ -290,34 +290,8 @@ void kernel_impl(
res_0123 = vec_clamp(res_0123, vec_min, vec_max);
res_4567 = vec_clamp(res_4567, vec_min, vec_max);
}

// Store result
int remaining = n - n_idx;
float* store_loc = output + m_idx * output_m_stride + n_idx;
if (remaining >= 8) {
vst1q_f32(store_loc, res_0123);
vst1q_f32(store_loc + 4, res_4567);
} else if (remaining >= 7) {
vst1q_f32(store_loc, res_0123);
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
*(store_loc + 6) = res_4567[2];
} else if (remaining >= 6) {
vst1q_f32(store_loc, res_0123);
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
} else if (remaining >= 5) {
vst1q_f32(store_loc, res_0123);
*(store_loc + 4) = res_4567[0];
} else if (remaining >= 4) {
vst1q_f32(store_loc, res_0123);
} else if (remaining >= 3) {
vst1_f32(store_loc, vget_low_f32(res_0123));
*(store_loc + 2) = res_0123[2];
} else if (remaining >= 2) {
vst1_f32(store_loc, vget_low_f32(res_0123));
} else {
*store_loc = res_0123[0];
}

vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123);
vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567);
} // n_idx
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
} // m_idx
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <cassert>

int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum(
const int8_t* vals,
int size) {
assert(size >= 1);

int32_t res = 0;
int i = 0;

#pragma unroll(4)
for (; i + 15 < size; i += 16) {
for (; i < size; i += 16) {
int8x16_t vec_vals = vld1q_s8(vals + i);
res += (int)(vaddlvq_s8(vec_vals));
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,23 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <cassert>

void torchao::kernels::cpu::aarch64::reduction::find_min_and_max(
float32_t& min,
float32_t& max,
const float32_t* vals,
int size) {
assert(size > 0);

// Needed in case size < 4 so we don't compare to
// uninitialized min/max values
min = vals[0];
max = min;

float32x4_t mins = vdupq_n_f32(0.0);
float32x4_t maxes = vdupq_n_f32(0.0);
int i = 0;
if (i + 3 < size) {
float32x4_t mins = vld1q_f32(vals + i);
float32x4_t maxes = mins;
i += 4;
for (; i + 3 < size; i += 4) {
float32x4_t v = vld1q_f32(vals + i);
mins = vminq_f32(mins, v);
maxes = vmaxq_f32(maxes, v);
}
min = vminvq_f32(mins);
max = vmaxvq_f32(maxes);
for (; i < size; i += 8) {
float32x4_t v1 = vld1q_f32(vals + i);
float32x4_t v2 = vld1q_f32(vals + i + 4);
mins = vminq_f32(v1, v2);
maxes = vmaxq_f32(v1, v2);
}
min = vminvq_f32(mins);
max = vmaxvq_f32(maxes);

// Remainder
while (i < size) {
9 changes: 0 additions & 9 deletions torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -35,14 +35,6 @@ target_link_libraries(
dep
)

add_executable(test_reduction test_reduction.cpp)
target_link_libraries(
test_reduction
PRIVATE
GTest::gtest_main
dep
)

add_executable(test_bitpacking test_bitpacking.cpp)
target_link_libraries(
test_bitpacking
@@ -69,7 +61,6 @@ target_link_libraries(

include(GoogleTest)
gtest_discover_tests(test_quantization)
gtest_discover_tests(test_reduction)
gtest_discover_tests(test_bitpacking)
gtest_discover_tests(test_linear)
gtest_discover_tests(test_valpacking)
Original file line number Diff line number Diff line change
@@ -7,8 +7,7 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e
cmake --build ${CMAKE_OUT}

# Run
${CMAKE_OUT}/test_quantization
${CMAKE_OUT}/test_reduction
${CMAKE_OUT}/test_bitpacking
${CMAKE_OUT}/test_linear
${CMAKE_OUT}/test_valpacking
${CMAKE_OUT}/test_quantization
${CMAKE_OUT}/test_bitpacking
${CMAKE_OUT}/test_linear
${CMAKE_OUT}/test_valpacking
233 changes: 129 additions & 104 deletions torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp
Original file line number Diff line number Diff line change
@@ -10,11 +10,12 @@
float kTol = 0.0001;

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
int m,
int k,
int n,
int group_size) {
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot() {
int m = 7;
int k = 128;
int n = 13;
int group_size = 32;

auto test_case = torchao::
channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate(
m,
@@ -49,7 +50,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot
test_case.weight_scales.data(),
/*weight_zeros=*/test_case.weight_zeros.data());

std::vector<float> output(m * n);
std::vector<float> output(m * k);
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
output.data(),
/*output_m_stride=*/n,
@@ -71,53 +72,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot
TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot,
Standard) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot,
HasWeightZeros) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = true;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot<
4 /*weight_nbit*/,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot,
HasBias) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = true;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot,
HasClamp) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = true;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
int m,
int k,
int n,
int group_size) {
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot() {
int m = 7;
int k = 64;
int n = 13;
int group_size = 16;

auto test_case = torchao::
channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate(
m,
@@ -152,7 +170,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot
test_case.weight_scales.data(),
/*weight_zeros=*/test_case.weight_zeros.data());

std::vector<float> output(m * n);
std::vector<float> output(m * k);
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
output.data(),
/*output_m_stride=*/n,
@@ -174,66 +192,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot
TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot,
Standard) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot,
HasWeightZeros) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = true;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
4 /*weight_nbit*/,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot,
HasBias) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = true;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot,
HasClamp) {
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
}
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = true;

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot,
NLessThan4) {
for (int n = 1; n < 4; n++) {
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16);
}
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot<
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
int m,
int k,
int n,
int group_size) {
void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot() {
int m = 7;
int k = 64;
int n = 13;
int group_size = 16;

auto test_case = torchao::
channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate(
m,
@@ -268,7 +290,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot
test_case.weight_scales.data(),
/*weight_zeros=*/test_case.weight_zeros.data());

std::vector<float> output(m * n);
std::vector<float> output(m * k);
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
output.data(),
/*output_m_stride=*/n,
@@ -290,56 +312,59 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot
TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot,
Standard) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot,
HasWeightZeros) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = true;
constexpr bool has_bias = false;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
4 /*weight_nbit*/,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot,
HasBias) {
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = true;
constexpr bool has_clamp = false;

test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
true /*has_bias*/,
false /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot,
HasClamp) {
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16);
}
constexpr int weight_nbit = 4;
constexpr bool has_weight_zeros = false;
constexpr bool has_bias = false;
constexpr bool has_clamp = true;

TEST(
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot,
NLessThan8) {
for (int n = 1; n < 8; n++) {
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
4 /*weight_nbit*/,
false /*has_weight_zeros*/,
false /*has_bias*/,
true /*has_clamp*/>(
/*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16);
}
test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot<
weight_nbit,
has_weight_zeros,
has_bias,
has_clamp>();
}