Skip to content

Commit 012f120

Browse files
authoredMar 25, 2025··
Arm backend: Add ComputeConstantOpsAOT pass (#9504)
Operators that output tensors based on constant args are pre-computed and added as buffers. - The pass currently supports full, arange, linspace, and eye. - Remove some logic for full now handled by the pass - Rename FuseConstantOpsPass to FuseConstantArgsPass and do minor improvements Fix retracing in FuseViewCopyTransform Since the pass can change shapes of ops, the graph needs to be retraced to show this in node.meta["val"]. Signed-off-by: Erik Lundell <[email protected]>
1 parent 766bbdc commit 012f120

13 files changed

+338
-136
lines changed
 

‎backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
RetraceFoldedDtypesPass,
5656
)
5757
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
58-
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
58+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
59+
ComputeConstantOpsAOT,
60+
FuseConstantArgsPass,
61+
)
5962
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
6063
FuseQuantizedActivationPass,
6164
)
@@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121124
self.add_pass(QuantizeOperatorArguments())
122125
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
123126
self.add_pass(RetraceFoldedDtypesPass())
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(MatchArgRanksPass(exported_program))
129+
self.add_pass(ComputeConstantOpsAOT(exported_program))
124130

125131
self.add_pass(RemoveClonePass())
126132
self.add_pass(SizeAdjustConv2DPass())
127133
self.add_pass(ConvertExpandCopyToRepeatPass())
128134
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
130135
self.add_pass(CastInt64ToInt32Pass(exported_program))
131-
self.add_pass(MatchArgRanksPass(exported_program))
132136
self.add_pass(KeepDimsFalseToSqueezePass())
133137
self.add_pass(Conv1dUnsqueezePass(exported_program))
134138
self.add_pass(DecomposeSelectPass())
135139
self.add_pass(ConvertSqueezesToViewPass())
136140

137141
self.add_pass(FuseViewCopyTransform())
138-
self.add_pass(FuseConstantOpsPass(exported_program))
142+
self.add_pass(FuseConstantArgsPass(exported_program))
143+
139144
self.add_pass(InsertTableOpsPass(exported_program))
140145
self.add_pass(AnnotateChannelsLastDimOrder())
141146
self.add_pass(InsertRescalePass())
@@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166171
self.add_pass(QuantizeOperatorArguments())
167172
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
168173
self.add_pass(RetraceFoldedDtypesPass())
174+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175+
self.add_pass(MatchArgRanksPass(exported_program))
176+
self.add_pass(ComputeConstantOpsAOT(exported_program))
169177

170178
self.add_pass(RemoveClonePass())
171179
self.add_pass(SizeAdjustConv2DPass())
172180
self.add_pass(ConvertExpandCopyToRepeatPass())
173181
self.add_pass(UnsqueezeBeforeRepeatPass())
174-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175182
self.add_pass(CastInt64ToInt32Pass(exported_program))
176-
self.add_pass(MatchArgRanksPass(exported_program))
177183
self.add_pass(KeepDimsFalseToSqueezePass())
178184
self.add_pass(Conv1dUnsqueezePass(exported_program))
179185
self.add_pass(DecomposeSelectPass())
180186
self.add_pass(ConvertSqueezesToViewPass())
181187

182188
self.add_pass(FuseViewCopyTransform())
183-
self.add_pass(FuseConstantOpsPass(exported_program))
189+
self.add_pass(FuseConstantArgsPass(exported_program))
184190
self.add_pass(InsertTableOpsPass(exported_program))
185191
self.add_pass(AnnotateChannelsLastDimOrder())
186192
self.add_pass(InsertRescalePass())

‎backends/arm/_passes/cast_int64_pass.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import logging
99

1010
import torch
11-
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
1211
from executorch.exir.pass_base import ExportPass, PassResult
1312
from torch._export.utils import is_buffer
1413

@@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram):
2524
super(CastInt64ToInt32Pass, self).__init__()
2625
self.exported_program = exported_program
2726

