Skip to content

Commit 7c632b4

Browse files
authored
Reuse GELU implementation from PyTorch core
Differential Revision: D66335522 Pull Request resolved: #7041
1 parent 9911992 commit 7c632b4

File tree

17 files changed

+102
-53
lines changed

17 files changed

+102
-53
lines changed

.ci/scripts/build_llama_android.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ set -exu
1010
# shellcheck source=/dev/null
1111
source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
1212

13+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
14+
PYTHON_EXECUTABLE=python3
15+
fi
16+
which "${PYTHON_EXECUTABLE}"
17+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
18+
1319
install_executorch_and_backend_lib() {
1420
echo "Installing executorch and xnnpack backend"
1521
clean_executorch_install_folders
@@ -22,6 +28,7 @@ install_executorch_and_backend_lib() {
2228
-DANDROID_ABI="${ANDROID_ABI}" \
2329
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
2430
-DCMAKE_BUILD_TYPE=Release \
31+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
2532
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
2633
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
2734
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
@@ -47,6 +54,7 @@ build_llama_runner() {
4754
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
4855
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
4956
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
57+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
5058
-Bcmake-android-out/examples/models/llama examples/models/llama
5159

5260
cmake --build cmake-android-out/examples/models/llama -j4 --config Release

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ cmake_install_executorch_libraries() {
154154
rm -rf cmake-out
155155
retry cmake \
156156
-DCMAKE_INSTALL_PREFIX=cmake-out \
157+
-DCMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')" \
157158
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
158159
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
159160
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \

.ci/scripts/test_llava.sh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ fi
3030
NPROC=8
3131
if hash nproc &> /dev/null; then NPROC=$(nproc); fi
3232

33+
python_lib=$($PYTHON_EXECUTABLE -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
34+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
3335
EXECUTORCH_COMMON_CMAKE_ARGS=" \
3436
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
35-
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
37+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
3638
-DEXECUTORCH_ENABLE_LOGGING=ON \
3739
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
3840
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -46,6 +48,7 @@ EXECUTORCH_COMMON_CMAKE_ARGS=" \
4648
cmake_install_executorch_libraries() {
4749
cmake \
4850
${EXECUTORCH_COMMON_CMAKE_ARGS} \
51+
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
4952
-B${BUILD_DIR} .
5053

5154
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
@@ -56,6 +59,7 @@ cmake_install_executorch_libraries_for_android() {
5659
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
5760
-DANDROID_ABI=arm64-v8a \
5861
${EXECUTORCH_COMMON_CMAKE_ARGS} \
62+
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
5963
-B${BUILD_DIR} .
6064

6165
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
@@ -76,7 +80,7 @@ cmake_build_llava_runner() {
7680

7781
cmake \
7882
${LLAVA_COMMON_CMAKE_ARGS} \
79-
-DCMAKE_PREFIX_PATH="$python_lib" \
83+
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
8084
-B${BUILD_DIR}/${dir} \
8185
${dir}
8286

@@ -92,7 +96,7 @@ cmake_build_llava_runner_for_android() {
9296
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
9397
-DANDROID_ABI=arm64-v8a \
9498
${LLAVA_COMMON_CMAKE_ARGS} \
95-
-DCMAKE_PREFIX_PATH="$python_lib" \
99+
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
96100
-DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \
97101
-B${BUILD_DIR}/${dir} \
98102
${dir}

.ci/scripts/test_model.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ prepare_artifacts_upload() {
5050

5151
build_cmake_executor_runner() {
5252
echo "Building executor_runner"
53+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
5354
rm -rf ${CMAKE_OUTPUT_DIR}
5455
cmake -DCMAKE_BUILD_TYPE=Debug \
5556
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
5657
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
58+
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \
5759
-B${CMAKE_OUTPUT_DIR} .
5860

5961
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
@@ -98,8 +100,7 @@ test_model() {
98100

99101
build_cmake_xnn_executor_runner() {
100102
echo "Building xnn_executor_runner"
101-
SITE_PACKAGES="$(${PYTHON_EXECUTABLE} -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
102-
CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch"
103+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
103104

104105
(rm -rf ${CMAKE_OUTPUT_DIR} \
105106
&& mkdir ${CMAKE_OUTPUT_DIR} \

.ci/scripts/test_phi_3_mini.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ NPROC=8
2222
if hash nproc &> /dev/null; then NPROC=$(nproc); fi
2323

2424
cmake_install_executorch_libraries() {
25+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
2526
cmake -DPYTHON_EXECUTABLE=python \
2627
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
28+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
2729
-DEXECUTORCH_ENABLE_LOGGING=1 \
2830
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
2931
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -39,8 +41,10 @@ cmake_install_executorch_libraries() {
3941
}
4042

4143
cmake_build_phi_3_mini() {
44+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
4245
cmake -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
4346
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
47+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
4448
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
4549
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
4650
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \

.ci/scripts/utils.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ cmake_install_executorch_lib() {
136136
clean_executorch_install_folders
137137
retry cmake -DBUCK2="$BUCK" \
138138
-DCMAKE_INSTALL_PREFIX=cmake-out \
139+
-DCMAKE_PREFIX_PATH="$($PYTHON_EXECUTABLE -c 'import torch as _; print(_.__path__[0])')" \
139140
-DCMAKE_BUILD_TYPE=Release \
140141
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
141142
-Bcmake-out .

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ jobs:
147147
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
148148
conda activate "${CONDA_ENV}"
149149
150+
source .ci/scripts/utils.sh
151+
install_executorch "use-pt-pinned-commit"
150152
BUILD_TOOL="cmake"
151153
PYTHON_EXECUTABLE=python \
152154
bash .ci/scripts/build_llama_android.sh "${BUILD_TOOL}"

.github/workflows/trunk.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ jobs:
394394
rm -rf cmake-out
395395
cmake \
396396
-DCMAKE_INSTALL_PREFIX=cmake-out \
397+
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
397398
-DCMAKE_BUILD_TYPE=Release \
398399
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
399400
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
@@ -411,6 +412,7 @@ jobs:
411412
cmake \
412413
-DCMAKE_INSTALL_PREFIX=cmake-out \
413414
-DCMAKE_BUILD_TYPE=Release \
415+
-DCMAKE_PREFIX_PATH="$(python -c 'import torch as _; print(_.__path__[0])')" \
414416
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
415417
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
416418
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

build/Utils.cmake

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,20 @@ function(resolve_python_executable)
321321
)
322322
endif()
323323
endfunction()
324+
325+
# find_package(Torch CONFIG REQUIRED) replacement for targets that
326+
# have a header-only Torch dependency. Because find_package sets
327+
# variables in the parent scope, we use a macro to preserve this
328+
# rather than maintaining our own list of those variables.
329+
macro(find_package_torch_headers)
330+
# We cannot simply use CMAKE_FIND_ROOT_PATH_BOTH, because that does
331+
# not propagate into TorchConfig.cmake.
332+
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
333+
set(OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
334+
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} BOTH)
335+
endforeach()
336+
find_package(Torch CONFIG REQUIRED)
337+
foreach(mode_kind IN ITEMS PACKAGE LIBRARY INCLUDE)
338+
set(CMAKE_FIND_ROOT_PATH_MODE_${mode_kind} ${OLD_CMAKE_FIND_ROOT_PATH_MODE_${mode_kind}})
339+
endforeach()
340+
endmacro()

build/build_android_llm_demo.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
set -ex
99

10+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
11+
PYTHON_EXECUTABLE=python3
12+
fi
13+
which "${PYTHON_EXECUTABLE}"
14+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
15+
1016
build_jar() {
1117
pushd extension/android
1218
./gradlew build
@@ -36,6 +42,7 @@ build_android_native_library() {
3642
fi
3743

3844
cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
45+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
3946
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
4047
-DANDROID_ABI="${ANDROID_ABI}" \
4148
-DANDROID_PLATFORM=android-26 \
@@ -69,6 +76,7 @@ build_android_native_library() {
6976
-DANDROID_ABI="${ANDROID_ABI}" \
7077
-DANDROID_PLATFORM=android-26 \
7178
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
79+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
7280
-DEXECUTORCH_ENABLE_LOGGING=ON \
7381
-DEXECUTORCH_LOG_LEVEL=Info \
7482
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \

kernels/optimized/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ message("Generated files ${gen_command_sources}")
6161

6262
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6363
add_library(optimized_kernels ${_optimized_kernels__srcs})
64+
find_package_torch_headers()
65+
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
6466
target_link_libraries(
6567
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
6668
)

kernels/optimized/cpu/op_gelu.cpp

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <cmath>
1515

16+
#include <ATen/native/cpu/Gelu.h>
1617
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1718
#include <executorch/runtime/kernel/kernel_includes.h>
1819
#include <executorch/runtime/platform/assert.h>
@@ -47,48 +48,26 @@ void gelu(
4748
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
4849
size_t lim = input.numel();
4950

50-
// TODO: Add fast path for tanh using sleef's tanh
5151
if (approximate == "tanh") {
52-
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
53-
for (size_t i = 0; i < lim; ++i) {
54-
const CTYPE x = in_data[i];
55-
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
56-
const CTYPE kKappa = 0.044715;
57-
auto x_cube = x * x * x;
58-
auto inner = kBeta * (x + kKappa * x_cube);
59-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
52+
using Vec = at::vec::Vectorized<CTYPE>;
53+
int i = 0;
54+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
55+
Vec x = Vec::loadu(in_data + i);
56+
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
6057
}
61-
} else if (approximate == "none") { // dont appx
62-
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
63-
// Function for Gaussian Distribution.
64-
65-
#ifndef __aarch64__
66-
for (size_t i = 0; i < lim; ++i) {
67-
const CTYPE x = in_data[i];
68-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
58+
for (; i < lim; ++i) {
59+
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
6960
}
70-
#else
71-
size_t i = 0;
72-
if constexpr (std::is_same_v<CTYPE, float>) {
73-
for (; i + 4 < lim; i += 4) {
74-
const float32x4_t in =
75-
vld1q_f32(static_cast<const float*>(&in_data[i]));
76-
const float32x4_t m_sqrt1_2x4 = {
77-
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
78-
const float32x4_t ones = vmovq_n_f32(1.0);
79-
const float32x4_t halves = vmovq_n_f32(0.5);
80-
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
81-
vst1q_f32(
82-
static_cast<float*>(&out_data[i]),
83-
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
84-
}
61+
} else if (approximate == "none") {
62+
using Vec = at::vec::Vectorized<CTYPE>;
63+
int i = 0;
64+
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
65+
Vec x = Vec::loadu(in_data + i);
66+
at::native::vectorized_gelu(x).store(out_data + i);
8567
}
8668
for (; i < lim; ++i) {
87-
const CTYPE x = in_data[i];
88-
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
69+
out_data[i] = at::native::scalar_gelu(in_data[i]);
8970
}
90-
#endif // __aarch64__
91-
9271
} else {
9372
ET_KERNEL_CHECK_MSG(
9473
context,

kernels/optimized/cpu/targets.bzl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
2828
op_target(name = "op_sigmoid"),
2929
op_target(
3030
name = "op_gelu",
31-
deps = select({
32-
"DEFAULT": [],
33-
"ovr_config//cpu:arm64": [
34-
"fbsource//third-party/sleef:sleef_arm",
35-
],
36-
}) + [
31+
deps = [
3732
"//executorch/kernels/portable/cpu/util:activation_ops_util",
33+
"//executorch/runtime/core/portable_type/c10:aten_headers_for_executorch",
3834
],
3935
),
4036
op_target(
@@ -96,6 +92,13 @@ _OPTIMIZED_ATEN_OPS = (
9692
),
9793
)
9894

95+
96+
def get_sleef_preprocessor_flags():
97+
if runtime.is_oss:
98+
return []
99+
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]
100+
101+
99102
def define_common_targets():
100103
"""Defines targets that should be shared between fbcode and xplat.
101104

kernels/optimized/op_registration_util.bzl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,14 @@ def define_op_library(name, deps):
9090
"//executorch/kernels/test/...",
9191
"@EXECUTORCH_CLIENTS",
9292
],
93-
# kernels often have helpers with no prototypes just disabling the warning here as the headers
94-
# are codegend and linked in later
95-
compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(),
93+
compiler_flags = [
94+
# kernels often have helpers with no prototypes just disabling the warning here as the headers
95+
# are codegend and linked in later
96+
"-Wno-missing-prototypes",
97+
# pragma unroll fails with -Os, don't need to warn us and
98+
# fail Werror builds; see https://godbolt.org/z/zvf85vTsr
99+
"-Wno-pass-failed",
100+
] + get_compiler_optimization_flags(),
96101
deps = [
97102
"//executorch/runtime/kernel:kernel_includes",
98103
] + augmented_deps + get_vec_deps(),

kernels/optimized/optimized-oss.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
#
33
# This yaml file contains operators that have optimized kernels available.
4-
# Note that this is a copy of optimized.yaml that does not include gelu and
5-
# log_softmax, due to the OSS build not currently including sleef.
4+
# Note that this is a copy of optimized.yaml that does not include log_softmax,
5+
# due to the OSS build not currently including sleef.
66
# TODO (T183193812)
77

88
- op: add.out
@@ -40,6 +40,11 @@
4040
- arg_meta: null
4141
kernel_name: torch::executor::opt_sigmoid_out
4242

43+
- op: gelu.out
44+
kernels:
45+
- arg_meta: null
46+
kernel_name: torch::executor::opt_gelu_out
47+
4348
- op: le.Scalar_out
4449
kernels:
4550
- arg_meta: null

shim/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,5 @@ def define_op_target(name, deps):
134134

135135
def is_op_disabled(name):
136136
# TODO (gjcomer) Enable ops with sleef dependency in OSS
137-
disabled_ops = ["op_gelu", "op_log_softmax"]
137+
disabled_ops = ["op_log_softmax"]
138138
return name in disabled_ops

test/run_oss_cpp_tests.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ elif [[ $(uname) == "Linux" ]]; then
2222
export LLVM_COV="${LLVM_COV:-llvm-cov}"
2323
fi
2424

25+
if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
26+
PYTHON_EXECUTABLE=python3
27+
fi
28+
which "${PYTHON_EXECUTABLE}"
29+
2530
build_executorch() {
2631
BUILD_VULKAN="OFF"
2732
if [ -x "$(command -v glslc)" ]; then
2833
BUILD_VULKAN="ON"
2934
fi
35+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
3036
cmake . \
3137
-DCMAKE_INSTALL_PREFIX=cmake-out \
38+
-DCMAKE_PREFIX_PATH="${CMAKE_PREFIX_PATH}" \
3239
-DEXECUTORCH_USE_CPP_CODE_COVERAGE=ON \
3340
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
3441
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \

0 commit comments

Comments
 (0)