Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c28053b

Browse files
authoredApr 10, 2025··
Merge branch 'main' into fix_logging
2 parents 99fd467 + 4bd7798 commit c28053b

File tree

26 files changed

+205
-74
lines changed

26 files changed

+205
-74
lines changed
 

‎.ci/scripts/build_android_instrumentation.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
1212
fi
1313
which "${PYTHON_EXECUTABLE}"
1414

15-
mkdir -p "${BUILD_AAR_DIR}"/executorch_android/src/androidTest/resources
16-
cp extension/module/test/resources/add.pte "${BUILD_AAR_DIR}"/executorch_android/src/androidTest/resources
15+
mkdir -p extension/android/executorch_android/src/androidTest/resources
16+
cp extension/module/test/resources/add.pte extension/android/executorch_android/src/androidTest/resources
1717

18-
pushd "${BUILD_AAR_DIR}"
18+
pushd extension/android
1919
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:testDebugUnitTest
2020
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:assembleAndroidTest
2121
popd

‎.github/workflows/_android.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
3838
mkdir -p ${ARTIFACTS_DIR_NAME}/library_test_dir
3939
bash .ci/scripts/build_android_instrumentation.sh
40-
cp ${BUILD_AAR_DIR}/executorch_android/build/outputs/apk/androidTest/debug/executorch_android-debug-androidTest.apk "${ARTIFACTS_DIR_NAME}/library_test_dir"
40+
cp extension/android/executorch_android/build/outputs/apk/androidTest/debug/executorch_android-debug-androidTest.apk "${ARTIFACTS_DIR_NAME}/library_test_dir"
4141
4242
mkdir -p ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom
4343
bash examples/models/llama/install_requirements.sh

‎CONTRIBUTING.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
Thank you for your interest in contributing to ExecuTorch! We want to make
22
it easy to contribute to this project.
33

4-
 
54

65
## Dev Install
76

@@ -91,7 +90,7 @@ executorch
9190
│ └── <a href="runtime/platform">platform</a> - Layer between architecture specific code and portable C++.
9291
├── <a href="schema">schema</a> - ExecuTorch PTE file format flatbuffer schemas.
9392
├── <a href="scripts">scripts</a> - Utility scripts for building libs, size management, dependency management, etc.
94-
├── <a href="shim">shim</a> - Compatibility layer between OSS and Internal builds.
93+
├── <a href="shim_et">shim_et</a> - Compatibility layer between OSS and Internal builds.
9594
├── <a href="test">test</a> - Broad scoped end-to-end tests.
9695
├── <a href="third-party">third-party</a> - Third-party dependencies.
9796
├── <a href="tools">tools</a> - Tools for building ExecuTorch from source, for different built tools (CMake, Buck).
@@ -192,9 +191,6 @@ in the Github repo.
192191

193192
## Coding Style
194193

195-
Goal: Encourage standards that make it easier to read, edit, maintain, and debug
196-
the ExecuTorch code.
197-
198194
### lintrunner
199195