27+
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
28+
if torch.min(tensor) < torch.iinfo(torch.int32).min:
29+
raise RuntimeError(
30+
f"Node {node.name} has value < {torch.iinfo(torch.int32).min}"
31+
)
32+
if torch.max(tensor) > torch.iinfo(torch.int32).max:
33+
raise RuntimeError(
34+
f"Node {node.name} has value > {torch.iinfo(torch.int32).max}"
35+
)
36+
2837
def _to_int32(self, graph_module: torch.fx.GraphModule):
2938
for node in graph_module.graph.nodes:
3039
fake_tensor = node.meta["val"]
31-
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
32-
if node.meta["val"].dtype == torch.int64 and is_param_node(
33-
self.exported_program, node
34-
):
35-
if is_buffer(self.exported_program, node):
36-
node.meta["val"] = node.meta["val"].to(torch.int32)
37-
buffer_name = (
38-
self.exported_program.graph_signature.inputs_to_buffers[
39-
node.name
40-
]
41-
)
42-
buffer = self.exported_program.state_dict[node.name]
43-
logger.warning(
44-
f"Casting buffer {node.name} from torch.int64 to torch.int32"
45-
f" defined in {node.meta['stack_trace']}"
46-
)
47-
if torch.min(buffer) < torch.iinfo(torch.int32).min:
48-
raise RuntimeError(
49-
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
50-
)
51-
if torch.max(buffer) > torch.iinfo(torch.int32).max:
52-
raise RuntimeError(
53-
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
54-
)
55-
buffer_int32 = buffer.to(torch.int32)
56-
self.exported_program.state_dict[buffer_name] = buffer_int32
40+
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
41+
continue
42+
if fake_tensor.dtype != torch.int64:
43+
continue
44+
if is_buffer(self.exported_program, node):
45+
node.meta["val"] = fake_tensor.to(torch.int32)
46+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
47+
node.name
48+
]
49+
buffer = self.exported_program.state_dict[node.name]
50+
self._assert_within_int32(buffer, node)
51+
logger.warning(
52+
f"Casting buffer {node.name} from torch.int64 to torch.int32"
53+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
54+
)
55+
buffer_int32 = buffer.to(torch.int32)
56+
self.exported_program.state_dict[buffer_name] = buffer_int32
57+
continue
5758

5859
def call(self, graph_module: torch.fx.GraphModule):
5960
self._to_int32(graph_module)

‎backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
174174

175175
class QuantizeOperatorArguments(ExportPass):
176176
"""
177-
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
177+
This pass makes sure that the arguments to clamp.default are quantized correctly.
178178
More specifically, this pass:
179-
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
180-
the folding pass above to make sure that the retraced output of the full.default op is
181-
the right dtype.
182179
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
183180
"""
184181

@@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
189186
n = cast(Node, n)
190187
if n.target not in {
191188
exir_ops.edge.aten.clamp.default,
192-
exir_ops.edge.aten.full.default,
193189
}:
194190
continue
195191

@@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
200196

201197
qargs = QuantArgs.from_operator(user.target, user.args)
202198

203-
if n.target == exir_ops.edge.aten.full.default:
204-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
205-
# replace the node arg with a quantized dito and also set dtype
206-
# to get the right output according to the Edge IR specification:
207-
# exir/dialects/edge/edge.yaml:3596
208-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
209-
n.update_arg(1, quantized_full_value)
210-
n.update_kwarg("dtype", qargs.dtype)
211-
modified = True
212-
elif n.target == exir_ops.edge.aten.clamp.default:
199+
if n.target == exir_ops.edge.aten.clamp.default:
213200
# Quantize the min and max arguments of clamp, if they are not None
214201
min_val = n.args[1]
215202
max_val = None if len(n.args) <= 2 else n.args[2]

‎backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 114 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch._export.utils
99
from executorch.backends.arm._passes.arm_pass_utils import (
1010
get_constant_placeholder_kind,
11+
get_first_fake_tensor,
1112
get_param_tensor,
1213
is_persistent_buffer,
1314
)
@@ -18,11 +19,12 @@
1819
from executorch.exir import ExportedProgram
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.pass_base import ExportPass, PassResult
22+
from torch.export.graph_signature import InputKind
2123

2224
logger = logging.getLogger(__name__)
2325

2426

