Skip to content

Commit 3f5ef3c

Browse files
Merge branch 'main' into add-native-group-norm
2 parents f5d8911 + 5e8295e commit 3f5ef3c

File tree

165 files changed

+3639
-650
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+3639
-650
lines changed

.ci/docker/ci_commit_pins/buck2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2024-12-16
1+
2025-05-06

.github/workflows/apple.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ on:
55
branches:
66
- main
77
- release/*
8+
tags:
9+
- ciflow/trunk/*
810
pull_request:
911
paths:
1012
- .ci/scripts/setup-ios.sh

.github/workflows/build-presets.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Build Presets
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- release/*
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
13+
cancel-in-progress: true

CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444

4545
cmake_minimum_required(VERSION 3.24)
4646
project(executorch)
47+
48+
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
49+
50+
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
52+
53+
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
54+
4755
include(tools/cmake/Utils.cmake)
4856
include(CMakeDependentOption)
4957

@@ -96,9 +104,6 @@ set(EXECUTORCH_PAL_DEFAULT
96104
"Which PAL default implementation to use: one of {posix, minimal}"
97105
)
98106

99-
option(EXECUTORCH_ENABLE_LOGGING "Build with ET_LOG_ENABLED"
100-
${_default_release_disabled_options}
101-
)
102107
if(NOT EXECUTORCH_ENABLE_LOGGING)
103108
# Avoid pulling in the logging strings, which can be large. Note that this
104109
# will set the compiler flag for all targets in this directory, and for all

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import itertools
10-
9+
import operator
1110
from typing import List
1211

1312
import torch
@@ -22,7 +21,7 @@
2221

2322
class AnnotateDecomposedMatmulPass(ExportPass):
2423
"""
25-
torch.matmul can be decomposed in many ways, for instance:
24+
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2625
dq -> matmul -> q can become
2726
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
2827
difficult. This helper function find all matmul partitions and annotate its
@@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5049
graph_module.graph,
5150
[
5251
torch.matmul,
52+
operator.matmul,
5353
],
5454
None,
5555
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeGroupNormPass,
@@ -208,6 +209,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
208209
self.add_pass(DecomposeVarPass())
209210
self.add_pass(DecomposeMeanDimPass())
210211
self.add_pass(DecomposeNotEqualPass())
212+
self.add_pass(DecomposeCosineSimilarityPass())
211213
self.add_pass(DecomposeDivPass())
212214
self.add_pass(DecomposeLeakyReLUPass())
213215
self.add_pass(DecomposeSqrtPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/operator_support/pool_2d_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5454
kernel = cast(tuple[int, int], node.args[1])
5555
stride = cast(tuple[int, int], node.args[2])
5656
if len(node.args) > 3:
57+
padding = cast(tuple[int, int], node.args[3])
5758
# Padding case
58-
if not all(1 <= k <= 8 for k in kernel):
59+
if not all(1 <= k <= 8 for k in kernel) and not all(
60+
v == 0 for v in padding
61+
):
5962
self.reporter.report_reject(
6063
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
6164
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def _is_matmul_node_supported(
337337
graph_module.graph,
338338
[
339339
torch.matmul,
340+
operator.matmul,
340341
],
341342
None,
342343
)
@@ -387,7 +388,7 @@ def is_node_supported(
387388
):
388389
source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", [])
389390
if len(source_fn_stack) > 0:
390-
if source_fn_stack[-1][1] in (torch.matmul,):
391+
if source_fn_stack[-1][1] in (torch.matmul, operator.matmul):
391392
return self._is_matmul_node_supported(submodules, node)
392393

393394
elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,):

backends/arm/operators/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@ python_library(
1010
],
1111
)
1212

13+
python_library(
14+
name = "operator_validation_utils",
15+
srcs = ["operator_validation_utils.py"],
16+
)
17+
1318
python_library(
1419
name = "ops",
1520
srcs = glob(["op_*.py", "ops_*.py"]),
1621
deps = [
1722
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa",
1823
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
1924
":node_visitor",
25+
":operator_validation_utils",
2026
"//executorch/backends/arm:tosa_mapping",
2127
"//executorch/backends/arm:tosa_quant_utils",
2228
"//executorch/backends/arm:tosa_utils",

backends/arm/operators/op_abs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16+
from executorch.backends.arm.operators.operator_validation_utils import (
17+
validate_num_inputs,
18+
)
1619
from executorch.backends.arm.tosa_mapping import TosaArg
1720
from executorch.backends.arm.tosa_specification import TosaSpecification
1821
from torch.fx import Node
@@ -39,6 +42,7 @@ def define_node(
3942

4043
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4144

45+
validate_num_inputs(self.target, inputs, 1)
4246
# Specification (0.80) states that input and output types
4347
# should all be the same
4448
if not (inputs[0].dtype == output.dtype):
@@ -105,6 +109,7 @@ def define_node(
105109

106110
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107111

112+
validate_num_inputs(self.target, inputs, 1)
108113
# Specification (0.80) states that input and output types
109114
# should all be the same
110115
if not (inputs[0].dtype == output.dtype):
@@ -157,6 +162,8 @@ def define_node(
157162

158163
import serializer.tosa_serializer as ts # type: ignore
159164

165+
validate_num_inputs(self.target, inputs, 1)
166+
160167
# Specification (1.0) states that input and output types
161168
# should all be the same
162169
if not (inputs[0].dtype == output.dtype):
@@ -224,6 +231,8 @@ def define_node(
224231

225232
import serializer.tosa_serializer as ts # type: ignore
226233

234+
validate_num_inputs(self.target, inputs, 1)
235+
227236
# Specification (1.0) states that input and output types
228237
# should all be the same
229238
if not (inputs[0].dtype == output.dtype):

backends/arm/operators/op_add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
17+
from executorch.backends.arm.operators.operator_validation_utils import (
18+
validate_num_inputs,
19+
)
1720
from executorch.backends.arm.tosa_mapping import TosaArg
1821
from executorch.backends.arm.tosa_specification import TosaSpecification
1922
from torch.fx import Node
@@ -40,6 +43,7 @@ def define_node(
4043

4144
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4245

46+
validate_num_inputs(self.target, inputs, 2)
4347
# Specification (0.80) states that input and output types
4448
# should all be the same
4549
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -118,6 +122,7 @@ def define_node(
118122

119123
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120124

125+
validate_num_inputs(self.target, inputs, 2)
121126
# Specification (0.80) states that input and output types
122127
# should all be the same
123128
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -169,6 +174,8 @@ def define_node(
169174

170175
import serializer.tosa_serializer as ts # type: ignore
171176

177+
validate_num_inputs(self.target, inputs, 2)
178+
172179
# Specification (1.0) states that input and output types
173180
# should all be the same
174181
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -237,6 +244,8 @@ def define_node(
237244

238245
import serializer.tosa_serializer as ts # type: ignore
239246

247+
validate_num_inputs(self.target, inputs, 2)
248+
240249
# Specification (1.0) states that input and output types
241250
# should all be the same
242251
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
NodeVisitor,
1010
register_node_visitor,
1111
)
12+
from executorch.backends.arm.operators.operator_validation_utils import (
13+
validate_num_inputs,
14+
)
1215
from executorch.backends.arm.tosa_mapping import TosaArg
1316
from torch.fx import Node
1417

@@ -31,6 +34,8 @@ def define_node(
3134
) -> None:
3235
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3336

37+
validate_num_inputs(self.target, inputs, 3)
38+
3439
input = inputs[0]
3540
dim = inputs[1].number
3641

@@ -71,6 +76,8 @@ def define_node(
7176
) -> None:
7277
import serializer.tosa_serializer as ts
7378

79+
validate_num_inputs(self.target, inputs, 3)
80+
7481
input = inputs[0]
7582
dim = inputs[1].number
7683

backends/arm/operators/op_amin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
NodeVisitor,
1010
register_node_visitor,
1111
)
12+
from executorch.backends.arm.operators.operator_validation_utils import (
13+
validate_num_inputs,
14+
)
1215
from executorch.backends.arm.tosa_mapping import TosaArg
1316
from torch.fx import Node
1417

@@ -31,6 +34,8 @@ def define_node(
3134
) -> None:
3235
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3336

37+
validate_num_inputs(self.target, inputs, 3)
38+
3439
input = inputs[0]
3540
dim = inputs[1].number
3641

@@ -71,6 +76,8 @@ def define_node(
7176
) -> None:
7277
import serializer.tosa_serializer as ts
7378

79+
validate_num_inputs(self.target, inputs, 3)
80+
7481
input = inputs[0]
7582
dim = inputs[1].number
7683

0 commit comments

Comments
 (0)