Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 711282e

Browse files
committedApr 7, 2025
Update base for Update on "[ET-VK][ez] Make squeeze insertion requirements more strict"
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
2 parents f1e2f1a + 6adff9c commit 711282e

File tree

39 files changed

+2195
-740
lines changed

39 files changed

+2195
-740
lines changed
 
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7ae0ce6360b6e4f944906502d20da24c04debee5
1+
59d5cf083b4f860dea76fe8936076177f9367f10

‎backends/arm/test/models/test_conformer.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TestConformer(unittest.TestCase):
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
34-
"torch.ops.aten._assert_scalar.default": 10,
34+
"torch.ops.aten._assert_scalar.default": 7,
3535
"torch.ops.aten._local_scalar_dense.default": 1,
3636
}
3737

‎backends/arm/test/models/test_llama.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sys
1212
import unittest
1313

14+
import pytest
1415
import torch
1516

1617
from executorch.backends.arm.test import common, conftest
@@ -102,7 +103,7 @@ def test_llama_tosa_MI(self):
102103
llama_model, llama_inputs, llama_meta = self.prepare_model()
103104

104105
if llama_model is None and llama_inputs is None and llama_meta is None:
105-
return
106+
pytest.skip("Missing model and/or input files")
106107

107108
with torch.no_grad():
108109
(

‎backends/xnnpack/operators/op_slice_copy.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def define_node(
6969
output_shape = [output_shape[i] for i in PERM_NCHW_TO_NHWC]
7070
dim_of_slice = PERM_NHWC_TO_NCHW[dim_of_slice]
7171

72-
slice_begin_index = cast(int, node.args[2])
72+
slice_begin_index = 0
73+
if len(node.args) > 2 and node.args[2]:
74+
slice_begin_index = cast(int, node.args[2])
7375
if slice_begin_index < 0:
7476
slice_begin_index = input_shape[dim_of_slice] + slice_begin_index
7577

‎backends/xnnpack/test/ops/test_slice_copy.py‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,18 @@ def forward(self, x):
6969
# Note that two of the slices are optimized away as they are identity.
7070
self._test_slice_copy(ConvSlice(), inputs, 4, 2)
7171

72+
def test_fp32_slice_copy_default_start(self):
73+
"""
74+
XNNPACK supports default start in slice op.
75+
"""
76+
77+
class Slice(torch.nn.Module):
78+
def forward(self, x):
79+
return torch.ops.aten.slice.Tensor(x, 0, None, 2)
80+
81+
inputs = (torch.randn(5, 5),)
82+
self._test_slice_copy(Slice(), inputs, 1, 1)
83+
7284
def test_fp32_slice_copy_stride_non_1(self):
7385
"""
7486
XNNPACK does not support strided slicing.

‎devtools/etdump/etdump_filter.cpp‎

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/devtools/etdump/etdump_filter.h>
10+
11+
#include <executorch/runtime/core/error.h>
12+
13+
using ::executorch::runtime::DelegateDebugIntId;
14+
using ::executorch::runtime::Error;
15+
using ::executorch::runtime::kUnsetDelegateDebugIntId;
16+
17+
namespace executorch {
18+
namespace etdump {
19+
20+
ETDumpFilter::ETDumpFilter() = default;
21+
22+
Result<bool> ETDumpFilter::add_regex(string_view pattern) {
23+
auto regex = std::make_unique<re2::RE2>(pattern.data());
24+
if (!regex->ok()) {
25+
return Error::InvalidArgument; // Error during regex compilation
26+
}
27+
regex_patterns_.emplace_back(std::move(regex));
28+
return true;
29+
}
30+
31+
Result<bool> ETDumpFilter::set_debug_handle_range(size_t start, size_t end) {
32+
if (start >= end) {
33+
return Error::InvalidArgument; // Start is greater than end
34+
}
35+
if (start < 0 || end < 0) {
36+
return Error::InvalidArgument; // Start or end is negative
37+
}
38+
range_start_ = start;
39+
range_end_ = end;
40+
return true;
41+
}
42+
43+
Result<bool> ETDumpFilter::filter_name_(const char* name) {
44+
if (name == nullptr) {
45+
return Error::InvalidArgument;
46+
}
47+
if (regex_patterns_.empty()) {
48+
return true;
49+
}
50+
for (const auto& regex : regex_patterns_) {
51+
if (RE2::FullMatch(name, *regex)) {
52+
return true;
53+
}
54+
}
55+
return false;
56+
}
57+
58+
Result<bool> ETDumpFilter::filter_delegate_debug_index_(
59+
DelegateDebugIntId debug_handle) {
60+
if (debug_handle == kUnsetDelegateDebugIntId) {
61+
return Error::InvalidArgument; // Delegate debug index is unset
62+
}
63+
64+
if (range_start_ == 0 && range_end_ == 0) {
65+
return true;
66+
}
67+
68+
if (debug_handle < range_start_ || debug_handle >= range_end_) {
69+
return false;
70+
}
71+
72+
return true;
73+
}
74+
75+
Result<bool> ETDumpFilter::filter(
76+
const char* name,
77+
DelegateDebugIntId delegate_debug_index) {
78+
if ((name == nullptr) == (delegate_debug_index == kUnsetDelegateDebugIntId)) {
79+
return Error::InvalidArgument; // Name and delegate debug index should be
80+
// both set or unset
81+
}
82+
83+
if (name) {
84+
return filter_name_(name);
85+
} else {
86+
return filter_delegate_debug_index_(delegate_debug_index);
87+
}
88+
}
89+
90+
size_t ETDumpFilter::get_n_regex() const {
91+
return regex_patterns_.size();
92+
}
93+
94+
} // namespace etdump
95+
} // namespace executorch

‎devtools/etdump/etdump_filter.h‎

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <re2/re2.h>
12+
#include <memory>
13+
14+
#include <executorch/runtime/core/event_tracer.h>
15+
#include <executorch/runtime/core/result.h>
16+
#include <executorch/runtime/platform/platform.h>
17+
18+
namespace executorch::etdump {
19+
20+
using ::executorch::aten::string_view;
21+
using ::executorch::runtime::Result;
22+
23+
/**
24+
* ETDumpFilter is a class that filters intermediate output based on output's
25+
* name by full regex filtering, or delegate debug indices by range-based
26+
* filtering.
27+
*
28+
* Note that this filter supports up to MAX_REGEX_PATTERNS regex patterns with a
29+
* maximum length of MAX_PATTERN_LENGTH characters each.
30+
*/
31+
32+
class ETDumpFilter : public ::executorch::runtime::EventTracerFilterBase {
33+
public:
34+
ETDumpFilter();
35+
~ETDumpFilter() override = default;
36+
/**
37+
* Adds a regex pattern to the filter.
38+
*
39+
* @param[in] pattern A c string representing the regex pattern to be added.
40+
*
41+
* @return A Result<bool> indicating the success or failure of adding the
42+
* regex pattern.
43+
* - True if the pattern is successfully added.
44+
* - False if the pattern could not be added or if the maximum number
45+
* of patterns is exceeded.
46+
* - An error code if number of pattern has reached to cap, or any
47+
* error occurs during regex compilation.
48+
*/
49+
Result<bool> add_regex(string_view pattern);
50+
/**
51+
* Sets the range for the delegate debug index filtering as [start, end).
52+
* Note that this function will flush the existing range.
53+
*
54+
* @param[in] start The start of the range for filtering.
55+
* @param[in] end The end of the range for filtering.
56+
*
57+
* @return A Result<bool> indicating the success or failure of setting the
58+
* range.
59+
* - True if the range is successfully set.
60+
* - An error code if an error occurs.
61+
*/
62+
Result<bool> set_debug_handle_range(size_t start, size_t end);
63+
64+
/**
65+
* Filters events based on the given name or delegate debug index.
66+
*
67+
* Note that everytime only one of either the name or delegate_debug_index
68+
* should be passed in.
69+
*
70+
* @param[in] name A pointer to a string representing the `name` of the
71+
* event. If `delegate_debug_index` is not set to kUnsetDebugHandle, `name`
72+
* should be set to nullptr.
73+
*
74+
* @param[in] delegate_debug_index A DebugHandle representing the debug index
75+
* of the delegate. If `name` is not nullptr, this should be set to
76+
* kUnsetDebugHandle.
77+
*
78+
* @return A Result<bool> indicating whether the event matches the filter
79+
* criteria.
80+
* - True if the event matches the filter, or filter is unset.
81+
* - False if the event does not match or is unknown.
82+
* - An error code if an error occurs during filtering.
83+
*/
84+
Result<bool> filter(
85+
const char* name,
86+
::executorch::runtime::DelegateDebugIntId delegate_debug_index) override;
87+
88+
/**
89+
* Returns the number of regex patterns in the filter.
90+
*/
91+
size_t get_n_regex() const;
92+
93+
private:
94+
std::vector<std::unique_ptr<re2::RE2>> regex_patterns_;
95+
size_t range_start_ = 0;
96+
size_t range_end_ = 0;
97+
Result<bool> filter_name_(const char* name);
98+
Result<bool> filter_delegate_debug_index_(
99+
::executorch::runtime::DelegateDebugIntId delegate_debug_index);
100+
};
101+
102+
} // namespace executorch::etdump

‎devtools/etdump/etdump_flatcc.cpp‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <executorch/devtools/etdump/etdump_schema_flatcc_builder.h>
1616
#include <executorch/devtools/etdump/etdump_schema_flatcc_reader.h>
1717
#include <executorch/devtools/etdump/utils.h>
18+
#include <executorch/runtime/core/error.h>
1819
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1920
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2021
#include <executorch/runtime/platform/assert.h>
@@ -28,6 +29,7 @@ using ::executorch::runtime::ChainID;
2829
using ::executorch::runtime::DebugHandle;
2930
using ::executorch::runtime::DelegateDebugIdType;
3031
using ::executorch::runtime::DelegateDebugIntId;
32+
using ::executorch::runtime::Error;
3133
using ::executorch::runtime::EValue;
3234
using ::executorch::runtime::EventTracerEntry;
3335
using ::executorch::runtime::kUnsetDelegateDebugIntId;

‎devtools/etdump/etdump_flatcc.h‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#pragma once
1010

1111
#include <cstdint>
12-
#include <memory>
1312

1413
#include <executorch/devtools/etdump/data_sinks/buffer_data_sink.h>
1514
#include <executorch/devtools/etdump/data_sinks/data_sink_base.h>

‎devtools/etdump/targets.bzl‎

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ def define_common_targets():
101101
for aten_mode in get_aten_mode_options():
102102
aten_suffix = "_aten" if aten_mode else ""
103103

104+
runtime.cxx_library(
105+
name = "etdump_filter" + aten_suffix,
106+
srcs = [
107+
"etdump_filter.cpp",
108+
],
109+
exported_headers = [
110+
"etdump_filter.h",
111+
],
112+
deps = [
113+
"//executorch/runtime/platform:platform",
114+
],
115+
exported_deps = [
116+
"fbsource//third-party/re2:re2",
117+
"//executorch/runtime/core:event_tracer" + aten_suffix,
118+
],
119+
visibility = [
120+
"//executorch/...",
121+
"@EXECUTORCH_CLIENTS",
122+
],
123+
)
124+
104125
runtime.cxx_library(
105126
name = "etdump_flatcc" + aten_suffix,
106127
srcs = [
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <executorch/devtools/etdump/etdump_filter.h>
12+
#include <executorch/runtime/platform/runtime.h>
13+
14+
#include <cstring>
15+
16+
using ::executorch::etdump::ETDumpFilter;
17+
using ::executorch::runtime::Error;
18+
using ::executorch::runtime::kUnsetDelegateDebugIntId;
19+
using ::executorch::runtime::Result;
20+
21+
class ETDumpFilterTest : public ::testing::Test {
22+
protected:
23+
ETDumpFilter filter;
24+
25+
void SetUp() override {
26+
torch::executor::runtime_init();
27+
}
28+
29+
void TearDown() override {}
30+
};
31+
32+
TEST_F(ETDumpFilterTest, AddRegexPatternSuccess) {
33+
Result<bool> result = filter.add_regex("test.*");
34+
EXPECT_TRUE(result.ok());
35+
EXPECT_TRUE(result.get());
36+
}
37+
38+
TEST_F(ETDumpFilterTest, SetDebugHandleRangeSuccess) {
39+
Result<bool> result = filter.set_debug_handle_range(10, 20);
40+
EXPECT_TRUE(result.ok());
41+
EXPECT_TRUE(result.get());
42+
}
43+
44+
TEST_F(ETDumpFilterTest, SetDebugHandleRangeFailure) {
45+
Result<bool> result = filter.set_debug_handle_range(20, 10);
46+
EXPECT_EQ(result.error(), Error::InvalidArgument);
47+
}
48+
49+
TEST_F(ETDumpFilterTest, FilterByNameSuccess) {
50+
filter.add_regex("event.*");
51+
Result<bool> result = filter.filter("event_name", kUnsetDelegateDebugIntId);
52+
EXPECT_TRUE(result.ok());
53+
EXPECT_TRUE(result.get());
54+
}
55+
56+
TEST_F(ETDumpFilterTest, PartialMatchingFailed) {
57+
filter.add_regex("event.*");
58+
Result<bool> result =
59+
filter.filter("non_matching_event", kUnsetDelegateDebugIntId);
60+
EXPECT_TRUE(result.ok());
61+
EXPECT_FALSE(result.get());
62+
}
63+
64+
TEST_F(ETDumpFilterTest, FilterByDelegateDebugIndexSuccess) {
65+
filter.set_debug_handle_range(10, 20);
66+
Result<bool> result = filter.filter(nullptr, 15);
67+
EXPECT_TRUE(result.ok());
68+
EXPECT_TRUE(result.get());
69+
}
70+
71+
TEST_F(ETDumpFilterTest, FilterByDelegateDebugIndexFailure) {
72+
filter.set_debug_handle_range(10, 20);
73+
Result<bool> result = filter.filter(nullptr, 25);
74+
EXPECT_TRUE(result.ok());
75+
EXPECT_FALSE(result.get());
76+
}
77+
78+
TEST_F(ETDumpFilterTest, NaiveFilterNameInputCanSucceed) {
79+
Result<bool> result = filter.filter("any_input", kUnsetDelegateDebugIntId);
80+
EXPECT_TRUE(result.ok());
81+
EXPECT_TRUE(result.get());
82+
}
83+
84+
TEST_F(ETDumpFilterTest, NaiveFilterDebugHandleInputCanSucceed) {
85+
Result<bool> result = filter.filter(nullptr, 12345);
86+
EXPECT_TRUE(result.ok());
87+
EXPECT_TRUE(result.get());
88+
}
89+
90+
TEST_F(ETDumpFilterTest, IllegalInput) {
91+
filter.add_regex("pattern");
92+
Result<bool> result = filter.filter("matching_event", 1);
93+
EXPECT_EQ(result.error(), Error::InvalidArgument);
94+
}
95+
96+
TEST_F(ETDumpFilterTest, NoMatchFirstThenMatch) {
97+
filter.add_regex("non_matching_pattern");
98+
Result<bool> result_1 =
99+
filter.filter("matching_event", kUnsetDelegateDebugIntId);
100+
EXPECT_TRUE(result_1.ok());
101+
EXPECT_FALSE(result_1.get());
102+
filter.add_regex("matching_.*");
103+
Result<bool> result_2 =
104+
filter.filter("matching_event", kUnsetDelegateDebugIntId);
105+
EXPECT_TRUE(result_2.ok());
106+
EXPECT_TRUE(result_2.get());
107+
}
108+
109+
TEST_F(ETDumpFilterTest, MatchRegexFirstThen) {
110+
filter.add_regex("matching.*");
111+
Result<bool> result_1 =
112+
filter.filter("matching_event", kUnsetDelegateDebugIntId);
113+
EXPECT_TRUE(result_1.ok());
114+
EXPECT_TRUE(result_1.get());
115+
filter.add_regex("non_matching_pattern");
116+
Result<bool> result_2 =
117+
filter.filter("matching_event", kUnsetDelegateDebugIntId);
118+
EXPECT_TRUE(result_2.ok());
119+
EXPECT_TRUE(result_2.get());
120+
}

‎devtools/etdump/tests/targets.bzl‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,14 @@ def define_common_targets():
2121
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
2222
],
2323
)
24+
25+
runtime.cxx_test(
26+
name = "etdump_filter_test",
27+
srcs = [
28+
"etdump_filter_test.cpp",
29+
],
30+
deps = [
31+
"//executorch/devtools/etdump:etdump_filter",
32+
"//executorch/runtime/platform:platform",
33+
],
34+
)

‎examples/arm/README.md‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ $ source executorch/examples/arm/ethos-u-scratch/setup_path.sh
3232
$ executorch/examples/arm/run.sh --model_name=mv2 --target=ethos-u85-128 [--scratch-dir=same-optional-scratch-dir-as-before]
3333
```
3434

35+
### Ethos-U minimal example
36+
37+
See the jupyter notebook `ethos_u_minimal_example.ipynb` for an explained minimal example of the full flow for running a
38+
PyTorch module on the EthosUDelegate. The notebook runs directly in some IDE:s s.a. VS Code, otherwise it can be run in
39+
your browser using
40+
```
41+
pip install jupyter
42+
jupyter notebook ethos_u_minimal_example.ipynb
43+
```
44+
3545
### Online Tutorial
3646

3747
We also have a [tutorial](https://pytorch.org/executorch/stable/executorch-arm-delegate-tutorial.html) explaining the steps performed in these
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
From 23712ff626db16793d428dddcb530f9e5faaa073 Mon Sep 17 00:00:00 2001
2+
From: Adrian Lundell <adrian.lundell@arm.com>
3+
Date: Thu, 3 Apr 2025 14:25:52 +0200
4+
Subject: [PATCH] Move input_data_sec to NOLOAD area
5+
6+
---
7+
targets/corstone-300/platform.ld | 10 ++++++++--
8+
targets/corstone-320/platform.ld | 8 ++++++--
9+
2 files changed, 14 insertions(+), 4 deletions(-)
10+
11+
diff --git a/targets/corstone-300/platform.ld b/targets/corstone-300/platform.ld
12+
index 1733509..3ccce64 100644
13+
--- a/targets/corstone-300/platform.ld
14+
+++ b/targets/corstone-300/platform.ld
15+
@@ -272,13 +272,12 @@ SECTIONS
16+
*(.bss.tensor_arena)
17+
#endif
18+
19+
- . = ALIGN(4);
20+
- *(input_data_sec)
21+
. = ALIGN(16);
22+
#if (ETHOSU_MODEL == 1)
23+
*(network_model_sec)
24+
#endif
25+
* (expected_output_data_sec)
26+
+ . = ALIGN(16);
27+
* (sec_command_stream, sec_weight_data, sec_input_data)
28+
*(.got*)
29+
*(.rodata*)
30+
@@ -287,6 +286,13 @@ SECTIONS
31+
. = ALIGN(4);
32+
} > DDR :rom_dram
33+
34+
+ .ddr_noload (NOLOAD) :
35+
+ {
36+
+ . = ALIGN(16);
37+
+ *(input_data_sec)
38+
+ . = ALIGN(16);
39+
+ } > DDR :null
40+
+
41+
__eddr_data = ALIGN (4) ;
42+
.sram.data : {
43+
__sram_data_start__ = .;
44+
diff --git a/targets/corstone-320/platform.ld b/targets/corstone-320/platform.ld
45+
index c8261c0..9b7e071 100644
46+
--- a/targets/corstone-320/platform.ld
47+
+++ b/targets/corstone-320/platform.ld
48+
@@ -268,8 +268,6 @@ SECTIONS
49+
*(network_model_sec)
50+
#endif
51+
52+
- . = ALIGN(4);
53+
- *(input_data_sec)
54+
*(expected_output_data_sec)
55+
*(output_data_sec)
56+
57+
@@ -279,6 +277,12 @@ SECTIONS
58+
__etext = .;
59+
} > DDR :rom_dram
60+
61+
+ .ddr_noload (NOLOAD) :
62+
+ {
63+
+ . = ALIGN(16);
64+
+ *(input_data_sec)
65+
+ } > DDR :null
66+
+
67+
.bss :
68+
{
69+
. = ALIGN(4);
70+
--
71+
2.43.0
72+
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# Copyright 2025 Arm Limited and/or its affiliates.\n",
10+
"#\n",
11+
"# This source code is licensed under the BSD-style license found in the\n",
12+
"# LICENSE file in the root directory of this source tree."
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"# Ethos-U delegate flow example\n",
20+
"\n",
21+
"This guide demonstrates the full flow for running a module on Arm Ethos-U using ExecuTorch. \n",
22+
"Tested on Linux x86_64 and macOS aarch64. If something is not working for you, please raise a GitHub issue and tag Arm.\n",
23+
"\n",
24+
"Before you begin:\n",
25+
"1. (In a clean virtual environment with a compatible Python version) Install executorch using `./install_executorch.sh`\n",
26+
"2. Install Arm cross-compilation toolchain and simulators using `examples/arm/setup.sh --i-agree-to-the-contained-eula`\n",
27+
"3. Add Arm cross-compilation toolchain and simulators to PATH using `examples/arm/ethos-u-scratch/setup_path.sh` \n",
28+
"\n",
29+
"With all commands executed from the base `executorch` folder.\n",
30+
"\n",
31+
"\n",
32+
"\n",
33+
"*Some scripts in this notebook produces long output logs: Configuring the 'Customizing Notebook Layout' settings to enable 'Output:scrolling' and setting 'Output:Text Line Limit' makes this more manageable*"
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"metadata": {},
39+
"source": [
40+
"## AOT Flow\n",
41+
"\n",
42+
"The first step is creating the PyTorch module and exporting it. Exporting converts the python code in the module into a graph structure. The result is still runnable python code, which can be displayed by printing the `graph_module` of the exported program. "
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"import torch\n",
52+
"\n",
53+
"class Add(torch.nn.Module):\n",
54+
" def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
55+
" return x + y\n",
56+
"\n",
57+
"example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1))\n",
58+
"\n",
59+
"model = Add()\n",
60+
"model = model.eval()\n",
61+
"exported_program = torch.export.export_for_training(model, example_inputs)\n",
62+
"graph_module = exported_program.module()\n",
63+
"\n",
64+
"_ = graph_module.print_readable()"
65+
]
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"metadata": {},
70+
"source": [
71+
"To run on Ethos-U the `graph_module` must be quantized using the `arm_quantizer`. Quantization can be done in multiple ways and it can be customized for different parts of the graph; shown here is the recommended path for the EthosUBackend. Quantization also requires calibrating the module with example inputs.\n",
72+
"\n",
73+
"Again printing the module, it can be seen that the quantization wraps the node in quantization/dequantization nodes which contain the computed quanitzation parameters."
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder\n",
83+
"from executorch.backends.arm.quantizer.arm_quantizer import (\n",
84+
" EthosUQuantizer,\n",
85+
" get_symmetric_quantization_config,\n",
86+
")\n",
87+
"from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e\n",
88+
"\n",
89+
"target = \"ethos-u55-128\"\n",
90+
"\n",
91+
"# Create a compilation spec describing the target for configuring the quantizer\n",
92+
"# Some args are used by the Arm Vela graph compiler later in the example. Refer to Arm Vela documentation for an \n",
93+
"# explanation of its flags: https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/OPTIONS.md\n",
94+
"spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(\n",
95+
" target,\n",
96+
" system_config=\"Ethos_U55_High_End_Embedded\",\n",
97+
" memory_mode=\"Shared_Sram\",\n",
98+
" extra_flags=\"--output-format=raw --debug-force-regor\"\n",
99+
" )\n",
100+
"compile_spec = spec_builder.build()\n",
101+
"\n",
102+
"# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
103+
"quantizer = EthosUQuantizer(compile_spec) \n",
104+
"operator_config = get_symmetric_quantization_config(is_per_channel=False)\n",
105+
"quantizer.set_global(operator_config)\n",
106+
"\n",
107+
"# Post training quantization\n",
108+
"quantized_graph_module = prepare_pt2e(graph_module, quantizer) \n",
109+
"quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input\n",
110+
"quantized_graph_module = convert_pt2e(quantized_graph_module)\n",
111+
"\n",
112+
"_ = quantized_graph_module.print_readable()\n",
113+
"\n",
114+
"# Create a new exported program using the quantized_graph_module\n",
115+
"quantized_exported_program = torch.export.export_for_training(quantized_graph_module, example_inputs)"
116+
]
117+
},
118+
{
119+
"cell_type": "markdown",
120+
"metadata": {},
121+
"source": [
122+
"The quantization nodes created in the previous cell are not built by default with ExecuTorch but must be included in the .pte-file, and so they need to be built separately. `backends/arm/scripts/build_quantized_ops_aot_lib.sh` is a utility script which does this. "
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {},
129+
"outputs": [],
130+
"source": [
131+
"import subprocess \n",
132+
"import os \n",
133+
"\n",
134+
"# Setup paths\n",
135+
"cwd_dir = os.getcwd()\n",
136+
"et_dir = os.path.join(cwd_dir, \"..\", \"..\")\n",
137+
"et_dir = os.path.abspath(et_dir)\n",
138+
"script_dir = os.path.join(et_dir, \"backends\", \"arm\", \"scripts\")\n",
139+
"\n",
140+
"# Run build_quantized_ops_aot_lib.sh\n",
141+
"subprocess.run(os.path.join(script_dir, \"build_quantized_ops_aot_lib.sh\"), shell=True, cwd=et_dir)"
142+
]
143+
},
144+
{
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"The lowering in the EthosUBackend happens in five steps:\n",
149+
"\n",
150+
"1. **Lowering to core Aten operator set**: Transform module to use a subset of operators applicable to edge devices. \n",
151+
"2. **Partitioning**: Find subgraphs which are supported for running on Ethos-U\n",
152+
"3. **Lowering to TOSA compatible operator set**: Perform transforms to make the Ethos-U subgraph(s) compatible with TOSA \n",
153+
"4. **Serialization to TOSA**: Compiles the graph module into a TOSA graph \n",
154+
"5. **Compilation to NPU**: Compiles the TOSA graph into an EthosU command stream using the Arm Vela graph compiler. This makes use of the `compile_spec` created earlier.\n",
155+
"Step 5 also prints a Network summary for each processed subgraph.\n",
156+
"\n",
157+
"All of this happens behind the scenes in `to_edge_transform_and_lower`. Printing the graph module shows that what is left in the graph is two quantization nodes for `x` and `y` going into an `executorch_call_delegate` node, followed by a dequantization node."
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"metadata": {},
164+
"outputs": [],
165+
"source": [
166+
"from executorch.backends.arm.ethosu_partitioner import EthosUPartitioner\n",
167+
"from executorch.exir import (\n",
168+
" EdgeCompileConfig,\n",
169+
" ExecutorchBackendConfig,\n",
170+
" to_edge_transform_and_lower,\n",
171+
")\n",
172+
"from executorch.extension.export_util.utils import save_pte_program\n",
173+
"import platform \n",
174+
"\n",
175+
"# Create partitioner from compile spec \n",
176+
"partitioner = EthosUPartitioner(compile_spec)\n",
177+
"\n",
178+
"# Lower the exported program to the Ethos-U backend\n",
179+
"edge_program_manager = to_edge_transform_and_lower(\n",
180+
" quantized_exported_program,\n",
181+
" partitioner=[partitioner],\n",
182+
" compile_config=EdgeCompileConfig(\n",
183+
" _check_ir_validity=False,\n",
184+
" ),\n",
185+
" )\n",
186+
"\n",
187+
"# Load quantization ops library\n",
188+
"os_aot_lib_names = {\"Darwin\" : \"libquantized_ops_aot_lib.dylib\", \n",
189+
" \"Linux\" : \"libquantized_ops_aot_lib.so\", \n",
190+
" \"Windows\": \"libquantized_ops_aot_lib.dll\"}\n",
191+
"aot_lib_name = os_aot_lib_names[platform.system()]\n",
192+
"\n",
193+
"libquantized_ops_aot_lib_path = os.path.join(et_dir, \"cmake-out-aot-lib\", \"kernels\", \"quantized\", aot_lib_name)\n",
194+
"torch.ops.load_library(libquantized_ops_aot_lib_path)\n",
195+
"\n",
196+
"# Convert edge program to executorch\n",
197+
"executorch_program_manager = edge_program_manager.to_executorch(\n",
198+
" config=ExecutorchBackendConfig(extract_delegate_segments=False)\n",
199+
" )\n",
200+
"\n",
201+
"executorch_program_manager.exported_program().module().print_readable()\n",
202+
"\n",
203+
"# Save pte file\n",
204+
"pte_base_name = \"simple_example\"\n",
205+
"pte_name = pte_base_name + \".pte\"\n",
206+
"pte_path = os.path.join(cwd_dir, pte_name)\n",
207+
"save_pte_program(executorch_program_manager, pte_name)\n",
208+
"assert os.path.exists(pte_path), \"Build failed; no .pte-file found\""
209+
]
210+
},
211+
{
212+
"cell_type": "markdown",
213+
"metadata": {},
214+
"source": [
215+
"## Build executor runtime\n",
216+
"\n",
217+
"After the AOT compilation flow is done, the runtime can be cross compiled and linked to the produced .pte-file using the Arm cross-compilation toolchain. This is done in three steps:\n",
218+
"1. Build the executorch library and EthosUDelegate.\n",
219+
"2. Build any external kernels required. In this example this is not needed as the graph is fully delegated, but its included for completeness.\n",
220+
"3. Build and link the `arm_executor_runner`."
221+
]
222+
},
223+
{
224+
"cell_type": "code",
225+
"execution_count": null,
226+
"metadata": {},
227+
"outputs": [],
228+
"source": [
229+
"# Build executorch \n",
230+
"subprocess.run(os.path.join(script_dir, \"build_executorch.sh\"), shell=True, cwd=et_dir)\n",
231+
"\n",
232+
"# Build portable kernels\n",
233+
"subprocess.run(os.path.join(script_dir, \"build_portable_kernels.sh\"), shell=True, cwd=et_dir)\n",
234+
"\n",
235+
"# Build executorch runner\n",
236+
"args = f\"--pte={pte_path} --target={target}\"\n",
237+
"subprocess.run(os.path.join(script_dir, \"build_executorch_runner.sh\") + \" \" + args, shell=True, cwd=et_dir)\n",
238+
"\n",
239+
"elf_path = os.path.join(cwd_dir, pte_base_name, \"cmake-out\", \"arm_executor_runner\")\n",
240+
"assert os.path.exists(elf_path), \"Build failed; no .elf-file found\""
241+
]
242+
},
243+
{
244+
"cell_type": "markdown",
245+
"metadata": {},
246+
"source": [
247+
"# Run on simulated model\n",
248+
"\n",
249+
"We can finally use the `backends/arm/scripts/run_fvp.sh` utility script to run the .elf-file on simulated Arm hardware. This Script runs the model with an input of ones, so the expected result of the addition should be close to 2."
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"metadata": {},
256+
"outputs": [],
257+
"source": [
258+
"args = f\"--elf={elf_path} --target={target}\"\n",
259+
"subprocess.run(os.path.join(script_dir, \"run_fvp.sh\") + \" \" + args, shell=True, cwd=et_dir)"
260+
]
261+
}
262+
],
263+
"metadata": {
264+
"kernelspec": {
265+
"display_name": "venv",
266+
"language": "python",
267+
"name": "python3"
268+
},
269+
"language_info": {
270+
"codemirror_mode": {
271+
"name": "ipython",
272+
"version": 3
273+
},
274+
"file_extension": ".py",
275+
"mimetype": "text/x-python",
276+
"name": "python",
277+
"nbconvert_exporter": "python",
278+
"pygments_lexer": "ipython3",
279+
"version": "3.10.15"
280+
}
281+
},
282+
"nbformat": 4,
283+
"nbformat_minor": 4
284+
}

‎exir/backend/canonical_partitioners/TARGETS‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"duplicate_dequant_node_pass.py",
99
"pattern_op_partitioner.py",
10+
"all_node_partitioner.py",
1011
],
1112
visibility = [
1213
"//executorch/...",
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.exir.backend.backend_details import ExportedProgram
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
18+
19+
20+
def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
21+
"""
22+
Returns true if the node is a placeholder node and it is not a tensor
23+
"""
24+
return node.op == "placeholder" and not (
25+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
26+
)
27+
28+
29+
class AllNodePartitioner(Partitioner):
30+
def __init__(
31+
self,
32+
backend_id: str,
33+
compile_specs: List[CompileSpec],
34+
):
35+
"""
36+
Partitioner that lowers every single node in the graph module unconditionally
37+
to the specified backend_id
38+
"""
39+
super().__init__()
40+
self.delegation_spec = DelegationSpec(backend_id, compile_specs)
41+
42+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
43+
# tag all nodes
44+
partition_tags: Dict[str, DelegationSpec] = {}
45+
for node in exported_program.graph_module.graph.nodes:
46+
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
47+
continue
48+
49+
delegation_tag = self.delegation_spec.backend_id
50+
node.meta["delegation_tag"] = delegation_tag
51+
partition_tags[delegation_tag] = self.delegation_spec
52+
53+
return PartitionResult(
54+
tagged_exported_program=exported_program, partition_tags=partition_tags
55+
)

‎exir/backend/test/test_backends.py‎

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import executorch.exir as exir
1212
import torch
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
15+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
16+
AllNodePartitioner,
17+
)
1418
from executorch.exir.backend.compile_spec_schema import CompileSpec
1519
from executorch.exir.backend.partitioner import (
1620
DelegationSpec,
@@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]):
12661270

12671271
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12681272
gm(*inputs)
1273+
1274+
def test_to_backend_delegation_spec(self):
1275+
class SinModule(torch.nn.Module):
1276+
def __init__(self):
1277+
super().__init__()
1278+
1279+
def forward(self, x):
1280+
return [torch.sin(x)]
1281+
1282+
sin_module = SinModule()
1283+
model_inputs = (torch.ones(1),)
1284+
max_value = model_inputs[0].shape[0]
1285+
1286+
partitioner = AllNodePartitioner(
1287+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1288+
)
1289+
1290+
edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
1291+
edgeir_m = edgeir_m.to_backend(partitioner)
1292+
exec_prog = edgeir_m.to_executorch()
1293+
graph_module = exec_prog.exported_program().graph_module
1294+
# Check that there is not an aten.sin node.
1295+
self.assertTrue(
1296+
exir_ops.edge.aten.sin
1297+
not in {node.target for node in graph_module.graph.nodes}
1298+
)
1299+
1300+
# Check that there exists a call_delegate, representing the call to the
1301+
# delegated function
1302+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1303+
graph_module.code
1304+
)
1305+
lowered_submodules = get_lowered_submodules(graph_module)
1306+
self.assertEqual(len(lowered_submodules), 1)
1307+
1308+
for node in graph_module.graph.nodes:
1309+
if node.op == "call_function" and node.target == executorch_call_delegate:
1310+
# Check that first arg is lowered_module_{unique_id}
1311+
self.assertEqual(node.args[0].target, "lowered_module_0")
1312+
1313+
program = exec_prog.executorch_program
1314+
1315+
# Check the program can be printed
1316+
print_program(program)
1317+
1318+
# Check the backend delegate
1319+
self.check_backend_delegate(
1320+
program=program,
1321+
delegate=program.execution_plan[0].delegates[0],
1322+
expected_id=BackendWithCompilerDemo.__name__,
1323+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1324+
)
1325+
1326+
# Check the delegate instruction
1327+
self.assertTrue(
1328+
isinstance(
1329+
program.execution_plan[0].chains[0].instructions[0].instr_args,
1330+
DelegateCall,
1331+
)
1332+
)
1333+
buff = exec_prog.buffer
1334+
1335+
executorch_module = _load_for_executorch_from_buffer(buff)
1336+
model_inputs = torch.ones(1)
1337+
model_outputs = executorch_module.forward([model_inputs])
1338+
self.assertEqual(
1339+
model_inputs,
1340+
torch.ones(1),
1341+
)
1342+
expected_output = 0.8333 * torch.ones(1)
1343+
1344+
self.assertTrue(
1345+
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
1346+
)
1347+
1348+
def test_to_backend_multimethod_delegation_spec(self):
1349+
class SinModule(torch.nn.Module):
1350+
def __init__(self):
1351+
super().__init__()
1352+
1353+
def forward(self, x):
1354+
return torch.sin(x)
1355+
1356+
def inputs(self):
1357+
return (torch.ones(1),)
1358+
1359+
class AddMulModule(torch.nn.Module):
1360+
def __init__(self):
1361+
super().__init__()
1362+
1363+
def forward(self, a, x, b):
1364+
y = torch.mm(a, x)
1365+
z = torch.add(y, b)
1366+
return z
1367+
1368+
def inputs(self):
1369+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
1370+
1371+
sin_module = SinModule()
1372+
max_value_sin = sin_module.inputs()[0].shape[0]
1373+
sin_partitioner = AllNodePartitioner(
1374+
"BackendWithCompilerDemo",
1375+
[CompileSpec("max_value", bytes([max_value_sin]))],
1376+
)
1377+
1378+
add_mul_module = AddMulModule()
1379+
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
1380+
add_mul_partitioner = AllNodePartitioner(
1381+
"BackendWithCompilerDemo",
1382+
[CompileSpec("max_value", bytes([max_value_add_mul]))],
1383+
)
1384+
1385+
edgeir_m = to_edge(
1386+
{
1387+
"sin": torch.export.export(sin_module, sin_module.inputs()),
1388+
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
1389+
}
1390+
)
1391+
edgeir_m = edgeir_m.to_backend(
1392+
{
1393+
"sin": sin_partitioner,
1394+
"add_mul": add_mul_partitioner,
1395+
}
1396+
)
1397+
exec_prog = edgeir_m.to_executorch()
1398+
1399+
for method_name in ["sin", "add_mul"]:
1400+
graph_module = exec_prog.exported_program(method_name).graph_module
1401+
# Check delegated nodes are gone
1402+
self.assertTrue(
1403+
exir_ops.edge.aten.sin
1404+
not in {node.target for node in graph_module.graph.nodes}
1405+
)
1406+
self.assertTrue(
1407+
exir_ops.edge.aten.add
1408+
not in {node.target for node in graph_module.graph.nodes}
1409+
)
1410+
self.assertTrue(
1411+
exir_ops.edge.aten.mm
1412+
not in {node.target for node in graph_module.graph.nodes}
1413+
)
1414+
# Check that there exists a call_delegate, representing the call to the
1415+
# delegated function
1416+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1417+
graph_module.code
1418+
)
1419+
lowered_submodules = get_lowered_submodules(graph_module)
1420+
self.assertEqual(len(lowered_submodules), 1)
1421+
1422+
program = exec_prog.executorch_program
1423+
1424+
# Check the program can be printed
1425+
print_program(program)
1426+
1427+
buff = exec_prog.buffer
1428+
1429+
executorch_module = _load_for_executorch_from_buffer(buff)
1430+
1431+
for method_name, module in {
1432+
"sin": sin_module,
1433+
"add_mul": add_mul_module,
1434+
}.items():
1435+
inputs_flattened, _ = tree_flatten(module.inputs())
1436+
model_outputs = executorch_module.run_method(
1437+
method_name, tuple(inputs_flattened)
1438+
)
1439+
1440+
if method_name == "sin":
1441+
# backend with compiler demo does a taylor approximation of sin
1442+
ref_output = 0.8333 * torch.ones(1)
1443+
else:
1444+
ref_output = module(*module.inputs())
1445+
self.assertTrue(
1446+
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
1447+
)

‎exir/backend/test/test_backends_lifted.py‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import torch
1212
from executorch.exir import to_edge
1313
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
15+
AllNodePartitioner,
16+
)
1417
from executorch.exir.backend.compile_spec_schema import CompileSpec
1518
from executorch.exir.backend.partitioner import (
1619
DelegationSpec,
@@ -138,6 +141,18 @@ def forward(self, x):
138141

139142
self.assertTrue(torch.allclose(new_res, expected_res))
140143

144+
# Test same flow but through edge_program_manager
145+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
146+
loweredir_m = edgeir_m.to_backend(
147+
AllNodePartitioner(BackendWithCompilerDemo.__name__, [])
148+
)
149+
lowered_sin_module = get_lowered_submodules(
150+
loweredir_m.exported_program().graph_module
151+
)[0][1]
152+
153+
new_res = lowered_sin_module(*model_inputs)[0]
154+
155+
self.assertTrue(torch.allclose(new_res, expected_res))
141156
# TODO(tkaruturi): emitting single LoweredBackendModule
142157
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143158

‎exir/backend/test/test_compatibility.py‎

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from executorch.exir import to_edge
1111
from executorch.exir._serialize import _serialize_pte_binary
1212
from executorch.exir.backend.backend_api import to_backend
13+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
14+
AllNodePartitioner,
15+
)
1316
from executorch.exir.backend.compile_spec_schema import CompileSpec
1417
from executorch.exir.backend.test.backend_with_compiler_demo import (
1518
BackendWithCompilerDemo,
@@ -65,3 +68,49 @@ def forward(self, x):
6568
"loading method forward failed with error 0x30",
6669
):
6770
executorch_module = _load_for_executorch_from_buffer(buff)
71+
72+
def test_compatibility_in_runtime_edge_program_manager(self):
73+
class SinModule(torch.nn.Module):
74+
def __init__(self):
75+
super().__init__()
76+
77+
def forward(self, x):
78+
return torch.sin(x)
79+
80+
sin_module = SinModule()
81+
model_inputs = (torch.ones(1),)
82+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
83+
max_value = model_inputs[0].shape[0]
84+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
85+
lowered_edge_irm = edgeir_m.to_backend(
86+
AllNodePartitioner("BackendWithCompilerDemo", compile_specs)
87+
)
88+
exec_prog = lowered_edge_irm.to_executorch()
89+
90+
buff = exec_prog.buffer
91+
92+
# The demo backend works well
93+
executorch_module = _load_for_executorch_from_buffer(buff)
94+
model_inputs = torch.ones(1)
95+
_ = executorch_module.forward([model_inputs])
96+
97+
prog = exec_prog.executorch_program
98+
# Rewrite the delegate version number from 0 to 1.
99+
prog.backend_delegate_data[0].data = bytes(
100+
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
101+
encoding="utf8",
102+
)
103+
104+
# Generate the .pte file with the wrong version.
105+
buff = bytes(
106+
_serialize_pte_binary(
107+
program=prog,
108+
)
109+
)
110+
111+
# Throw runtime error with error code 0x30, meaning delegate is incompatible.
112+
with self.assertRaisesRegex(
113+
RuntimeError,
114+
"loading method forward failed with error 0x30",
115+
):
116+
executorch_module = _load_for_executorch_from_buffer(buff)

‎exir/program/TARGETS‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
"//executorch/exir/_serialize:lib",
3232
"//executorch/exir/backend:backend_api",
3333
"//executorch/exir/backend:partitioner",
34+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3435
"//executorch/exir/capture:config",
3536
"//executorch/exir/emit:emit",
3637
"//executorch/exir/emit:lib",

‎extension/llm/custom_ops/op_sdpa.cpp‎

Lines changed: 215 additions & 715 deletions
Large diffs are not rendered by default.

‎extension/llm/custom_ops/op_sdpa_impl.h‎

Lines changed: 772 additions & 0 deletions
Large diffs are not rendered by default.

‎extension/llm/custom_ops/targets.bzl‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def define_common_targets():
3030
"op_sdpa.h",
3131
"op_update_cache.h",
3232
],
33+
headers = [
34+
"op_sdpa_impl.h",
35+
],
3336
preprocessor_flags = get_vec_preprocessor_flags(),
3437
exported_deps = [
3538
"//executorch/runtime/kernel:kernel_includes",

‎extension/parallel/targets.bzl‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ def define_common_targets():
1717
"@EXECUTORCH_CLIENTS",
1818
],
1919
deps = [
20-
"//executorch/runtime/kernel:thread_parallel_interface",
20+
"//executorch/extension/threadpool:threadpool",
2121
],
2222
)

‎extension/threadpool/targets.bzl‎

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def define_common_targets():
2020
] + (["fb/threadpool_use_n_threads.h"] if not runtime.is_oss else [])
2121

2222
runtime.cxx_library(
23-
name = "threadpool",
23+
name = "threadpool_lib",
2424
srcs = _THREADPOOL_SRCS,
2525
deps = [
2626
"//executorch/runtime/core:core",
@@ -45,6 +45,38 @@ def define_common_targets():
4545
],
4646
)
4747

48+
runtime.cxx_library(
49+
name = "threadpool",
50+
# TODO: OSS doesn't have os:iphoneos. Sync buck2 prelude
51+
# update to add it and remove duplication.
52+
exported_deps = (select({
53+
# Major operating systems should be able to use threadpool.
54+
"ovr_config//os:linux": [":threadpool_lib"],
55+
"ovr_config//os:macos": [":threadpool_lib"],
56+
"ovr_config//os:windows": [":threadpool_lib"],
57+
"ovr_config//os:android": [":threadpool_lib"],
58+
"ovr_config//os:iphoneos": [":threadpool_lib"],
59+
# Machines without an operating system shouldn't.
60+
"ovr_config//os:none": ["//executorch/runtime/kernel:thread_parallel_interface"],
61+
# If we don't know what it is, disable threadpool out of caution.
62+
"DEFAULT": ["//executorch/runtime/kernel:thread_parallel_interface"],
63+
}) if not runtime.is_oss else select({
64+
# Major operating systems should be able to use threadpool.
65+
"ovr_config//os:linux": [":threadpool_lib"],
66+
"ovr_config//os:macos": [":threadpool_lib"],
67+
"ovr_config//os:windows": [":threadpool_lib"],
68+
"ovr_config//os:android": [":threadpool_lib"],
69+
# Machines without an operating system shouldn't.
70+
"ovr_config//os:none": ["//executorch/runtime/kernel:thread_parallel_interface"],
71+
# If we don't know what it is, disable threadpool out of caution.
72+
"DEFAULT": ["//executorch/runtime/kernel:thread_parallel_interface"],
73+
})),
74+
visibility = [
75+
"//executorch/...",
76+
"@EXECUTORCH_CLIENTS",
77+
],
78+
)
79+
4880
runtime.cxx_library(
4981
name = "cpuinfo_utils",
5082
srcs = [

‎install_requirements.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def python_is_compatible():
7171
#
7272
# NOTE: If you're changing, make the corresponding change in .ci/docker/ci_commit_pins/pytorch.txt
7373
# by picking the hash from the same date in https://hud.pytorch.org/hud/pytorch/pytorch/nightly/
74-
NIGHTLY_VERSION = "dev20250310"
74+
NIGHTLY_VERSION = "dev20250325"
7575

7676

7777
def install_requirements(use_pytorch_nightly):
@@ -80,7 +80,7 @@ def install_requirements(use_pytorch_nightly):
8080
# Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note
8181
# that we don't need to set any version number there because they have already
8282
# been installed on CI before this step, so pip won't reinstall them
83-
f"torch==2.7.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
83+
f"torch==2.8.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
8484
(
8585
f"torchvision==0.22.0.{NIGHTLY_VERSION}"
8686
if use_pytorch_nightly

‎kernels/optimized/cpu/op_elu.cpp‎

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/native/cpu/Elu.h>
10+
11+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
16+
namespace torch::executor::native {
17+
18+
namespace {
19+
template <typename CTYPE>
20+
void elu(
21+
KernelRuntimeContext& context,
22+
const Tensor& input,
23+
const Scalar& alpha,
24+
const Scalar& scale,
25+
const Scalar& input_scale,
26+
Tensor& out) {
27+
const CTYPE* in_data = input.const_data_ptr<CTYPE>();
28+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
29+
using MathT =
30+
std::conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
31+
MathT math_alpha = 0;
32+
MathT math_scale = 0;
33+
MathT math_input_scale = 0;
34+
ET_EXTRACT_SCALAR(alpha, math_alpha);
35+
ET_EXTRACT_SCALAR(scale, math_scale);
36+
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
37+
const auto scalar_func =
38+
at::native::get_scalar_elu_elementwise_func<CTYPE, MathT>(
39+
math_alpha, math_scale, math_input_scale);
40+
const auto vec_func = at::native::get_vectorized_elu_elementwise_func<CTYPE>(
41+
math_alpha, math_scale, math_input_scale);
42+
43+
::executorch::extension::parallel_for(
44+
0,
45+
out.numel(),
46+
::executorch::extension::internal::GRAIN_SIZE,
47+
[&](const auto begin, const auto end) {
48+
using Vec = at::vec::Vectorized<CTYPE>;
49+
const auto vectorized_begin =
50+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
51+
const auto vectorized_end = end - (end % Vec::size());
52+
// Scalar prologue.
53+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
54+
out_data[idx] = scalar_func(in_data[idx]);
55+
}
56+
57+
// Main vectorized loop.
58+
for (auto idx = vectorized_begin; idx < vectorized_end;
59+
idx += Vec::size()) {
60+
auto result_vec = vec_func(Vec::loadu(&in_data[idx]));
61+
result_vec.store(&out_data[idx]);
62+
}
63+
64+
// Scalar epilogue.
65+
for (const auto idx : c10::irange(vectorized_end, end)) {
66+
out_data[idx] = scalar_func(in_data[idx]);
67+
}
68+
});
69+
}
70+
} // namespace
71+
72+
Tensor& opt_elu_out(
73+
KernelRuntimeContext& ctx,
74+
const Tensor& in,
75+
const Scalar& alpha,
76+
const Scalar& scale,
77+
const Scalar& input_scale,
78+
Tensor& out) {
79+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
80+
ET_KERNEL_CHECK(
81+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
82+
83+
ET_KERNEL_CHECK(
84+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
85+
86+
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);
87+
88+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
89+
90+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, "elu.out", CTYPE, [&]() {
91+
elu<CTYPE>(ctx, in, alpha, scale, input_scale, out);
92+
});
93+
return out;
94+
}
95+
96+
} // namespace torch::executor::native

‎kernels/optimized/cpu/targets.bzl‎

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ _OPTIMIZED_ATEN_OPS = (
2525
"//executorch/kernels/portable/cpu/util:broadcast_util",
2626
],
2727
),
28+
op_target(
29+
name = "op_elu",
30+
deps = [
31+
"//executorch/extension/threadpool:threadpool",
32+
"//executorch/kernels/portable/cpu:scalar_utils",
33+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
34+
],
35+
),
2836
op_target(name = "op_exp"),
2937
op_target(
3038
name = "op_fft_r2c",
@@ -99,8 +107,8 @@ _OPTIMIZED_ATEN_OPS = (
99107
op_target(
100108
name = "op_where",
101109
deps = [
110+
"//executorch/extension/threadpool:threadpool",
102111
"//executorch/kernels/portable/cpu/util:elementwise_util",
103-
"//executorch/runtime/kernel:thread_parallel_interface",
104112
],
105113
),
106114
)

‎kernels/optimized/lib_defs.bzl‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def define_libs(is_fbcode=False):
232232
"DEFAULT": [],
233233
}) + LIBBLAS_DEPS,
234234
exported_deps = [
235+
"//executorch/extension/threadpool:threadpool",
235236
"//executorch/kernels/optimized:libutils",
236237
"//executorch/runtime/core/exec_aten:lib",
237-
"//executorch/runtime/kernel:thread_parallel_interface",
238238
],
239239
**get_apple_framework_deps_kwargs(is_fbcode),
240240
)

‎kernels/optimized/optimized.yaml‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
- arg_meta: null
3838
kernel_name: torch::executor::opt_div_scalar_out
3939

40+
- op: elu.out
41+
kernels:
42+
- arg_meta: null
43+
kernel_name: torch::executor::opt_elu_out
44+
4045
- op: exp.out
4146
kernels:
4247
- arg_meta: null

‎kernels/portable/CMakeLists.txt‎

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ gen_operators_lib(
6666
# Portable kernels support optional parallelization (and, in the
6767
# future, perhaps other performance features). If support is present,
6868
# produce an optimized version.
69-
set(BUILD_OPTIMIZED_PORTABLE_KERNELS EXECUTORCH_BUILD_PTHREADPOOL)
70-
71-
if(BUILD_OPTIMIZED_PORTABLE_KERNELS)
69+
if(EXECUTORCH_BUILD_PTHREADPOOL AND EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
7270
add_library(optimized_portable_kernels ${_portable_kernels__srcs})
7371
target_link_libraries(optimized_portable_kernels PRIVATE executorch)
7472
target_link_libraries(optimized_portable_kernels PUBLIC extension_threadpool)

‎kernels/portable/cpu/util/targets.bzl‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def define_common_targets():
1212
runtime.cxx_library(
1313
name = "all_deps",
1414
deps = [
15+
"//executorch/extension/threadpool:threadpool",
1516
"//executorch/kernels/portable/cpu/util:functional_util",
1617
"//executorch/kernels/portable/cpu/util:broadcast_util",
1718
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
@@ -32,7 +33,6 @@ def define_common_targets():
3233
"//executorch/kernels/portable/cpu/util:slice_util",
3334
"//executorch/kernels/portable/cpu/util:elementwise_util",
3435
"//executorch/kernels/portable/cpu/util:upsample_util",
35-
"//executorch/runtime/kernel:thread_parallel_interface",
3636
],
3737
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
3838
)
@@ -111,7 +111,7 @@ def define_common_targets():
111111
":broadcast_util",
112112
":dtype_util",
113113
"//executorch/runtime/kernel:kernel_runtime_context",
114-
"//executorch/runtime/kernel:thread_parallel_interface",
114+
"//executorch/extension/threadpool:threadpool",
115115
],
116116
deps = [
117117
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -245,7 +245,7 @@ def define_common_targets():
245245
srcs = [],
246246
exported_headers = ["functional_util.h"],
247247
exported_deps = [
248-
"//executorch/runtime/kernel:thread_parallel_interface",
248+
"//executorch/extension/threadpool:threadpool",
249249
],
250250
deps = [
251251
"//executorch/runtime/kernel:kernel_includes",
@@ -319,7 +319,7 @@ def define_common_targets():
319319
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
320320
],
321321
exported_deps = [
322-
"//executorch/runtime/kernel:thread_parallel_interface",
322+
"//executorch/extension/threadpool:threadpool",
323323
],
324324
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
325325
visibility = [

‎kernels/test/CMakeLists.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ set(_optimized_kernels_test_sources
274274
"op_add_test.cpp"
275275
"op_bmm_test.cpp"
276276
"op_div_test.cpp"
277+
"op_elu_test.cpp"
277278
"op_exp_test.cpp"
278279
"op_fft_r2c_test.cpp"
279280
"op_gelu_test.cpp"

‎kernels/test/targets.bzl‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def define_common_targets():
215215
_common_op_test("op_detach_copy_test", ["aten", "portable"])
216216
_common_op_test("op_diagonal_copy_test", ["aten", "portable"])
217217
_common_op_test("op_div_test", ["aten", "portable", "optimized"])
218-
_common_op_test("op_elu_test", ["aten", "portable"])
218+
_common_op_test("op_elu_test", ["aten", "portable", "optimized"])
219219
_common_op_test("op_embedding_test", ["aten", "portable"])
220220
_common_op_test("op_empty_test", ["aten", "portable"])
221221
_common_op_test("op_eq_test", ["aten", "portable"])

‎runtime/core/event_tracer.h‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ class EventTracerFilterBase {
101101
* - An error code if an error occurs during filtering.
102102
*/
103103
virtual Result<bool> filter(
104-
char* name,
105-
DelegateDebugIntId delegate_debug_index);
104+
const char* name,
105+
DelegateDebugIntId delegate_debug_index) = 0;
106106

107107
/**
108108
* Virtual destructor for the EventTracerFilterBase class.
109109
* Ensures proper cleanup of derived class objects.
110110
*/
111-
virtual ~EventTracerFilterBase();
111+
virtual ~EventTracerFilterBase() = default;
112112
};
113113

114114
/**

‎runtime/kernel/targets.bzl‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def define_common_targets():
5959
"//executorch/runtime/core/portable_type/c10/c10:c10",
6060
"//executorch/runtime/platform:platform",
6161
],
62+
# Don't depend on this target, depend on //executorch/extension/threadpool:threadpool.
6263
visibility = [
63-
"//executorch/...",
64-
"@EXECUTORCH_CLIENTS",
64+
"//executorch/extension/threadpool/...",
6565
],
6666
)
6767

‎shim_et/xplat/executorch/build/runtime_wrapper.bzl‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def _patch_build_mode_flags(kwargs):
112112
# @oss-disable: "ovr_config//build_mode:code-coverage": ["-D__ET_BUILD_MODE_COV=1"],
113113
})
114114

115+
kwargs["compiler_flags"] = kwargs["compiler_flags"] + select({
116+
"DEFAULT": [],
117+
"ovr_config//os:macos": ["-fvisibility=default"],
118+
})
119+
115120
return kwargs
116121

117122
def _patch_test_compiler_flags(kwargs):

‎tools/cmake/executorch-config.cmake‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ endif()
149149
if(TARGET coremldelegate)
150150
set_target_properties(
151151
coremldelegate PROPERTIES INTERFACE_LINK_LIBRARIES
152-
"coreml_inmemoryfs;coreml_util"
152+
"coreml_inmemoryfs;coreml_util"
153153
)
154154
endif()
155155

@@ -167,4 +167,8 @@ if(TARGET optimized_native_cpu_ops_lib)
167167
endif()
168168
if(TARGET extension_threadpool)
169169
target_compile_definitions(extension_threadpool INTERFACE ET_USE_THREADPOOL)
170+
set_target_properties(
171+
extension_threadpool PROPERTIES INTERFACE_LINK_LIBRARIES
172+
"cpuinfo;pthreadpool"
173+
)
170174
endif()

0 commit comments

Comments
 (0)
Please sign in to comment.