200196
We use [`lintrunner`](https://pypi.org/project/lintrunner/) to help make sure the
@@ -259,7 +255,7 @@ toolchains, and having access to relatively modern C++ features.
259255

260256
#### C/C++ standard library usage
261257

262-
**Restricted usage of the C++ standard library.**
258+
**Restricted usage of the C++ standard library**
263259

264260
Rationale: ExecuTorch is intended to be portable to bare-metal systems that lack
265261
certain features, like dynamic memory, threading, and locking, required by parts
@@ -280,7 +276,7 @@ careful to also manually destroy objects initialized in this way.
280276

281277
#### C++ language features
282278

283-
**Exceptions: Do not use.**
279+
**Exceptions: Do not use**
284280
- Rationale: Exceptions are not widely supported on some classes of
285281
microcontrollers and DSPs, and they can significantly increase binary size.
286282

@@ -289,12 +285,12 @@ must work with threading**
289285
- Rationale: The core runtime must work on systems that do not have threading
290286
support.
291287

292-
**RTTI, dynamic_cast, and `<typeid>`: Do not use.**
288+
**RTTI, dynamic_cast, and `<typeid>`: Do not use**
293289
- Rationale: RTTI adds extra data to every virtual class. ExecuTorch doesn't
294290
have a strong need for `dynamic_cast` and friends, so it's better to reduce
295291
the binary size.
296292

297-
**Templates and template metaprogramming: Be careful and avoid if possible.**
293+
**Templates and template metaprogramming: Be careful and avoid if possible**
298294
- Rationale: Most templating results in code generation, and is one of the most
299295
common sources of binary bloat. Some use of templates is fine (e.g. an
300296
`ArrayRef<T>`, or code that handles multiple `ScalarType` types), but for the
@@ -359,7 +355,7 @@ docs](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/
359355
for basics.
360356

361357
1. Push your branch to your fork of `pytorch/executorch`. Most people do not
362-
have permission to push a branch directoy to the upstream repo.
358+
have permission to push a branch directory to the upstream repo.
363359
1. Create your PR
364360
- Use the `main` branch as the base.
365361
- Give the PR a clear and descriptive title. It will become the title of the

‎README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ Key value propositions of ExecuTorch are:
4949
## Getting Started
5050
To get started you can:
5151

52-
- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/main/index.html) on getting things running locally and deploy a model to a device
52+
- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/main/index.html) to get things running locally and deploy a model to a device
5353
- Use this [Colab Notebook](https://pytorch.org/executorch/stable/getting-started-setup.html#quick-setup-colab-jupyter-notebook-prototype) to start playing around right away
54-
- Jump straight into LLMs use cases by following specific instructions for [Llama](./examples/models/llama/README.md) and [Llava](./examples/models/llava/README.md)
54+
- Jump straight into LLM use cases by following specific instructions for [Llama](./examples/models/llama/README.md) and [Llava](./examples/models/llava/README.md)
5555

5656
## Feedback and Engagement
5757

‎backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class LayoutTransform(ExportPass):
4747
layout_agnostic_ops = {
4848
exir_ops.edge.aten.abs.default,
4949
exir_ops.edge.aten.add.Tensor,
50+
exir_ops.edge.aten.amax.default,
5051
exir_ops.edge.aten.bitwise_or.Tensor,
5152
exir_ops.edge.aten.bmm.default,
5253
exir_ops.edge.aten.bitwise_and.Tensor,

‎backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
op_abs,
1010
op_adaptive_avg_pool2d,
1111
op_add,
12+
op_amax,
1213
op_and,
1314
op_arange,
1415
op_argmin,
@@ -95,6 +96,7 @@
9596
op_abs,
9697
op_adaptive_avg_pool2d,
9798
op_add,
99+
op_amax,
98100
op_and,
99101
op_arange,
100102
op_argmin,

‎backends/qualcomm/builders/op_amax.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict, List
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpAmax, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class AMax(NodeVisitor):
21+
target = ["aten.amax.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = node.args[0]
32+
input_tensor = self.get_tensor(input_node, node)
33+
input_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
41+
# mean dims and keep dims
42+
mean_dims = cast(List[int], node.args[1])
43+
mean_dims = [
44+
mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims
45+
]
46+
if QCOM_AXIS_ORDER in node.meta:
47+
mean_dims = [
48+
node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims
49+
]
50+
mean_dims_shape = [len(mean_dims)]
51+
52+
output_tensor = self.get_tensor(node, node)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
node,
56+
output_tensor,
57+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58+
nodes_to_wrappers,
59+
)
60+
61+
reduce_max_op = PyQnnWrapper.PyQnnOpWrapper(
62+
node.name,
63+
QNN_OP_PACKAGE_NAME_QTI_AISW,
64+
OpAmax.op_name,
65+
)
66+
reduce_max_op.AddInputTensors([input_tensor_wrapper])
67+
reduce_max_op.AddOutputTensors([output_tensor_wrapper])
68+
reduce_max_op.AddTensorParam(
69+
OpAmax.param_axes,
70+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
71+
len(mean_dims_shape),
72+
mean_dims_shape,
73+
np.array(mean_dims, dtype=np.uint32),
74+
True,
75+
)
76+
if len(node.args) > 2:
77+
keep_dims = cast(bool, node.args[2])
78+
reduce_max_op.AddScalarParam(
79+
OpAmax.param_keep_dims,
80+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
81+
{QCOM_DATA: keep_dims},
82+
)
83+
84+
return reduce_max_op

‎backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
# instead of replicating them here.
1515

1616

17+
@dataclass(init=False, frozen=True)
18+
class OpAmax:
19+
op_name: str = "ReduceMax"
20+
param_axes: str = "axes"
21+
param_keep_dims: str = "keep_dims"
22+
23+
1724
@dataclass(init=False, frozen=True)
1825
class OpBatchnorm:
1926
op_name: str = "Batchnorm"

‎backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,11 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
182182
annotate_binary(node, quantization_config)
183183

184184

185+
@register_annotator([torch.ops.aten.amax.default])
186+
def annotate_amax(node: Node, quantization_config: QuantizationConfig) -> None:
187+
annotate_binary(node, quantization_config)
188+
189+
185190
@register_annotator([torch.ops.aten.argmin.default])
186191
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
187192
if _is_annotated([node]):

‎backends/qualcomm/tests/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ def forward(self, x):
7272
return torch.any(x, dim=self.dim, keepdim=self.keepdim)
7373

7474

75+
class AMax(torch.nn.Module):
76+
def __init__(self, dim=None, keepdim=False):
77+
super().__init__()
78+
self.dim = dim
79+
self.keepdim = keepdim
80+
81+
def forward(self, x):
82+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
83+
84+
7585
class Arange(torch.nn.Module):
7686
def __init__(self, start, end, step, dtype):
7787
super().__init__()

‎backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
113113
sample_input = (torch.randn(1, 512, 7, 7),)
114114
self.lower_module_and_test_output(module, sample_input)
115115

116+
def test_qnn_backend_amax(self):
117+
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
118+
sample_input = (torch.randn(4, 4),)
119+
for i, module in enumerate(modules):
120+
with self.subTest(i=i):
121+
self.lower_module_and_test_output(module, sample_input)
122+
116123
def test_qnn_backend_any(self):
117124
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
118125
sample_input = (torch.randn(3, 3, 3) > 0,)
@@ -1111,6 +1118,14 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
11111118
module = self.get_qdq_module(module, sample_input)
11121119
self.lower_module_and_test_output(module, sample_input)
11131120

1121+
def test_qnn_backend_amax(self):
1122+
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
1123+
sample_input = (torch.randn(4, 4),)
1124+
for i, module in enumerate(modules):
1125+
with self.subTest(i=i):
1126+
module = self.get_qdq_module(module, sample_input)
1127+
self.lower_module_and_test_output(module, sample_input)
1128+
11141129
def test_qnn_backend_any(self):
11151130
modules = [Any(), Any(dim=[0, 1]), Any(dim=1, keepdim=True)] # noqa: F405
11161131
sample_input = (torch.randn(3, 3, 3) > 0,)

‎docs/source/backends-coreml.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ add_subdirectory("executorch")
172172
target_link_libraries(
173173
my_target
174174
PRIVATE executorch
175-
executorch_module_static
176-
executorch_tensor
175+
extension_module_static
176+
extension_tensor
177177
optimized_native_cpu_ops_lib
178178
coremldelegate)
179179
```

‎docs/source/backends-xnnpack.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ add_subdirectory("executorch")
128128
target_link_libraries(
129129
my_target
130130
PRIVATE executorch
131-
executorch_module_static
132-
executorch_tensor
131+
extension_module_static
132+
extension_tensor
133133
optimized_native_cpu_ops_lib
134134
xnnpack_backend)
135135
```

‎docs/source/getting-started-architecture.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ There are three phases to deploy a PyTorch model to on-device: program preparati
1818

1919
ExecuTorch extends the flexibility and usability of PyTorch to edge devices. It
2020
leverages PyTorch 2 compiler and export functionality
21-
([TorchDynamo](https://pytorch.org/docs/stable/dynamo/index.html),
21+
([TorchDynamo](https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html),
2222
[AOTAutograd](https://pytorch.org/functorch/stable/notebooks/aot_autograd_optimizations.html),
2323
[Quantization](https://pytorch.org/docs/main/quantization.html),
2424
[dynamic shapes](https://pytorch.org/get-started/pytorch-2.0/#pytorch-2x-faster-more-pythonic-and-as-dynamic-as-ever),

‎docs/source/getting-started.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ To add the library to your app, add the following dependency to gradle build rul
121121
dependencies {
122122
implementation("org.pytorch:executorch-android:0.5.1")
123123
}
124+
125+
# See latest available versions in https://mvnrepository.com/artifact/org.pytorch/executorch-android
124126
```
125127

126128
#### Runtime APIs
@@ -170,8 +172,8 @@ add_subdirectory("executorch")
170172
target_link_libraries(
171173
my_target
172174
PRIVATE executorch
173-
executorch_module_static
174-
executorch_tensor
175+
extension_module_static
176+
extension_tensor
175177
optimized_native_cpu_ops_lib
176178
xnnpack_backend)
177179
```

‎docs/source/using-executorch-android.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,22 @@ public class MainActivity extends Activity {
172172
protected void onCreate(Bundle savedInstanceState) {
173173
super.onCreate(savedInstanceState);
174174
// Load the ExecuTorch module
175-
module = Module.load("/path/to/module.pte");
176-
}
177-
public void runInference(View view) {
178-
// Prepare input data
179-
Tensor input = Tensor.fromBlob(getInputData());
180-
// Run inference
181-
Tensor output = module.forward(EValue.from(input))[0].toTensor();
182-
// Process output data
183-
processOutput(output);
175+
Module module = Module.load("/data/local/tmp/add.pte");
176+
Tensor tensor1 = Tensor.fromBlob(new float[] {1.0f}, new long[] {1});
177+
Tensor tensor2 = Tensor.fromBlob(new float[] {20.0f}, new long[] {1});
178+
179+
EValue eValue1 = EValue.from(tensor1);
180+
EValue eValue2 = EValue.from(tensor2);
181+
float result = module.forward(eValue1, eValue2)[0].toTensor().getDataAsFloatArray()[0];
184182
}
185183
}
186184
```
185+
186+
Push the corresponding pte file to the phone:
187+
```sh
188+
adb push extension/module/test/resources/add.pte /data/local/tmp/
189+
```
190+
187191
This example loads an ExecuTorch module, prepares input data, runs inference, and processes the output data.
188192

189193
Please use [DeepLabV3AndroidDemo](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo)

‎docs/source/using-executorch-cpp.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Running a model using the low-level runtime APIs allows for a high-degree of con
3838
3939
## Building with CMake
4040
41-
ExecuTorch uses CMake as the primary build system. Inclusion of the module and tensor APIs are controlled by the `EXECUTORCH_BUILD_EXTENSION_MODULE` and `EXECUTORCH_BUILD_EXTENSION_TENSOR` CMake options. As these APIs may not be supported on embedded systems, they are disabled by default when building from source. The low-level API surface is always included. To link, add the `executorch` target as a CMake dependency, along with `executorch_module_static` and `executorch_tensor`, if desired.
41+
ExecuTorch uses CMake as the primary build system. Inclusion of the module and tensor APIs are controlled by the `EXECUTORCH_BUILD_EXTENSION_MODULE` and `EXECUTORCH_BUILD_EXTENSION_TENSOR` CMake options. As these APIs may not be supported on embedded systems, they are disabled by default when building from source. The low-level API surface is always included. To link, add the `executorch` target as a CMake dependency, along with `extension_module_static` and `extension_tensor`, if desired.
4242
4343
```
4444
# CMakeLists.txt
@@ -47,8 +47,8 @@ add_subdirectory("executorch")
4747
target_link_libraries(
4848
my_target
4949
PRIVATE executorch
50-
executorch_module_static
51-
executorch_tensor
50+
extension_module_static
51+
extension_tensor
5252
optimized_native_cpu_ops_lib
5353
xnnpack_backend)
5454
```

‎examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ You may also wonder what the "--metadata" flag is doing. This flag helps export
135135

136136
Convert tokenizer for Llama 2
137137
```
138-
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
138+
python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin
139139
```
140140
Rename tokenizer for Llama 3 with command: `mv tokenizer.model tokenizer.bin`. We are updating the demo app to support tokenizer in original format directly.
141141

‎examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ You may wonder what the ‘--metadata’ flag is doing. This flag helps export t
103103

104104
* Convert tokenizer for Llama 2 and Llava (skip this for Llama 3.x)
105105
```
106-
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
106+
python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin
107107
```
108108

109109
### For LLaVA model

‎examples/models/llama2/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ You can export and run the original Llama 2 7B model.
4141
```
4242
4. Create tokenizer.bin.
4343
```
44-
python -m extension.llm.tokenizer.tokenizer -t <tokenizer.model> -o tokenizer.bin
44+
python -m pytorch_tokenizers.tools.llama2c.convert -t <tokenizer.model> -o tokenizer.bin
4545
```
4646
4747
Pass the converted `tokenizer.bin` file instead of `tokenizer.model` for subsequent steps.

‎examples/models/phi-3-mini/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pip uninstall -y transformers ; pip install transformers==4.44.2
1313
```
1414
cd executorch
1515
wget -O tokenizer.model "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/tokenizer.model?download=true"
16-
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
16+
python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin
1717
```
1818
2. Export the model. This step will take a few minutes to finish.
1919
```

‎examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
4141
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
4242

4343
# tokenizer.bin:
44-
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
44+
python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin
4545

4646
# params.json:
4747
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json

‎exir/program/test/test_program.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,17 +725,17 @@ def count_nodes(graph_module, target):
725725
)
726726

727727
def test_edge_dialect_non_core_aten_ops(self):
728-
class LinalgNorm(torch.nn.Module):
728+
class LinalgRank(torch.nn.Module):
729729
def __init__(self):
730730
super().__init__()
731731

732732
def forward(self, x: torch.Tensor) -> torch.Tensor:
733-
return torch.linalg.norm(x)
733+
return torch.linalg.matrix_rank(x)
734734

735735
from torch._export.verifier import SpecViolationError
736736

737-
input = torch.arange(9, dtype=torch.float) - 4
738-
ep = torch.export.export(LinalgNorm(), (input,), strict=True)
737+
input = torch.ones((9, 9, 9), dtype=torch.float)
738+
ep = torch.export.export(LinalgRank(), (input,), strict=True)
739739

740740
# aten::linalg_norm is not a core op, so it should error out
741741
with self.assertRaises(SpecViolationError):
@@ -748,9 +748,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
748748
ep,
749749
compile_config=EdgeCompileConfig(
750750
_check_ir_validity=True,
751-
_core_aten_ops_exception_list=[
752-
torch.ops.aten.linalg_vector_norm.default
753-
],
751+
_core_aten_ops_exception_list=[torch.ops.aten._linalg_svd.default],
754752
),
755753
)
756754
except SpecViolationError:

‎exir/tracer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,18 @@ def _default_decomposition_table(
631631
]
632632
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633633
return get_decompositions(decomp_opset)
634+
635+
decomps = default_decompositions()
636+
# Add edge specific decompositions
637+
additional_decomp_ops = [
638+
# TODO: Eventually this op should be added to the core decompo table, and will not
639+
# need to be added here.
640+
torch.ops.aten.linalg_vector_norm.default,
641+
]
642+
additional_decomps = get_decompositions(additional_decomp_ops)
643+
decomps.update(additional_decomps)
634644
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
645+
return decomps
636646

637647

638648
def dynamo_trace(

‎extension/android/executorch_android/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ android {
2727
}
2828

2929
sourceSets {
30+
main {
31+
jniLibs.srcDirs = ['../../../cmake-out-android-so/']
32+
}
3033
androidTest {
3134
resources.srcDirs += [ 'src/androidTest/resources' ]
3235
}

‎scripts/build_android_library.sh

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
1212
fi
1313
which "${PYTHON_EXECUTABLE}"
1414

15-
copy_src() {
16-
cp -r extension/android/build.gradle extension/android/settings.gradle extension/android/gradlew extension/android/gradle extension/android/gradlew.bat extension/android/gradle.properties "${BUILD_AAR_DIR}"
17-
cp -r extension/android/executorch_android "${BUILD_AAR_DIR}/executorch_android"
18-
}
19-
2015
build_android_native_library() {
2116
ANDROID_ABI="$1"
2217
ANDROID_NDK="${ANDROID_NDK:-/opt/ndk}"
@@ -93,54 +88,53 @@ build_android_native_library() {
9388
cmake --build "${CMAKE_OUT}"/extension/android -j "${CMAKE_JOBS}" --config "${EXECUTORCH_CMAKE_BUILD_TYPE}"
9489

9590
# Copy artifacts to ABI specific directory
96-
mkdir -p "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}"
97-
cp "${CMAKE_OUT}"/extension/android/*.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
91+
local SO_STAGE_DIR="cmake-out-android-so/${ANDROID_ABI}"
92+
mkdir -p ${SO_STAGE_DIR}
93+
cp "${CMAKE_OUT}"/extension/android/*.so "${SO_STAGE_DIR}/libexecutorch.so"
9894

9995
# Copy QNN related so library
10096
if [ -n "$QNN_SDK_ROOT" ] && [ "$ANDROID_ABI" == "arm64-v8a" ]; then
101-
cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
102-
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
103-
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
104-
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV69Stub.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
105-
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV73Stub.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
106-
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV75Stub.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
107-
cp "${QNN_SDK_ROOT}"/lib/hexagon-v69/unsigned/libQnnHtpV69Skel.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
108-
cp "${QNN_SDK_ROOT}"/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
109-
cp "${QNN_SDK_ROOT}"/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so "${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/"
97+
cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so ${SO_STAGE_DIR}
98+
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so ${SO_STAGE_DIR}
99+
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so ${SO_STAGE_DIR}
100+
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV69Stub.so ${SO_STAGE_DIR}
101+
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV73Stub.so ${SO_STAGE_DIR}
102+
cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtpV75Stub.so ${SO_STAGE_DIR}
103+
cp "${QNN_SDK_ROOT}"/lib/hexagon-v69/unsigned/libQnnHtpV69Skel.so ${SO_STAGE_DIR}
104+
cp "${QNN_SDK_ROOT}"/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so ${SO_STAGE_DIR}
105+
cp "${QNN_SDK_ROOT}"/lib/hexagon-v75/unsigned/libQnnHtpV75Skel.so ${SO_STAGE_DIR}
110106
fi
111107

112108
# Copy MTK related so library
113109
if [ -n "$NEURON_BUFFER_ALLOCATOR_LIB" ] && [ -n "$NEURON_USDK_ADAPTER_LIB" ] && [ "$ANDROID_ABI" == "arm64-v8a" ]; then
114-
cp "${CMAKE_OUT}"/backends/mediatek/libneuron_backend.so ${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/
115-
cp "${NEURON_BUFFER_ALLOCATOR_LIB}" ${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/
116-
cp "${NEURON_USDK_ADAPTER_LIB}" ${BUILD_AAR_DIR}/executorch_android/src/main/jniLibs/${ANDROID_ABI}/
110+
cp "${CMAKE_OUT}"/backends/mediatek/libneuron_backend.so ${SO_STAGE_DIR}
111+
cp "${NEURON_BUFFER_ALLOCATOR_LIB}" ${SO_STAGE_DIR}
112+
cp "${NEURON_USDK_ADAPTER_LIB}" ${SO_STAGE_DIR}
117113
fi
118114
}
119115

120116
build_aar() {
121-
pushd "${BUILD_AAR_DIR}"
122-
# Rename libexecutorch_jni.so to libexecutorch.so for soname consistency
123-
# between Java and JNI
124-
find . -type f -name "libexecutorch_jni.so" -exec bash -c 'mv "$1" "${1/_jni/}"' bash {} \;
125117
if [ "$EXECUTORCH_CMAKE_BUILD_TYPE" == "Release" ]; then
126-
find . -type f -name "*.so" -exec "$ANDROID_NDK"/toolchains/llvm/prebuilt/*/bin/llvm-strip {} \;
118+
find cmake-out-android-so -type f -name "*.so" -exec "$ANDROID_NDK"/toolchains/llvm/prebuilt/*/bin/llvm-strip {} \;
127119
fi
120+
pushd extension/android/
128121
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build
129-
cp executorch_android/build/outputs/aar/executorch_android-debug.aar executorch.aar
130122
popd
123+
cp extension/android/executorch_android/build/outputs/aar/executorch_android-debug.aar "${BUILD_AAR_DIR}/executorch.aar"
131124
}
132125

133126
main() {
134127
if [[ -z "${BUILD_AAR_DIR:-}" ]]; then
135128
BUILD_AAR_DIR="$(mktemp -d)"
136129
fi
137130
export BUILD_AAR_DIR
131+
mkdir -p $BUILD_AAR_DIR
138132
if [ -z "$ANDROID_ABIS" ]; then
139133
ANDROID_ABIS=("arm64-v8a" "x86_64")
140134
fi
141135
export ANDROID_ABIS
142136

143-
copy_src
137+
mkdir -p cmake-out-android-so/
144138
for ANDROID_ABI in "${ANDROID_ABIS[@]}"; do
145139
build_android_native_library ${ANDROID_ABI}
146140
done

0 commit comments

Comments
 (0)
Please sign in to comment.