25-
class FuseConstantOpsPass(ExportPass):
27+
class FuseConstantArgsPass(ExportPass):
2628
"""
2729
Fuses ops with only placeholder parameters into one placeholder parameter node with the op
2830
pre-calulcated on its data.
@@ -42,67 +44,38 @@ def __init__(self, exported_program: ExportedProgram) -> None:
4244
super().__init__()
4345
self.exported_program = exported_program
4446

45-
def fuse_nodes(self, node) -> bool:
47+
def _fuse_nodes(self, node) -> bool:
4648
"""
4749
Takes a node with only parameter inputs and replaces it with one constant tensor node with
4850
the operations already carried out on the data.
4951
"""
5052

51-
if node.target == exir_ops.edge.aten.full.default:
52-
# Create data from args
53-
size, fill_value = node.args
54-
dtype = node.kwargs["dtype"]
55-
data = torch.full(size, float(fill_value), dtype=dtype)
53+
# Extract tensors and args from the node
54+
data_list = [
55+
get_param_tensor(self.exported_program, input_node)
56+
for input_node in node.all_input_nodes
57+
]
5658

57-
insert_pos = list(node.graph.nodes)[0]
58-
else:
59-
# Extract tensors and args from the node
60-
61-
if len(node.all_input_nodes) == 0:
62-
raise RuntimeError("No inputs found")
59+
args = node.args[len(node.all_input_nodes) :]
60+
kwargs = node.kwargs
6361

64-
data_list = [
65-
get_param_tensor(self.exported_program, input_node)
66-
for input_node in node.all_input_nodes
67-
]
62+
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
63+
for i in range(len(node.all_input_nodes)):
64+
q_params = node.meta["input_qparams"][i]
65+
data_list[i] = q_params.dequantize_value(data_list[i])
6866

69-
args = node.args[len(node.all_input_nodes) :]
70-
kwargs = node.kwargs
67+
# Run the op on the extracted tensor
68+
data = node.target(*data_list, *args, **kwargs)
7169

72-
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
73-
dequantize_op = (
74-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
75-
)
70+
# Only fuse if the tensor does not get bigger.
71+
if data.numel() > get_first_fake_tensor(node).numel():
72+
return False
7673

77-
for i in range(len(node.all_input_nodes)):
78-
q_params = node.meta["input_qparams"][i]
79-
data_list[i] = dequantize_op(
80-
data_list[i],
81-
q_params.scale,
82-
q_params.zp,
83-
q_params.qmin,
84-
q_params.qmax,
85-
q_params.dtype,
86-
)
87-
88-
# Run the op on the extracted tensor
89-
data = node.target(*data_list, *args, **kwargs)
90-
91-
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
92-
quantize_op = (
93-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
94-
)
95-
q_params = node.meta["output_qparams"][0]
96-
data = quantize_op(
97-
data,
98-
q_params.scale,
99-
q_params.zp,
100-
q_params.qmin,
101-
q_params.qmax,
102-
q_params.dtype,
103-
)
74+
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
75+
q_params = node.meta["output_qparams"][0]
76+
data = q_params.quantize_value(data)
10477

105-
insert_pos = list(node.all_input_nodes)[0]
78+
insert_pos = list(node.all_input_nodes)[0]
10679

10780
# Make new node the same kind as the first constant input
10881
input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos)
@@ -124,20 +97,17 @@ def fuse_nodes(self, node) -> bool:
12497
return True
12598

12699
def call(self, graph_module):
127-
modified = True
100+
modified = False
128101
input_nodes_to_delete = []
129102
for node in graph_module.graph.nodes:
130103
if node.op != "call_function":
131104
continue
132105
if node.target == torch.ops.tosa._table.default:
133106
continue
134-
if node.target == exir_ops.edge.aten.repeat.default:
135-
_, multiples = node.args
136-
# Do not fuse if the repeat creates a larger output, i.e. any multiple > 1
137-
if any((multiple > 1 for multiple in multiples)):
138-
continue
139107

140108
input_nodes = node.all_input_nodes
109+
if len(input_nodes) == 0:
110+
continue
141111
input_nodes_constant = (
142112
torch._export.utils.is_param(self.exported_program, input_node)
143113
or torch._export.utils.is_lifted_tensor_constant(
@@ -152,9 +122,11 @@ def call(self, graph_module):
152122

153123
if all(input_nodes_constant) and all(input_nodes_single_users):
154124
try:
155-
self.fuse_nodes(node)
156-
graph_module.recompile() # Recompile needed to catch chains of constant ops
157-
input_nodes_to_delete.extend(input_nodes)
125+
did_fuse = self._fuse_nodes(node)
126+
modified |= did_fuse
127+
if did_fuse:
128+
graph_module.recompile() # Recompile needed to catch chains of constant ops
129+
input_nodes_to_delete.extend(input_nodes)
158130
except Exception as e:
159131
logger.warning(
160132
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
@@ -168,3 +140,85 @@ def call(self, graph_module):
168140
graph_module = super().call(graph_module).graph_module
169141

170142
return PassResult(graph_module, True)
143+
144+
145+
class ComputeConstantOpsAOT(ExportPass):
146+
"""
147+
Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders.
148+
149+
Original:
150+
state_dict = {}
151+
def f():
152+
return torch.arange(0,10)
153+
After pass:
154+
state_dict = {node_name_pre_computed : torch.arange(0,10)}
155+
def f(node_name_pre_computed):
156+
return node_name_pre_computed
157+
"""
158+
159+
targeted_ops = [
160+
exir_ops.edge.aten.full.default,
161+
exir_ops.edge.aten.arange.start_step,
162+
exir_ops.edge.aten.eye.default,
163+
exir_ops.edge.aten.linspace.default,
164+
]
165+
166+
def __init__(self, exported_program: ExportedProgram) -> None:
167+
super().__init__()
168+
self.exported_program = exported_program
169+
170+
def compute_node_aot(self, node: torch.fx.Node) -> bool:
171+
"""
172+
Takes a node with only parameter inputs and replaces it with one constant tensor node with
173+
the operations already carried out on the data.
174+
"""
175+
176+
# Create data from args
177+
output_qparams = node.meta.get("output_qparams", None)
178+
if output_qparams:
179+
# If we have output_qparams, compute data in fp and quantize
180+
data = node.target(*node.args) # type: ignore
181+
output_qparams = output_qparams[0]
182+
data = output_qparams.quantize_value(data)
183+
else:
184+
# If we don't have output_qparams, compute data using kwarg-specified dtype
185+
data = node.target(*node.args, **node.kwargs) # type: ignore
186+
187+
# Create new node
188+
insert_pos = list(node.graph.nodes)[0]
189+
input_kind = InputKind.BUFFER
190+
persistent_buffer = True
191+
192+
with node.graph.inserting_before(insert_pos):
193+
const_node = create_constant_placeholder(
194+
exp_program=self.exported_program,
195+
graph=node.graph,
196+
kind=input_kind,
197+
name=node.name + "_pre_computed",
198+
data=data,
199+
persistent_buffer=persistent_buffer,
200+
)
201+
node.replace_all_uses_with(const_node)
202+
203+
return True
204+
205+
def call(self, graph_module):
206+
modified = False
207+
for node in graph_module.graph.nodes:
208+
if node.op != "call_function":
209+
continue
210+
if node.target not in self.targeted_ops:
211+
continue
212+
try:
213+
modified |= self.compute_node_aot(node)
214+
except Exception as e:
215+
logger.warning(
216+
f"\nFailed to pre-compute op {node.name} due to exception:\n{str(e)}"
217+
)
218+
219+
if modified:
220+
graph_module.graph.eliminate_dead_code()
221+
graph_module.recompile()
222+
graph_module = super().call(graph_module).graph_module
223+
224+
return PassResult(graph_module, True)

‎backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.fx as fx
1515

1616
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
17+
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1718
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1819
FuseQuantizedActivationPass,
1920
)
@@ -142,6 +143,7 @@ def is_node_supported(
142143
exir_ops.edge.aten.logical_or.default,
143144
exir_ops.edge.aten.logical_xor.default,
144145
exir_ops.edge.aten.logical_not.default,
146+
exir_ops.edge.aten.arange.start_step,
145147
exir_ops.edge.aten.bitwise_and.Tensor,
146148
exir_ops.edge.aten.bitwise_or.Tensor,
147149
exir_ops.edge.aten.bitwise_xor.Tensor,
@@ -201,6 +203,8 @@ def is_node_supported(
201203
exir_ops.edge.aten.constant_pad_nd.default,
202204
exir_ops.edge.aten.amax.default,
203205
exir_ops.edge.aten.amin.default,
206+
exir_ops.edge.aten.eye.default,
207+
exir_ops.edge.aten.linspace.default,
204208
]
205209

206210
return supported
@@ -457,16 +461,18 @@ def is_node_supported(
457461
) -> bool:
458462

459463
for input_node in node.all_input_nodes:
460-
# We can cast constant placeholders AOT, not call_functions.
464+
# We can cast constant placeholders and constant ops AOT, such int64 are ok.
465+
# Otherwise, don't partition if one or more inputs are int64.
461466
if (
462467
input_node.name in self.input_names
463468
or not input_node.op == "placeholder"
464469
):
465470
tensor = get_first_fake_tensor(input_node)
466471
if tensor.dtype == torch.int64:
467-
self.reporter.report_reject(
468-
node,
469-
f"Had int64 input {input_node.name} that couldn't be handled.",
470-
)
471-
return False
472+
if input_node.target not in ComputeConstantOpsAOT.targeted_ops:
473+
self.reporter.report_reject(
474+
node,
475+
f"Had int64 input {input_node.name} that couldn't be handled.",
476+
)
477+
return False
472478
return True

‎backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TestConformer(unittest.TestCase):
3030
# for that is some assert ops are removed by passes in the
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
33-
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
3433
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
3534
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3635
"executorch_exir_dialects_edge__ops_aten_where_self": 4,

‎backends/arm/test/models/test_nn_functional.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,20 @@ def forward(self, *args):
7979

8080

8181
@parametrize(
82-
"test_data", module_tests, xfails={"max_pool1d": "ValueError: Invalid TOSA graph"}
82+
"test_data",
83+
module_tests,
84+
xfails={
85+
"max_pool1d": "ValueError: Invalid TOSA graph",
86+
"affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.",
87+
},
8388
)
8489
def test_nn_functional_MI(test_data):
8590
module, inputs = test_data
8691
pipeline = TosaPipelineMI[input_t](
87-
module, inputs, "", use_to_edge_transform_and_lower=True
92+
module, inputs, "", use_to_edge_transform_and_lower=False
8893
)
8994
pipeline.pop_stage("check.aten")
95+
pipeline.dump_artifact("to_edge")
9096
pipeline.pop_stage("check_count.exir")
9197
try:
9298
pipeline.run()

‎backends/arm/test/ops/test_arange.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
input_t = tuple[torch.Tensor]
18+
test_data_t = tuple[Callable[[], input_t], tuple[float, float, float, torch.dtype]]
19+
20+
21+
class ArangeAdd(torch.nn.Module):
22+
aten_op: str = "torch.ops.aten.arange.start_step"
23+
exir_op: str = "executorch_exir_dialects_edge__ops_aten_arange_start_step"
24+
25+
def __init__(self, start: float, stop: float, step: float, dtype: torch.dtype):
26+
super().__init__()
27+
self.args = (start, stop, step)
28+
self.dtype = dtype
29+
30+
def forward(self, x: torch.Tensor) -> torch.Tensor:
31+
return torch.arange(*self.args, dtype=self.dtype) + x
32+
33+
test_data: dict[str, test_data_t] = {
34+
"10": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 1.0, torch.float32)),
35+
"15": (lambda: (torch.randn(10),), (0.0, 15.0, 1.5, torch.float32)),
36+
"100": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 0.1, torch.float32)),
37+
}
38+
39+
test_data_dtypes: dict[str, test_data_t] = {
40+
"fp32_int32": (lambda: (torch.randn(10),), (0.0, 10.0, 1.0, torch.int32)),
41+
"fp32_int64": (lambda: (torch.randn(10),), (0.0, 10.0, 1.0, torch.int64)),
42+
"int32_int32": (
43+
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
44+
(0.0, 10.0, 1.0, torch.int32),
45+
),
46+
"int32_int64": (
47+
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
48+
(0.0, 10.0, 1.0, torch.int64),
49+
),
50+
}
51+
52+
53+
@common.parametrize("test_data", ArangeAdd.test_data)
54+
def test_arange_start_step_tosa_MI(test_data: test_data_t):
55+
input_data, init_data = test_data
56+
pipeline = TosaPipelineMI[input_t](
57+
ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", ArangeAdd.test_data_dtypes)
63+
def test_arange_start_step_dtypes_tosa_MI(test_data: test_data_t):
64+
input_data, init_data = test_data
65+
pipeline = TosaPipelineMI[input_t](
66+
ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", ArangeAdd.test_data)
72+
def test_arange_start_step_tosa_BI(test_data: test_data_t):
73+
input_data, init_data = test_data
74+
pipeline = TosaPipelineBI[input_t](
75+
ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op
76+
)
77+
pipeline.pop_stage("check.quant_nodes")
78+
pipeline.run()
79+
80+
81+
@common.parametrize("test_data", ArangeAdd.test_data)
82+
def test_arange_start_step_tosa_u55(test_data: test_data_t):
83+
input_data, init_data = test_data
84+
pipeline = EthosU55PipelineBI[input_t](
85+
ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op
86+
)
87+
pipeline.pop_stage("check.quant_nodes")
88+
pipeline.run()
89+
90+
91+
@common.parametrize("test_data", ArangeAdd.test_data)
92+
def test_arange_start_step_tosa_u85(test_data: test_data_t):
93+
input_data, init_data = test_data
94+
pipeline = EthosU85PipelineBI[input_t](
95+
ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op
96+
)
97+
pipeline.pop_stage("check.quant_nodes")
98+
pipeline.run()
99+
100+
101+
class LinspaceAdd(torch.nn.Module):
102+
aten_op: str = "torch.ops.aten.linspace.default"
103+
exir_op: str = "executorch_exir_dialects_edge__ops_aten_arange_default"
104+
105+
def __init__(self, start: float, stop: float, step: int, dtype: torch.dtype):
106+
super().__init__()
107+
self.args = (start, stop, step)
108+
self.dtype = dtype
109+
110+
def forward(self, x: torch.Tensor) -> torch.Tensor:
111+
return torch.linspace(*self.args, dtype=self.dtype) + x
112+
113+
test_data: dict[str, test_data_t] = {
114+
"10": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 100, torch.float32)),
115+
"15": (lambda: (torch.randn(20),), (0.0, 15.0, 20, torch.float32)),
116+
}
117+
118+
119+
@common.parametrize("test_data", LinspaceAdd.test_data)
120+
def test_linspace_tosa_MI(test_data):
121+
input_data, init_data = test_data
122+
pipeline = TosaPipelineMI[input_t](
123+
LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op
124+
)
125+
pipeline.run()
126+
127+
128+
@common.parametrize("test_data", LinspaceAdd.test_data)
129+
def test_linspace_tosa_BI(test_data: test_data_t):
130+
input_data, init_data = test_data
131+
pipeline = TosaPipelineBI[input_t](
132+
LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op
133+
)
134+
pipeline.pop_stage("check.quant_nodes")
135+
pipeline.run()

‎backends/arm/test/ops/test_full.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ def test_const_full_tosa_MI(self):
141141
def test_full_like_tosa_MI(self, test_tensor: Tuple):
142142
self._test_full_tosa_MI_pipeline(self.FullLike(), test_tensor)
143143

144-
def test_const_full_nhwc_tosa_BI(self):
145-
_input = torch.rand((2, 2, 3, 3)) * 10
146-
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,))
147-
148144
@parameterized.expand(AddVariableFull.test_parameters)
149145
def test_full_tosa_MI(self, test_tensor: Tuple):
150146
self._test_full_tosa_MI_pipeline(
@@ -175,8 +171,6 @@ def test_full_u85_BI(self, test_tensor: Tuple):
175171
test_tensor,
176172
)
177173

178-
# This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support.
179-
@unittest.expectedFailure
180174
def test_integer_value(self):
181175
_input = torch.ones((2, 2))
182176
integer_fill_value = 1

‎backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from typing import Tuple
99

1010
import torch
11-
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
11+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
12+
ComputeConstantOpsAOT,
13+
FuseConstantArgsPass,
14+
)
1215
from executorch.backends.arm.test import common
1316
from executorch.backends.arm.test.tester.test_pipeline import (
1417
PassPipeline,
@@ -95,22 +98,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9598

9699

97100
@common.parametrize("module", modules)
98-
def test_fuse_batchnorm_tosa_MI(module):
101+
def test_fuse_const_ops_tosa_MI(module):
99102
pipeline = PassPipeline[input_t](
100103
module=module,
101104
test_data=(torch.rand(1),),
102105
tosa_version="TOSA-0.80+MI",
103106
ops_before_pass=module.ops_before_pass,
104107
ops_after_pass=module.ops_after_pass,
105108
ops_not_after_pass=module.ops_not_after_pass,
106-
passes_with_exported_program=[FuseConstantOpsPass],
109+
passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass],
107110
)
108111
pipeline.run()
109112

110113

111114
@unittest.skip("Test failing on internal CI")
112115
@common.parametrize("module", modules)
113-
def test_fuse_batchnorm_tosa_BI(module):
116+
def test_fuse_const_ops_tosa_BI(module):
114117
pipeline = TosaPipelineBI[input_t](
115118
module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True
116119
)

‎backends/arm/tosa_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def filter_fn(node: torch.fx.Node) -> bool:
177177
ops_to_not_decompose = [
178178
torch.ops.aten.linear.default,
179179
torch.ops.aten.upsample_nearest2d.vec,
180+
torch.ops.aten.eye.default,
181+
torch.ops.aten.linspace.default,
180182
] + ops_to_not_decompose_if_quant_op
181183

182184
return (ops_to_not_decompose, filter_fn)

‎backends/arm/tosa_quant_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def quantize_value(self, x: torch.Tensor | float) -> Tensor:
130130
).to(self.dtype)
131131

132132
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
133-
return (qx - self.zp) * self.scale
133+
return (qx.to(torch.int64) - self.zp) * self.scale
134134

135135
@classmethod
136136
def from_operator(cls, op, args):

‎backends/transforms/fuse_view_copy.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -11,13 +12,14 @@
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213

1314

14-
def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
1516
"""
1617
Find chains of view_copy nodes and merge them into one view_copy node.
1718
Only merges view_copy nodes that are not used by any other nodes.
1819
"""
1920
ops = exir_ops.edge
2021
view_op = ops.aten.view_copy.default
22+
modified = False
2123
for node in graph.nodes:
2224
if node.op == "call_function" and node.target == view_op:
2325
# find ending view_copy node in chain
@@ -35,29 +37,36 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
3537
new_args = (node.args[0], end_node.args[1])
3638
node.args = new_args
3739
end_node.replace_all_uses_with(node)
40+
modified = True
3841

3942
graph.eliminate_dead_code()
40-
return graph
43+
return graph, modified
4144

4245

43-
def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph:
46+
def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
4447
"""
4548
Remove view_copy nodes that are no-ops.
4649
"""
4750
ops = exir_ops.edge
4851
view_op = ops.aten.view_copy.default
52+
modified = False
4953
for node in graph.nodes:
5054
if node.op == "call_function" and node.target == view_op:
5155
input_shape = list(node.args[0].meta["val"].shape)
5256
target_shape = node.args[1]
5357
if input_shape == target_shape:
5458
node.replace_all_uses_with(node.args[0])
59+
modified = True
5560
graph.eliminate_dead_code()
56-
return graph
61+
return graph, modified
5762

5863

5964
class FuseViewCopyTransform(ExportPass):
6065
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
61-
graph_module.graph = merge_view_copy_chains(graph_module.graph)
62-
graph_module.graph = remove_noop_view_copy(graph_module.graph)
63-
return PassResult(graph_module, True)
66+
graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph)
67+
graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph)
68+
modified = merge_modified or noop_modified
69+
if modified:
70+
graph_module.recompile()
71+
graph_module = super().call(graph_module).graph_module
72+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)
Please sign in to comment.