diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt new file mode 100644 index 0000000000..198e9ebd44 --- /dev/null +++ b/torchao/experimental/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao) + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_CXX_STANDARD 17) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + + +# Source root directory for torchao/experimental +if(NOT TORCHAO_ROOT) + set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) +endif() + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${TORCHAO_ROOT}/../..) +endif() + +if (NOT TORCHAO_PARALLEL_BACKEND) + if (TORCHAO_OP_TARGET STREQUAL "ATEN") + set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP") + elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + set(TORCHAO_PARALLEL_BACKEND "PTHREADPOOL") + else() + message(TORCHAO_PARALLEL_BACKEND "TORCHAO_PARALLEL_BACKEND is not set. Please set it directly or set TORCHAO_OP_TARGET to get a default.") + endif() +endif() + +include(CMakePrintHelpers) + +add_compile_options("-Wall" "-Werror") + +include(CMakePrintHelpers) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) + +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + # Defines target torchao_kernels_aarch64 + add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64) + add_subdirectory(${TORCHAO_ROOT}/ops/linear) + add_subdirectory(${TORCHAO_ROOT}/ops/linear/linear_a8wxdq_op) +endif() diff --git a/torchao/experimental/kernels/cpu/Utils.cmake b/torchao/experimental/Utils.cmake similarity index 100% rename from torchao/experimental/kernels/cpu/Utils.cmake rename to torchao/experimental/Utils.cmake diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/build_torchao_ops.sh similarity index 51% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh rename to torchao/experimental/build_torchao_ops.sh index c657857fcc..de6d8e17d8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -5,15 +5,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torchao -cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -DPLATFORM="ATEN" \ - -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_OP_TARGET="$1" \ + -DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \ + -DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \ + -S . \ -B ${CMAKE_OUT} -cmake --build ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} --target install --config Release diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index a13737d874..ec497a1871 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,10 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -add_library( - kernel_aarch64 - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + add_library( + torchao_kernels_aarch64 + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ) +endif() diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt deleted file mode 100644 index 61e5eeae27..0000000000 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) -target_link_libraries( - benchmark_linear_operator - PRIVATE - benchmark::benchmark - dep -) - -option(TORCHAO_PARALLEL_OMP "" OFF) -option(TORCHAO_PARALLEL_SINGLE_THREADED "" ON) - -if (TORCHAO_PARALLEL_OMP) - message("OpenMP_ROOT: ${OpenMP_ROOT}") - add_definitions(-DTORCHAO_PARALLEL_OMP=1) - find_package(OpenMP REQUIRED) - if(OpenMP_CXX_FOUND) - target_link_libraries(benchmark_linear_operator PUBLIC OpenMP::OpenMP_CXX) - endif() -endif() - -if (TORCHAO_PARALLEL_SINGLE_THREADED) - add_definitions(-DTORCHAO_PARALLEL_SINGLE_THREADED=1) -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt deleted file mode 100644 index 4489dc7c36..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(examples) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -add_executable(separate_function_wrappers separate_function_wrappers.cpp) -target_link_libraries( - separate_function_wrappers - PRIVATE - kernel_aarch64 -) - -add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) -target_link_libraries( - stateful_class_wrapper - PRIVATE - kernel_aarch64 -) - -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) - -target_link_torchao_parallel_backend(stateful_class_wrapper "openmp") -target_link_torchao_parallel_backend(separate_function_wrappers "openmp") diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt deleted file mode 100644 index 10e44a79a8..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(torch_custom_op) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") -include_directories(${TORCHAO_INCLUDE_DIRS}) - -add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) - -set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") -string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) - -if(PLATFORM_TO_UPPER STREQUAL "ATEN") -message(STATUS "Building with PLATFORM=ATEN") - -find_package(Torch REQUIRED) -add_library(lowbit_op_aten SHARED lowbit_op_aten.cpp) -target_link_libraries(lowbit_op_aten PRIVATE kernel_aarch64) -target_include_directories(lowbit_op_aten PRIVATE "${TORCH_INCLUDE_DIRS}") -target_link_libraries(lowbit_op_aten PRIVATE "${TORCH_LIBRARIES}") -target_compile_definitions(lowbit_op_aten PRIVATE USE_ATEN=1) -target_link_torchao_parallel_backend(lowbit_op_aten "ATEN_OPENMP") - -elseif(PLATFORM_TO_UPPER STREQUAL "EXECUTORCH") -message(STATUS "Building with PLATFORM=EXECUTORCH") - -add_library(lowbit_op_executorch SHARED - lowbit_op_executorch/w2s.cpp - lowbit_op_executorch/w2sz.cpp - lowbit_op_executorch/w3s.cpp - lowbit_op_executorch/w3sz.cpp - lowbit_op_executorch/w4s.cpp - lowbit_op_executorch/w4sz.cpp - lowbit_op_executorch/w5s.cpp - lowbit_op_executorch/w5sz.cpp -) -target_include_directories(lowbit_op_executorch PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) -target_compile_definitions(lowbit_op_executorch PRIVATE USE_EXECUTORCH=1) -target_link_torchao_parallel_backend(lowbit_op_executorch "SINGLE_THREADED") -target_link_libraries(lowbit_op_executorch PRIVATE ${EXECUTORCH_LIBRARIES}) -target_link_libraries(lowbit_op_executorch PRIVATE kernel_aarch64) - -else() -message(FATAL_ERROR "Unknown PLATFORM: ${PLATFORM}. Please choose one of: ATEN, EXECUTORCH.") -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py deleted file mode 100644 index e3d96df63c..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import glob -import os - -import sys - -import torch - -sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) -) -from quant_api import Int8DynActIntxWeightQuantizer - -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") -libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -group_size = 256 -m = 1 -n = 4096 -k = 4096 -nbit = 4 -has_weight_zeros = False -n_layers = 5 - -print("Creating random model") -layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] -model = torch.nn.Sequential(*layers) -model = model.eval() - -print("Quantizing random model") -quantized_model = copy.deepcopy(model) -quantizer = Int8DynActIntxWeightQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=nbit, - groupsize=group_size, - has_weight_zeros=has_weight_zeros, -) -quantized_model = quantizer.quantize(quantized_model) -quantized_model = quantized_model.eval() - -print("Creating random activations") -activations = torch.randn(m, k, dtype=torch.float32) - -print("Exporting quantized model") -exported = torch.export.export(quantized_model, (activations,)) - -print("Using torch.compile on quantized model") -quantized_model_compiled = torch.compile(quantized_model) -with torch.no_grad(): - quantized_model_compiled(activations) - -print("Compiling quantized model with AOTI") -torch._export.aot_compile( - quantized_model, - (activations,), - options={"aot_inductor.output_path": "/tmp/torch_custom_op_example_model.so"}, -) - -print("Running AOTI") -fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") -fn(activations) diff --git a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt deleted file mode 100644 index 3a415d8edd..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -enable_testing() - -add_definitions(-DTORCHAO_PARALLEL_TEST_DUMMY=1) -add_executable(test_linear_operator test_linear_operator.cpp) -target_link_libraries( - test_linear_operator - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh deleted file mode 100644 index ad9a855084..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/tests -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -# Run -${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/ops/linear/CMakeLists.txt b/torchao/experimental/ops/linear/CMakeLists.txt new file mode 100644 index 0000000000..2f7b91bbf9 --- /dev/null +++ b/torchao/experimental/ops/linear/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_library(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} STATIC channelwise_8bit_activation_groupwise_lowbit_weight.cpp) +target_link_torchao_parallel_backend(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..70d6bf2cba --- /dev/null +++ b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(benchmarks) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + +set(BENCHMARK_ENABLE_TESTING OFF) +FetchContent_MakeAvailable( + googlebenchmark) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) +target_link_libraries( + benchmark_linear_operator + PRIVATE + benchmark::benchmark + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(benchmark_linear_operator "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp similarity index 77% rename from torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp rename to torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp index ad6563eabe..8d7cd4a908 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp +++ b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp @@ -5,11 +5,40 @@ // LICENSE file in the root directory of this source tree. #include +#include #include -#include -#include +#include +#include +#include #include +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template static void channelwise_8bit_activation_groupwise_lowbit_weight( benchmark::State& state) { @@ -24,9 +53,6 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( int num_test_cases = state.range(5); // Initialize config and tiling params - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; - auto ukernel_config = get_ukernel_config(); auto pack_weight_data_tiling_params = @@ -66,7 +92,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( std::vector> packed_weight_data; for (int i = 0; i < test_cases.size(); i++) { - packed_weight_data.emplace_back(torchao::make_aligned_byte_array_unique_ptr( + packed_weight_data.emplace_back(torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size)); pack_weight_data_operator( ukernel_config, @@ -91,7 +117,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( size_t activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); auto output = std::vector(m * n); diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh similarity index 70% rename from torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh rename to torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh index 18da0e992d..ed80d34e2f 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh @@ -7,11 +7,9 @@ # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks +export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/benchmarks \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) \ -DTORCHAO_PARALLEL_OMP=ON diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp similarity index 83% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp index 37ad74b0f0..ae611d3ccc 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp @@ -4,18 +4,18 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#pragma once #include -#include -#include +#include +#include +#include #include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { -inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( +PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( const UKernelConfig& ukernel_config, int n, int target_panels_per_thread) { @@ -40,7 +40,7 @@ inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( return tiling_params; } -inline void pack_weight_data_operator( +void pack_weight_data_operator( const UKernelConfig& ukernel_config, const PackWeightDataTilingParams& tiling_params, // Outputs @@ -81,7 +81,7 @@ inline void pack_weight_data_operator( } // This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -inline LinearTilingParams get_default_linear_tiling_params( +LinearTilingParams get_default_linear_tiling_params( const UKernelConfig& ukernel_config, int m, int n, @@ -118,8 +118,7 @@ inline LinearTilingParams get_default_linear_tiling_params( namespace internal { -inline int -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( +inline int get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, int m, @@ -273,7 +272,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( } } // namespace internal -inline void linear_operator( +void linear_operator( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -333,7 +332,7 @@ inline void linear_operator( } } -inline int get_activation_data_buffer_size( +int get_activation_data_buffer_size( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -355,38 +354,4 @@ inline int get_activation_data_buffer_size( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different fil or namespace. This method is not part of the -// high-level interface, but specific to the universal kernels we wrote in -// torchao -#include -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template - -inline UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - // torchao::kernels::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h similarity index 81% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h index 5d8f11b821..c92c94acfb 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -7,8 +7,7 @@ #pragma once #include -// TODO: maybe move to operator directory -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { struct UKernelConfig { @@ -147,20 +146,4 @@ void linear_operator( float clamp_max); } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different file or namespace -// It is not part of the high-level interface, but specific to the universal -// kernels in torchao. -// Kleidi will need to implement their own get_ukernel_config -// In future, we may build a high-level get_ukernel_config with CPU-runtime -// selection -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template -UKernelConfig get_ukernel_config(); - -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -#include + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/ops/linear/examples/CMakeLists.txt b/torchao/experimental/ops/linear/examples/CMakeLists.txt new file mode 100644 index 0000000000..2b69adb3d8 --- /dev/null +++ b/torchao/experimental/ops/linear/examples/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(examples) + +cmake_minimum_required(VERSION 3.19) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) + +include(CMakePrintHelpers) + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_executable(separate_function_wrappers separate_function_wrappers.cpp) +target_link_libraries( + separate_function_wrappers + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") + +add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) +target_link_libraries( + stateful_class_wrapper + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h similarity index 91% rename from torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h rename to torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h index 575093f21b..a7755dadf4 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h +++ b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h @@ -5,26 +5,22 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { private: - torchao::aligned_byte_ptr packed_weight_data_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr}; int packed_weight_data_size_{0}; int packed_weight_data_alignment_{0}; - torchao::aligned_byte_ptr activation_data_buffer_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr}; int m_{0}; int n_{0}; @@ -114,7 +110,7 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config_); - + packed_weight_data_size_ = packed_weight_data_size; packed_weight_data_alignment_ = packed_weight_data_alignment; packed_weight_data_ = torchao::make_aligned_byte_ptr( @@ -199,4 +195,4 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { } }; } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh similarity index 67% rename from torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh rename to torchao/experimental/ops/linear/examples/build_and_run_examples.sh index 9c244e54cc..01185fdd3f 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh +++ b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh @@ -5,15 +5,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples \ +export CMAKE_OUT=/tmp/cmake-out/torchao/examples +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp similarity index 80% rename from torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp rename to torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp index ba3e5b29b3..144fe5c08d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp +++ b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp @@ -4,9 +4,11 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include +#include #include // This file contains an example of wrapping the torchao weight packing and // linear operators into two operators: one for weight packing and another @@ -20,9 +22,33 @@ // one stateful class, but not all surfaces support this (see // examples/stateful_class_wrapper.cpp for an example of this). -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + torchao::aligned_byte_ptr pack_weight_data_operator( UKernelConfig ukernel_config, int n, @@ -115,10 +141,10 @@ void linear_operator( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight int main() { - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; torchao::set_num_threads(8); diff --git a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp similarity index 71% rename from torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp rename to torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp index 5fb24c683d..c1cd2d110b 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp +++ b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp @@ -4,9 +4,10 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include #include #include @@ -21,9 +22,33 @@ // examples/separate_function_wrappers.cpp for an example of how to split the // operations into two steps. -using namespace torchao::operators::cpu::linear:: +using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + int main() { int m = 13; int n = 4096 + 1; @@ -54,6 +79,7 @@ int main() { std::cout << "Initializing linear_operator." << std::endl; auto ukernel_config = get_ukernel_config(); + auto linear_operator = Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator( ukernel_config, diff --git a/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt new file mode 100644 index 0000000000..f69d884cd8 --- /dev/null +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +if(TORCHAO_OP_TARGET STREQUAL "ATEN") + message(STATUS "Building with TORCHAO_OP_TARGET=ATEN") + find_package(Torch REQUIRED) + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED linear_a8wxdq_aten.cpp) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_INCLUDE_DIRS}") + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_ATEN=1) +elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + message(STATUS "Building with TORCHAO_OP_TARGET=EXECUTORCH") + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED + linear_a8wxdq_executorch/w2s.cpp + linear_a8wxdq_executorch/w2sz.cpp + linear_a8wxdq_executorch/w3s.cpp + linear_a8wxdq_executorch/w3sz.cpp + linear_a8wxdq_executorch/w4s.cpp + linear_a8wxdq_executorch/w4sz.cpp + linear_a8wxdq_executorch/w5s.cpp + linear_a8wxdq_executorch/w5sz.cpp + ) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_EXECUTORCH=1) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_LIBRARIES}) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +else() + message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: ATEN, EXECUTORCH.") +endif() + + +install( + TARGETS linear_a8wxdq_${TORCHAO_OP_TARGET} + DESTINATION lib +) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h similarity index 87% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h index 01b1836981..eee51eafc6 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h @@ -5,7 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include +#include #include #include @@ -28,6 +29,35 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; #error "Must define either USE_ATEN or USE_EXECUTORCH" #endif +namespace { + +template +inline torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight::UKernelConfig + get_ukernel_config() { + torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight:: + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + #ifdef USE_ATEN template Tensor pack_weights_cpu( @@ -69,7 +99,7 @@ Tensor pack_weights_cpu( weight_zeros_ptr = weight_zeros.value().const_data_ptr(); } - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -137,7 +167,7 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -221,7 +251,7 @@ Tensor linear_out_cpu( CHECK_MSG(out.size(1) == n, "out shape is incorrect"); #endif // USE_EXECUTORCH - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -311,3 +341,5 @@ Tensor linear_meta( return torch::empty({m, n}).to("meta"); } #endif // USE_ATEN + +} // namespace diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp index 626b3e769f..b1d464e5b5 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp index 592a0190a9..c6ef089995 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp index d2683b36ce..e569e05812 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp index d59db3e1c7..9f236bd7b3 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp index 7458311b91..24a381fdcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp index 75143050fa..67263d209d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp index 714192a19b..530ff44370 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp index 08c2d42ee8..de04a09f6a 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp index c1e3e953d3..91c5a16312 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/ops/linear/tests/CMakeLists.txt b/torchao/experimental/ops/linear/tests/CMakeLists.txt new file mode 100644 index 0000000000..866d832ccd --- /dev/null +++ b/torchao/experimental/ops/linear/tests/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(tests) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +FetchContent_MakeAvailable(googletest) +enable_testing() + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "TEST_DUMMY") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) +add_executable(test_linear_operator test_linear_operator.cpp) +target_link_libraries( + test_linear_operator + PRIVATE + GTest::gtest_main + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(test_linear_operator "${TORCHAO_PARALLEL_BACKEND}") + +include(GoogleTest) +gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/ops/linear/tests/build_and_run_tests.sh b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh new file mode 100644 index 0000000000..3fbe78c172 --- /dev/null +++ b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +export CMAKE_OUT=/tmp/cmake-out/torchao/tests +cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S . -B ${CMAKE_OUT} + +cmake --build ${CMAKE_OUT} + +# Run +${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp similarity index 78% rename from torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp rename to torchao/experimental/ops/linear/tests/test_linear_operator.cpp index 5408e426bf..6d563111cc 100644 --- a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp +++ b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp @@ -1,22 +1,52 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include // TODO: move test_utils.h out of aarch64 +#include #include -#include -#include -#include +#include +#include +#include const float kTol = 1.0e-5; +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template void test_channelwise_8bit_activation_groupwise_lowbit_weight( int m, int n, int k, int group_size) { - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config(); @@ -47,7 +77,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( get_packed_weight_data_size(ukernel_config, n, k, group_size); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_array_unique_ptr( + auto packed_weight_data = torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size); pack_weight_data_operator( @@ -74,7 +104,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( group_size); auto activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); // Run linear @@ -153,9 +183,6 @@ TEST( int n = 1; int k = 16 + 1; int group_size = 16; - - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, @@ -187,8 +214,6 @@ TEST( int k = 20; int group_size = 10; - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, diff --git a/torchao/experimental/kernels/cpu/macro.h b/torchao/experimental/ops/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/macro.h rename to torchao/experimental/ops/macro.h diff --git a/torchao/experimental/kernels/cpu/memory.h b/torchao/experimental/ops/memory.h similarity index 100% rename from torchao/experimental/kernels/cpu/memory.h rename to torchao/experimental/ops/memory.h diff --git a/torchao/experimental/kernels/cpu/parallel-aten-impl.h b/torchao/experimental/ops/parallel-aten-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-aten-impl.h rename to torchao/experimental/ops/parallel-aten-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-openmp-impl.h b/torchao/experimental/ops/parallel-openmp-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-openmp-impl.h rename to torchao/experimental/ops/parallel-openmp-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h b/torchao/experimental/ops/parallel-pthreadpool-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h rename to torchao/experimental/ops/parallel-pthreadpool-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h b/torchao/experimental/ops/parallel-single_threaded-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h rename to torchao/experimental/ops/parallel-single_threaded-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h b/torchao/experimental/ops/parallel-test_dummy-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h rename to torchao/experimental/ops/parallel-test_dummy-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel.h b/torchao/experimental/ops/parallel.h similarity index 73% rename from torchao/experimental/kernels/cpu/parallel.h rename to torchao/experimental/ops/parallel.h index 0d12c3acf9..e3949b8551 100644 --- a/torchao/experimental/kernels/cpu/parallel.h +++ b/torchao/experimental/ops/parallel.h @@ -10,7 +10,7 @@ namespace torchao { // F has signature [&](int64_t idx) template -void parallel_1d(const int64_t begin, const int64_t end, const F& f); +void parallel_1d(const int64_t begin, const int64_t end, const F& f); void set_num_threads(int num_threads); @@ -18,16 +18,17 @@ int get_num_threads(); } // namespace torchao - #ifdef TORCHAO_PARALLEL_ATEN #pragma message("TORCHAO_PARALLEL_ATEN is set. Using ATen parallel backend.") #ifndef INTRA_OP_PARALLEL - #pragma message("INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif #ifndef AT_PARALLEL_OPENMP - #pragma message("AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif -#include +#include #else #ifdef TORCHAO_PARALLEL_EXECUTORCH @@ -40,24 +41,25 @@ int get_num_threads(); #ifdef TORCHAO_PARALLEL_PTHREADPOOL #pragma message( \ "TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_OPENMP -#pragma message("TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") -#include +#pragma message( \ + "TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") +#include #else #if defined TORCHAO_PARALLEL_SINGLE_THREADED #pragma message( \ "TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_TEST_DUMMY #pragma message( \ "TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.") -#include +#include #else #error \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py similarity index 63% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py rename to torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py index 513088d2f0..d431d26939 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py +++ b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py @@ -11,18 +11,19 @@ import sys import unittest +import tempfile import torch sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ) from quant_api import ( _Int8DynActIntxWeightQuantizedLinearFallback, Int8DynActIntxWeightQuantizer, ) -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = glob.glob("/tmp/cmake-out/torchao/lib/liblinear_a8wxdq_ATEN.*") libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) if len(libs) == 0: print( @@ -73,7 +74,49 @@ def test_accuracy(self): # Assert at most 5% of entries are not close at a low tolerance self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + def test_export_compile_aoti(self): + group_size = 32 + m = 1 + n = 256 + k = 256 + nbit = 4 + has_weight_zeros = False + n_layers = 3 + layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] + model = torch.nn.Sequential(*layers) + + activations = torch.randn(m, k, dtype=torch.float32) + + print("Quantizing model") + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(model) + + print("Exporting quantized model") + exported = torch.export.export(quantized_model, (activations,)) + + print("Compiling quantized model") + quantized_model_compiled = torch.compile(quantized_model) + with torch.no_grad(): + quantized_model_compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + quantized_model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) if __name__ == "__main__": unittest.main()