Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 032ba6c

Browse files
authoredApr 24, 2025··
Arm backend: Update more node visitors to support TOSA 1.0 (#10425)
### Summary Updates more node visitors to support TOSA 1.0 specification. ### Test plan Tested through public and internal CI. Signed-off-by: Oscar Andersson <[email protected]>
1 parent d31ef13 commit 032ba6c

32 files changed

+1519
-236
lines changed
 

‎backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class TableOps:
4848
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4949
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
5050
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
51+
exir_ops.edge.aten.cos.default: torch.cos,
52+
exir_ops.edge.aten.sin.default: torch.sin,
5153
exir_ops.edge.aten.tanh.default: torch.tanh,
5254
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5355
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,

‎backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
pool_2d_support,
1313
reduce_sum_support,
1414
right_shift_support,
15+
sin_cos_support,
1516
slice_copy_support,
1617
to_copy_support,
1718
tosa_supported_operators,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class SinCosSupported(SupportedTOSAOperatorCheck):
20+
targets = [
21+
exir_ops.edge.aten.cos.default,
22+
exir_ops.edge.aten.sin.default,
23+
]
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
27+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
28+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
29+
]
30+
31+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
32+
return True

‎backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
EthosU55NotSupported,
2424
EthosU55TransposeCheck,
2525
)
26-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
26+
from executorch.backends.arm.tosa_specification import (
27+
Tosa_0_80,
28+
Tosa_1_00,
29+
TosaSpecification,
30+
)
2731
from executorch.exir import ExportedProgram
2832
from executorch.exir.backend.utils import WhyNoPartitionReporter
2933
from executorch.exir.dialects._ops import ops as exir_ops
@@ -124,7 +128,9 @@ def tosa_support_factory(
124128
if not tosa_spec.support_float():
125129
negative_checks.append(NeedsDecompositionCheck(reporter))
126130
negative_checks.append(CheckProperQuantization(reporter))
127-
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
131+
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
132+
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
133+
):
128134
negative_checks.append(EthosU55NotSupported(reporter))
129135
negative_checks.append(EthosU55DtypeSupport(reporter))
130136
negative_checks.append(EthosU55TransposeCheck(reporter))

‎backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
op_clamp,
1919
op_constant_pad_nd,
2020
op_conv2d,
21+
op_cos,
2122
op_eq,
2223
op_erf,
2324
op_exp,
@@ -38,6 +39,7 @@
3839
op_rshift_tensor,
3940
op_rsqrt,
4041
op_sigmoid,
42+
op_sin,
4143
op_slice,
4244
op_sub,
4345
op_sum,

‎backends/arm/operators/op_any.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import cast, List
7+
from typing import Any, cast, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1110
NodeVisitor,
1211
register_node_visitor,
@@ -17,16 +16,19 @@
1716

1817

1918
@register_node_visitor
20-
class AnyVisitor(NodeVisitor):
19+
class AnyVisitor_0_80(NodeVisitor):
2120
target = "aten.any.dim"
2221

22+
tosa_specs = NodeVisitor.tosa_specs_0_80
23+
2324
def define_node(
2425
self,
2526
node: Node,
26-
tosa_graph: ts.TosaSerializer,
27+
tosa_graph: Any,
2728
inputs: List[TosaArg],
2829
output: TosaArg,
2930
) -> None:
31+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3032

3133
if not (inputs[0].dtype == output.dtype):
3234
raise ValueError(
@@ -50,3 +52,42 @@ def define_node(
5052
tosa_graph.addOperator(
5153
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
5254
)
55+
56+
57+
@register_node_visitor
58+
class AnyVisitor(NodeVisitor):
59+
target = "aten.any.dim"
60+
61+
tosa_specs = NodeVisitor.tosa_specs_1_00
62+
63+
def define_node(
64+
self,
65+
node: Node,
66+
tosa_graph: Any,
67+
inputs: List[TosaArg],
68+
output: TosaArg,
69+
) -> None:
70+
import serializer.tosa_serializer as ts
71+
72+
if not (inputs[0].dtype == output.dtype):
73+
raise ValueError(
74+
"All inputs and outputs need same dtype."
75+
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
76+
)
77+
if not (inputs[0].dtype == ts.DType.BOOL):
78+
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
79+
80+
input_shape = list(inputs[0].shape)
81+
dim = cast(int, inputs[1].number) % len(
82+
input_shape
83+
) # process the negative index
84+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
85+
if not keep_dim:
86+
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
87+
88+
attr = ts.TosaSerializerAttribute()
89+
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
90+
91+
tosa_graph.addOperator(
92+
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
93+
)

‎backends/arm/operators/op_avg_pool2d.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1311
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1412
get_input_qparams,
1513
get_output_qparams,
@@ -36,14 +34,16 @@ def __init__(self, *args):
3634
def _build_generic_avgpool2d(
3735
self,
3836
node: torch.fx.Node,
39-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
4038
inputs: List[TosaArg],
4139
output: TosaArg,
4240
input_zp: int,
4341
output_zp: int,
44-
accumulator_type: ts.DType,
42+
accumulator_type: Any,
4543
) -> None:
4644

45+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46+
4747
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
@@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
7979
def define_node(
8080
self,
8181
node: torch.fx.Node,
82-
tosa_graph: ts.TosaSerializer,
82+
tosa_graph: Any,
8383
inputs: List[TosaArg],
8484
output: TosaArg,
8585
) -> None:
86+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
87+
8688
input_tensor = inputs[0]
8789
assert input_tensor.dtype == ts.DType.INT8
8890

@@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
110112
def define_node(
111113
self,
112114
node: torch.fx.Node,
113-
tosa_graph: ts.TosaSerializer,
115+
tosa_graph: Any,
114116
inputs: List[TosaArg],
115117
output: TosaArg,
116118
) -> None:
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
121+
assert (
122+
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123+
), "Only FP32 and INT8 supported"
124+
125+
if inputs[0].dtype == ts.DType.INT8:
126+
super().define_node(node, tosa_graph, inputs, output)
127+
128+
if inputs[0].dtype == ts.DType.FP32:
129+
accumulator_type = ts.DType.FP32
130+
# Initilize zero point to zero.
131+
input_zp = 0
132+
output_zp = 0
133+
134+
self._build_generic_avgpool2d(
135+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
136+
)
137+
138+
139+
@register_node_visitor
140+
class AvgPool2dVisitor(NodeVisitor):
141+
target = "aten.avg_pool2d.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def _build_generic_avgpool2d(
151+
self,
152+
node: torch.fx.Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
input_zp: int,
157+
output_zp: int,
158+
accumulator_type: Any,
159+
) -> None:
160+
161+
import serializer.tosa_serializer as ts # type: ignore
162+
163+
input_tensor = inputs[0]
164+
kernel_size_list = inputs[1].special
165+
stride_size_list = inputs[2].special
166+
167+
try:
168+
pad_size_list = inputs[3].special
169+
pad_size_list = [
170+
pad_size_list[0],
171+
pad_size_list[0],
172+
pad_size_list[1],
173+
pad_size_list[1],
174+
]
175+
except IndexError:
176+
pad_size_list = [0, 0, 0, 0]
177+
178+
attr = ts.TosaSerializerAttribute()
179+
attr.AvgPool2dAttribute(
180+
kernel=kernel_size_list,
181+
stride=stride_size_list,
182+
pad=pad_size_list,
183+
acc_type=accumulator_type,
184+
)
185+
input_zp_tensor = tosa_graph.addConst(
186+
shape=[1], dtype=output.dtype, vals=[input_zp]
187+
)
188+
output_zp_tensor = tosa_graph.addConst(
189+
shape=[1], dtype=output.dtype, vals=[output_zp]
190+
)
191+
192+
tosa_graph.addOperator(
193+
ts.TosaOp.Op().AVG_POOL2D,
194+
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
195+
[output.name],
196+
attr,
197+
)
198+
199+
def define_node(
200+
self,
201+
node: torch.fx.Node,
202+
tosa_graph: Any,
203+
inputs: List[TosaArg],
204+
output: TosaArg,
205+
) -> None:
206+
import serializer.tosa_serializer as ts # type: ignore
207+
208+
input_tensor = inputs[0]
209+
assert input_tensor.dtype == ts.DType.INT8
210+
211+
accumulator_type = ts.DType.INT32
212+
213+
input_qargs = get_input_qparams(node)
214+
input_zp = input_qargs[0].zp
215+
216+
output_qargs = get_output_qparams(node)
217+
output_zp = output_qargs[0].zp
218+
219+
self._build_generic_avgpool2d(
220+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
221+
)
222+
223+
224+
@register_node_visitor
225+
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
226+
target = "aten.avg_pool2d.default"
227+
228+
tosa_specs = [
229+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
230+
]
231+
232+
def __init__(self, *args):
233+
super().__init__(*args)
234+
235+
def define_node(
236+
self,
237+
node: torch.fx.Node,
238+
tosa_graph: Any,
239+
inputs: List[TosaArg],
240+
output: TosaArg,
241+
) -> None:
242+
import serializer.tosa_serializer as ts # type: ignore
243+
117244
assert (
118245
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
119246
), "Only FP32 and INT8 supported"

‎backends/arm/operators/op_cat.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

10-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1110
from executorch.backends.arm.operators.node_visitor import (
1211
NodeVisitor,
1312
register_node_visitor,
@@ -17,19 +16,22 @@
1716

1817

1918
@register_node_visitor
20-
class CatVisitor(NodeVisitor):
19+
class CatVisitor_0_80(NodeVisitor):
2120
target = "aten.cat.default"
2221

22+
tosa_specs = NodeVisitor.tosa_specs_0_80
23+
2324
def __init__(self, *args):
2425
super().__init__(*args)
2526

2627
def define_node(
2728
self,
2829
node: Node,
29-
tosa_graph: ts.TosaSerializer,
30+
tosa_graph: Any,
3031
inputs: List[TosaArg],
3132
output: TosaArg,
3233
) -> None:
34+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3335

3436
tensors = inputs[0].special
3537
dim = 0 if len(inputs) < 2 else inputs[1].number
@@ -46,3 +48,38 @@ def define_node(
4648
[output.name],
4749
attr,
4850
)
51+
52+
53+
@register_node_visitor
54+
class CatVisitor(NodeVisitor):
55+
target = "aten.cat.default"
56+
57+
tosa_specs = NodeVisitor.tosa_specs_1_00
58+
59+
def __init__(self, *args):
60+
super().__init__(*args)
61+
62+
def define_node(
63+
self,
64+
node: Node,
65+
tosa_graph: Any,
66+
inputs: List[TosaArg],
67+
output: TosaArg,
68+
) -> None:
69+
import serializer.tosa_serializer as ts
70+
71+
tensors = inputs[0].special
72+
dim = 0 if len(inputs) < 2 else inputs[1].number
73+
rank = len(output.shape)
74+
dim = (dim + rank) % rank
75+
dim = output.dim_order.index(dim)
76+
77+
attr = ts.TosaSerializerAttribute()
78+
attr.ConcatAttribute(dim)
79+
80+
tosa_graph.addOperator(
81+
ts.TosaOp.Op().CONCAT,
82+
[tensor.name for tensor in tensors],
83+
[output.name],
84+
attr,
85+
)

‎backends/arm/operators/op_constant_pad_nd.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
13-
1412
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1513
get_input_qparams,
1614
)
@@ -19,20 +17,27 @@
1917
register_node_visitor,
2018
)
2119
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
2221

2322

2423
@register_node_visitor
25-
class ConstantPadNDVisitor(NodeVisitor):
24+
class ConstantPadNDVisitor_0_80(NodeVisitor):
2625

2726
target = "aten.constant_pad_nd.default"
2827

28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
30+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
31+
]
32+
2933
def define_node(
3034
self,
3135
node: torch.fx.Node,
32-
tosa_graph: ts.TosaSerializer,
36+
tosa_graph: Any,
3337
inputs: List[TosaArg],
3438
output: TosaArg,
3539
) -> None:
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3641

3742
if inputs[0].dtype == ts.DType.INT8:
3843
input_qparams = get_input_qparams(node)
@@ -74,3 +79,72 @@ def define_node(
7479
tosa_graph.addOperator(
7580
ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr
7681
)
82+
83+
84+
@register_node_visitor
85+
class ConstantPadNDVisitor(NodeVisitor):
86+
87+
target = "aten.constant_pad_nd.default"
88+
89+
tosa_specs = [
90+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
91+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
92+
]
93+
94+
def define_node(
95+
self,
96+
node: torch.fx.Node,
97+
tosa_graph: Any,
98+
inputs: List[TosaArg],
99+
output: TosaArg,
100+
) -> None:
101+
102+
import serializer.tosa_serializer as ts # type: ignore
103+
104+
if inputs[0].dtype == ts.DType.INT8:
105+
input_qparams = get_input_qparams(node)
106+
qargs = input_qparams[0]
107+
pad_const_val = qargs.quantize_value(inputs[2].number).item()
108+
pad_const_dtype = ts.DType.INT8
109+
else:
110+
pad_const_val = inputs[2].number
111+
pad_const_dtype = inputs[0].dtype
112+
113+
rank = len(output.shape)
114+
# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
115+
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
116+
# (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding
117+
# values are in the reverse order. So, firstly we need to reverse the input padding parameters.
118+
input_pad = sum(
119+
[
120+
[inputs[1].special[i], inputs[1].special[i + 1]]
121+
for i in range(0, len(inputs[1].special), 2)
122+
][::-1],
123+
[],
124+
)
125+
# Then, add dummy zeros to make sure that both input_pad and output_pad has the same size.
126+
input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad
127+
# For PyTorch NCHW format, dim order is [0,...,rank-1]
128+
input_dim_order = list(range(rank))
129+
output_pad = [0] * rank * 2
130+
131+
# Map input padding parameters into output padding parameters. TOSA is NHWC format.
132+
for input_dim_idx, input_dim in enumerate(input_dim_order):
133+
output_dim_idx = output.dim_order.index(input_dim)
134+
output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[
135+
input_dim_idx * 2 : (input_dim_idx + 1) * 2
136+
]
137+
138+
padding = tosa_graph.addConst(
139+
shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad
140+
)
141+
142+
pad_const = tosa_graph.addConst(
143+
shape=[1], dtype=pad_const_dtype, vals=[pad_const_val]
144+
)
145+
146+
tosa_graph.addOperator(
147+
ts.TosaOp.Op().PAD,
148+
[inputs[0].name, padding.name, pad_const.name],
149+
[output.name],
150+
)

‎backends/arm/operators/op_cos.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from torch.fx import Node
17+
18+
19+
@register_node_visitor
20+
class CosVisitor(NodeVisitor):
21+
target = "aten.cos.default"
22+
23+
# INT case should be handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if len(node.all_input_nodes) != 1:
37+
raise ValueError(
38+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
39+
)
40+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
41+
raise ValueError(
42+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
43+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
44+
)
45+
46+
tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name])

‎backends/arm/operators/op_erf.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55
# pyre-unsafe
6-
from typing import List
6+
from typing import Any, List
77

88
import torch.fx
9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
10-
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
119
from executorch.backends.arm.operators.node_visitor import (
1210
NodeVisitor,
1311
register_node_visitor,
@@ -29,16 +27,49 @@ def __init__(self, *args):
2927
def define_node(
3028
self,
3129
node: torch.fx.Node,
32-
tosa_graph: ts.TosaSerializer,
30+
tosa_graph: Any,
3331
inputs: List[TosaArg],
3432
output: TosaArg,
3533
) -> None:
34+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
35+
36+
if not (inputs[0].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and output need same dtype."
39+
f"Got {inputs[0].dtype=}, {output.dtype=}"
40+
)
41+
if not (inputs[0].dtype == ts.DType.FP32):
42+
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
43+
# MI lowering
44+
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])
45+
46+
47+
@register_node_visitor
48+
class ERFVisitor(NodeVisitor):
49+
target = "aten.erf.default"
50+
51+
# INT case handled by op_table
52+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
53+
54+
def __init__(self, *args):
55+
super().__init__(*args)
56+
57+
def define_node(
58+
self,
59+
node: torch.fx.Node,
60+
tosa_graph: Any,
61+
inputs: List[TosaArg],
62+
output: TosaArg,
63+
) -> None:
64+
import serializer.tosa_serializer as ts
65+
3666
if not (inputs[0].dtype == output.dtype):
3767
raise ValueError(
3868
"All inputs and output need same dtype."
3969
f"Got {inputs[0].dtype=}, {output.dtype=}"
4070
)
4171
if not (inputs[0].dtype == ts.DType.FP32):
4272
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
73+
4374
# MI lowering
44-
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])
75+
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])

‎backends/arm/operators/op_exp.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import (
1110
NodeVisitor,
1211
register_node_visitor,
@@ -29,10 +28,43 @@ def __init__(self, *args):
2928
def define_node(
3029
self,
3130
node: Node,
32-
tosa_graph: ts.TosaSerializer,
31+
tosa_graph: Any,
3332
inputs: List[TosaArg],
3433
output: TosaArg,
3534
) -> None:
35+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36+
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input dtype: "
44+
f"{inputs[0].dtype} and output dtype: {output.dtype}"
45+
)
46+
47+
tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name])
48+
49+
50+
@register_node_visitor
51+
class ExpVisitor(NodeVisitor):
52+
target = "aten.exp.default"
53+
54+
# BI case should be handled by op_table
55+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56+
57+
def __init__(self, *args):
58+
super().__init__(*args)
59+
60+
def define_node(
61+
self,
62+
node: Node,
63+
tosa_graph: Any,
64+
inputs: List[TosaArg],
65+
output: TosaArg,
66+
) -> None:
67+
import serializer.tosa_serializer as ts
3668

3769
if len(node.all_input_nodes) != 1:
3870
raise ValueError(

‎backends/arm/operators/op_log.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import (
1110
NodeVisitor,
1211
register_node_visitor,
@@ -17,7 +16,7 @@
1716

1817

1918
@register_node_visitor
20-
class LogVisitor(NodeVisitor):
19+
class LogVisitor_0_80_MI(NodeVisitor):
2120
target = "aten.log.default"
2221

2322
# BI case should be handled by op_table
@@ -29,10 +28,44 @@ def __init__(self, *args):
2928
def define_node(
3029
self,
3130
node: Node,
32-
tosa_graph: ts.TosaSerializer,
31+
tosa_graph: Any,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
) -> None:
35+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36+
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45+
)
46+
47+
tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name])
48+
49+
50+
@register_node_visitor
51+
class LogVisitor(NodeVisitor):
52+
target = "aten.log.default"
53+
54+
# INT case should be handled by op_table
55+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56+
57+
def __init__(self, *args):
58+
super().__init__(*args)
59+
60+
def define_node(
61+
self,
62+
node: Node,
63+
tosa_graph: Any,
3364
inputs: List[TosaArg],
3465
output: TosaArg,
3566
) -> None:
67+
import serializer.tosa_serializer as ts
68+
3669
if len(node.all_input_nodes) != 1:
3770
raise ValueError(
3871
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_max_pool2d.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1311
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1412
get_input_qparams,
1513
get_output_qparams,
@@ -19,22 +17,29 @@
1917
register_node_visitor,
2018
)
2119
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
2221

2322

2423
@register_node_visitor
25-
class MaxPool2dVisitor(NodeVisitor):
24+
class MaxPool2dVisitor_0_80(NodeVisitor):
2625
target = "aten.max_pool2d.default"
2726

27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
30+
]
31+
2832
def __init__(self, *args):
2933
super().__init__(*args)
3034

3135
def define_node(
3236
self,
3337
node: torch.fx.Node,
34-
tosa_graph: ts.TosaSerializer,
38+
tosa_graph: Any,
3539
inputs: List[TosaArg],
3640
output: TosaArg,
3741
) -> None:
42+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3843

3944
input_tensor = inputs[0]
4045
kernel_size = inputs[1].special
@@ -80,3 +85,53 @@ def define_node(
8085
[output.name],
8186
attr,
8287
)
88+
89+
90+
@register_node_visitor
91+
class MaxPool2dVisitor(NodeVisitor):
92+
target = "aten.max_pool2d.default"
93+
94+
tosa_specs = [
95+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
96+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
97+
]
98+
99+
def __init__(self, *args):
100+
super().__init__(*args)
101+
102+
def define_node(
103+
self,
104+
node: torch.fx.Node,
105+
tosa_graph: Any,
106+
inputs: List[TosaArg],
107+
output: TosaArg,
108+
) -> None:
109+
110+
import serializer.tosa_serializer as ts # type: ignore
111+
112+
input_tensor = inputs[0]
113+
kernel_size = inputs[1].special
114+
stride = inputs[2].special
115+
116+
try:
117+
pad_size_list = inputs[3].special
118+
pad_size_list = [
119+
pad_size_list[0],
120+
pad_size_list[0],
121+
pad_size_list[1],
122+
pad_size_list[1],
123+
]
124+
except IndexError:
125+
pad_size_list = [0, 0, 0, 0]
126+
127+
attr = ts.TosaSerializerAttribute()
128+
attr.MaxPool2dAttribute(
129+
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1
130+
)
131+
132+
tosa_graph.addOperator(
133+
ts.TosaOp.Op().MAX_POOL2D,
134+
[input_tensor.name],
135+
[output.name],
136+
attr,
137+
)

‎backends/arm/operators/op_permute.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -88,20 +87,61 @@ def transform_permutation_vector(permutation_vector: list[int], dim_order: list[
8887
return permutation_vector
8988

9089

90+
@register_node_visitor
91+
class PermuteVisitor_0_80(NodeVisitor):
92+
target = "aten.permute_copy.default"
93+
94+
tosa_specs = NodeVisitor.tosa_specs_0_80
95+
96+
def __init__(self, *args):
97+
super().__init__(*args)
98+
99+
def define_node(
100+
self,
101+
node: torch.fx.Node,
102+
tosa_graph: Any,
103+
inputs: List[TosaArg],
104+
output: TosaArg,
105+
) -> None:
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
108+
# The permutation vector describes a permutation P in default Pytorch dim_order.
109+
# For rank 4, the default dim_order NCHW.
110+
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
111+
permutation_vector = inputs[1].special
112+
113+
if output.dim_order != tuple(range(len(output.dim_order))):
114+
# the permutation vector can't be used directly if we are not in NCHW dim_order.
115+
# Transform to dim_order.
116+
permutation_vector = transform_permutation_vector(
117+
permutation_vector, output.dim_order
118+
)
119+
120+
attr = ts.TosaSerializerAttribute()
121+
attr.TransposeAttribute(permutation_vector)
122+
tosa_graph.addOperator(
123+
ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
124+
)
125+
126+
91127
@register_node_visitor
92128
class PermuteVisitor(NodeVisitor):
93129
target = "aten.permute_copy.default"
94130

131+
tosa_specs = NodeVisitor.tosa_specs_1_00
132+
95133
def __init__(self, *args):
96134
super().__init__(*args)
97135

98136
def define_node(
99137
self,
100138
node: torch.fx.Node,
101-
tosa_graph: ts.TosaSerializer,
139+
tosa_graph: Any,
102140
inputs: List[TosaArg],
103141
output: TosaArg,
104142
) -> None:
143+
import serializer.tosa_serializer as ts
144+
105145
# The permutation vector describes a permutation P in default Pytorch dim_order.
106146
# For rank 4, the default dim_order NCHW.
107147
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)

‎backends/arm/operators/op_pow.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

10-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
1110
from executorch.backends.arm.operators.node_visitor import (
1211
NodeVisitor,
1312
register_node_visitor,
@@ -31,10 +30,53 @@ def __init__(self, *args):
3130
def define_node(
3231
self,
3332
node: Node,
34-
tosa_graph: ts.TosaSerializer,
33+
tosa_graph: Any,
3534
inputs: List[TosaArg],
3635
output: TosaArg,
3736
) -> None:
37+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
38+
39+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40+
raise ValueError(
41+
"All inputs and outputs need same dtype."
42+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43+
)
44+
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45+
raise ValueError(
46+
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47+
)
48+
49+
tosa_graph.addOperator(
50+
ts.TosaOp.Op().POW,
51+
[
52+
inputs[0].name,
53+
inputs[1].name,
54+
],
55+
[output.name],
56+
None,
57+
)
58+
59+
60+
@register_node_visitor
61+
class PowVisitor(NodeVisitor):
62+
target = "aten.pow.Tensor_Tensor"
63+
64+
tosa_specs = [
65+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
66+
]
67+
68+
def __init__(self, *args):
69+
super().__init__(*args)
70+
71+
def define_node(
72+
self,
73+
node: Node,
74+
tosa_graph: Any,
75+
inputs: List[TosaArg],
76+
output: TosaArg,
77+
) -> None:
78+
import serializer.tosa_serializer as ts
79+
3880
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
3981
raise ValueError(
4082
"All inputs and outputs need same dtype."

‎backends/arm/operators/op_reciprocal.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1211
from executorch.backends.arm.operators.node_visitor import (
1312
NodeVisitor,
1413
register_node_visitor,
@@ -30,10 +29,46 @@ def __init__(self, *args):
3029
def define_node(
3130
self,
3231
node: torch.fx.Node,
33-
tosa_graph: ts.TosaSerializer,
32+
tosa_graph: Any,
3433
inputs: List[TosaArg],
3534
output: TosaArg,
3635
) -> None:
36+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
37+
38+
if len(node.all_input_nodes) != 1:
39+
raise ValueError(
40+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
41+
)
42+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
43+
raise ValueError(
44+
f"Input and output for {self.target} need to be FP32, got "
45+
f"{inputs[0].dtype=} and {output.dtype=}"
46+
)
47+
48+
tosa_graph.addOperator(
49+
ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
50+
)
51+
52+
53+
@register_node_visitor
54+
class ReciprocalVisitor(NodeVisitor):
55+
target = "aten.reciprocal.default"
56+
57+
# INT case should be handled by op_table
58+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
59+
60+
def __init__(self, *args):
61+
super().__init__(*args)
62+
63+
def define_node(
64+
self,
65+
node: torch.fx.Node,
66+
tosa_graph: Any,
67+
inputs: List[TosaArg],
68+
output: TosaArg,
69+
) -> None:
70+
import serializer.tosa_serializer as ts
71+
3772
if len(node.all_input_nodes) != 1:
3873
raise ValueError(
3974
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_rshift_tensor.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,32 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
1615
)
1716
from executorch.backends.arm.tosa_mapping import TosaArg
18-
from executorch.backends.arm.tosa_specification import Tosa_0_80
17+
from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00
1918

2019

2120
@register_node_visitor
22-
class RshiftVisitor(NodeVisitor):
21+
class RshiftVisitor_0_80(NodeVisitor):
2322
target = "aten.bitwise_right_shift.Tensor"
2423

24+
tosa_specs = NodeVisitor.tosa_specs_0_80
25+
2526
def define_node(
2627
self,
2728
node: torch.fx.Node,
28-
tosa_graph: ts.TosaSerializer,
29+
tosa_graph: Any,
2930
inputs: List[TosaArg],
3031
output: TosaArg,
3132
) -> None:
33+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3234

3335
attr = ts.TosaSerializerAttribute()
3436
round = False
@@ -44,3 +46,34 @@ def define_node(
4446
[output.name],
4547
attr,
4648
)
49+
50+
51+
@register_node_visitor
52+
class RshiftVisitor(NodeVisitor):
53+
target = "aten.bitwise_right_shift.Tensor"
54+
55+
tosa_specs = NodeVisitor.tosa_specs_1_00
56+
57+
def define_node(
58+
self,
59+
node: torch.fx.Node,
60+
tosa_graph: Any,
61+
inputs: List[TosaArg],
62+
output: TosaArg,
63+
) -> None:
64+
import serializer.tosa_serializer as ts
65+
66+
attr = ts.TosaSerializerAttribute()
67+
round = False
68+
if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions:
69+
# U55 only supports INT32 and round == True
70+
# TODO MLETORCH-525 Emulate round == False with different decomposition
71+
round = True
72+
attr.ArithmeticRightShiftAttribute(round=round)
73+
74+
tosa_graph.addOperator(
75+
ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
76+
[inputs[0].name, inputs[1].name],
77+
[output.name],
78+
attr,
79+
)

‎backends/arm/operators/op_rsqrt.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1211
from executorch.backends.arm.operators.node_visitor import (
1312
NodeVisitor,
1413
register_node_visitor,
@@ -30,10 +29,44 @@ def __init__(self, *args):
3029
def define_node(
3130
self,
3231
node: torch.fx.Node,
33-
tosa_graph: ts.TosaSerializer,
32+
tosa_graph: Any,
3433
inputs: List[TosaArg],
3534
output: TosaArg,
3635
) -> None:
36+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
37+
38+
if len(node.all_input_nodes) != 1:
39+
raise ValueError(
40+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
41+
)
42+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
43+
raise ValueError(
44+
f"Input and output for {self.target} need to be FP32, got "
45+
f"{inputs[0].dtype=} and {output.dtype=}"
46+
)
47+
48+
tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name])
49+
50+
51+
@register_node_visitor
52+
class RsqrtVisitor(NodeVisitor):
53+
target = "aten.rsqrt.default"
54+
55+
# INT case should be handled by op_table
56+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
57+
58+
def __init__(self, *args):
59+
super().__init__(*args)
60+
61+
def define_node(
62+
self,
63+
node: torch.fx.Node,
64+
tosa_graph: Any,
65+
inputs: List[TosaArg],
66+
output: TosaArg,
67+
) -> None:
68+
import serializer.tosa_serializer as ts
69+
3770
if len(node.all_input_nodes) != 1:
3871
raise ValueError(
3972
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_sigmoid.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import (
1110
NodeVisitor,
1211
register_node_visitor,
@@ -29,10 +28,43 @@ def __init__(self, *args):
2928
def define_node(
3029
self,
3130
node: Node,
32-
tosa_graph: ts.TosaSerializer,
31+
tosa_graph: Any,
3332
inputs: List[TosaArg],
3433
output: TosaArg,
3534
) -> None:
35+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36+
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45+
)
46+
47+
tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])
48+
49+
50+
@register_node_visitor
51+
class SigmoidVisitor(NodeVisitor):
52+
target = "aten.sigmoid.default"
53+
54+
# INT case should be handled by op_table
55+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56+
57+
def __init__(self, *args):
58+
super().__init__(*args)
59+
60+
def define_node(
61+
self,
62+
node: Node,
63+
tosa_graph: Any,
64+
inputs: List[TosaArg],
65+
output: TosaArg,
66+
) -> None:
67+
import serializer.tosa_serializer as ts
3668

3769
if len(node.all_input_nodes) != 1:
3870
raise ValueError(

‎backends/arm/operators/op_sin.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_specification import TosaSpecification
16+
from torch.fx import Node
17+
18+
19+
@register_node_visitor
20+
class SinVisitor(NodeVisitor):
21+
target = "aten.sin.default"
22+
23+
# INT case should be handled by op_table
24+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
25+
26+
def __init__(self, *args):
27+
super().__init__(*args)
28+
29+
def define_node(
30+
self,
31+
node: Node,
32+
tosa_graph: ts.TosaSerializer,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
if len(node.all_input_nodes) != 1:
37+
raise ValueError(
38+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
39+
)
40+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
41+
raise ValueError(
42+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
43+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
44+
)
45+
46+
tosa_graph.addOperator(ts.TosaOp.Op().SIN, [inputs[0].name], [output.name])

‎backends/arm/operators/op_tanh.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import (
1110
NodeVisitor,
1211
register_node_visitor,
@@ -17,7 +16,7 @@
1716

1817

1918
@register_node_visitor
20-
class TanhVisitor_080_MI(NodeVisitor):
19+
class TanhVisitor_0_80_MI(NodeVisitor):
2120
target = "aten.tanh.default"
2221

2322
# BI case should be handled by op_table
@@ -29,10 +28,44 @@ def __init__(self, *args):
2928
def define_node(
3029
self,
3130
node: Node,
32-
tosa_graph: ts.TosaSerializer,
31+
tosa_graph: Any,
3332
inputs: List[TosaArg],
3433
output: TosaArg,
3534
) -> None:
35+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36+
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45+
)
46+
47+
tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name])
48+
49+
50+
@register_node_visitor
51+
class TanhVisitor(NodeVisitor):
52+
target = "aten.tanh.default"
53+
54+
# INT case should be handled by op_table
55+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56+
57+
def __init__(self, *args):
58+
super().__init__(*args)
59+
60+
def define_node(
61+
self,
62+
node: Node,
63+
tosa_graph: Any,
64+
inputs: List[TosaArg],
65+
output: TosaArg,
66+
) -> None:
67+
import serializer.tosa_serializer as ts
68+
3669
if len(node.all_input_nodes) != 1:
3770
raise ValueError(
3871
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_upsample_nearest2d.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1211
from executorch.backends.arm.operators.node_visitor import (
1312
NodeVisitor,
1413
register_node_visitor,
@@ -20,19 +19,23 @@
2019

2120

2221
@register_node_visitor
23-
class UpsampleNearest2dVisitor(NodeVisitor):
22+
class UpsampleNearest2dVisitor_0_80(NodeVisitor):
2423
target = "aten.upsample_nearest2d.vec"
2524

25+
tosa_specs = NodeVisitor.tosa_specs_0_80
26+
2627
def __init__(self, *args):
2728
super().__init__(*args)
2829

2930
def define_node(
3031
self,
3132
node: torch.fx.Node,
32-
tosa_graph: ts.TosaSerializer,
33+
tosa_graph: Any,
3334
inputs: List[TosaArg],
3435
output: TosaArg,
3536
) -> None:
37+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
38+
3639
assert (
3740
inputs[0].shape is not None and output.shape is not None
3841
), "Only static shapes are supported"
@@ -67,3 +70,74 @@ def in_int16_range(x):
6770
tosa_graph.addOperator(
6871
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
6972
)
73+
74+
75+
@register_node_visitor
76+
class UpsampleNearest2dVisitor(NodeVisitor):
77+
target = "aten.upsample_nearest2d.vec"
78+
79+
tosa_specs = NodeVisitor.tosa_specs_1_00
80+
81+
def __init__(self, *args):
82+
super().__init__(*args)
83+
84+
def define_node(
85+
self,
86+
node: torch.fx.Node,
87+
tosa_graph: Any,
88+
inputs: List[TosaArg],
89+
output: TosaArg,
90+
) -> None:
91+
import serializer.tosa_serializer as ts
92+
93+
assert (
94+
inputs[0].shape is not None and output.shape is not None
95+
), "Only static shapes are supported"
96+
97+
# tosa_shape output is NHWC, take HW
98+
input_size_yx = torch.tensor(
99+
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
100+
)
101+
# Ignore scale and size parameters, directly use the output size as
102+
# we only support static shapes currently
103+
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
104+
105+
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
106+
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
107+
)
108+
109+
def in_int16_range(x):
110+
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
111+
112+
assert in_int16_range(scale_n_yx)
113+
assert in_int16_range(scale_d_yx)
114+
assert in_int16_range(border_yx)
115+
116+
scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]
117+
scales_tensor = tosa_graph.addConst(
118+
[len(scales)], ts.DType.SHAPE, scales, node.name + "_scales"
119+
)
120+
offset = offset_yx.tolist()
121+
offset_tensor = tosa_graph.addConst(
122+
[len(offset)], ts.DType.SHAPE, offset, node.name + "_offset"
123+
)
124+
border = border_yx.tolist()
125+
border_tensor = tosa_graph.addConst(
126+
[len(border)], ts.DType.SHAPE, border, node.name + "_border"
127+
)
128+
attr = ts.TosaSerializerAttribute()
129+
attr.ResizeAttribute(
130+
mode=ResizeMode.NEAREST,
131+
)
132+
133+
tosa_graph.addOperator(
134+
ts.TosaOp.Op().RESIZE,
135+
[
136+
inputs[0].name,
137+
scales_tensor.name,
138+
offset_tensor.name,
139+
border_tensor.name,
140+
],
141+
[output.name],
142+
attr,
143+
)

‎backends/arm/operators/op_where.py

Lines changed: 130 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import List, Sequence
7-
8-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
9-
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
6+
from typing import Any, List, Sequence
107

118
from executorch.backends.arm.operators.node_visitor import (
129
NodeVisitor,
@@ -17,69 +14,163 @@
1714
from torch.fx import Node
1815

1916

20-
def _add_node_to_tosa_graph(
21-
tosa_graph: ts.TosaSerializer,
22-
inputs: List[TosaArg],
23-
output: TosaArg,
24-
supported_dtypes: Sequence,
25-
) -> None:
26-
if len(inputs) != 3:
27-
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
28-
29-
if inputs[0].dtype is not ts.DType.BOOL:
30-
raise ValueError("Input 0 needs to have dtype BOOL")
31-
if inputs[1].dtype != inputs[2].dtype:
32-
raise ValueError(
33-
"Non-condition tensors must have same data type, got "
34-
f"{inputs[1].dtype} and {inputs[2].dtype}"
35-
)
36-
for input_ in inputs[1:]:
37-
if input_.dtype not in supported_dtypes:
17+
@register_node_visitor
18+
class WhereVisitor_0_80_BI(NodeVisitor):
19+
target = "aten.where.self"
20+
21+
tosa_specs = [
22+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
23+
]
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def _add_node_to_tosa_graph(
29+
self,
30+
tosa_graph: Any,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
supported_dtypes: Sequence,
34+
) -> None:
35+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36+
37+
if len(inputs) != 3:
38+
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
39+
40+
if inputs[0].dtype is not ts.DType.BOOL:
41+
raise ValueError("Input 0 needs to have dtype BOOL")
42+
if inputs[1].dtype != inputs[2].dtype:
3843
raise ValueError(
39-
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
44+
"Non-condition tensors must have same data type, got "
45+
f"{inputs[1].dtype} and {inputs[2].dtype}"
4046
)
47+
for input_ in inputs[1:]:
48+
if input_.dtype not in supported_dtypes:
49+
raise ValueError(
50+
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
51+
)
52+
53+
tosa_graph.addOperator(
54+
ts.TosaOp.Op().SELECT,
55+
[inputs[0].name, inputs[1].name, inputs[2].name],
56+
[output.name],
57+
None,
58+
)
4159

42-
tosa_graph.addOperator(
43-
TosaOp.Op().SELECT,
44-
[inputs[0].name, inputs[1].name, inputs[2].name],
45-
[output.name],
46-
None,
47-
)
60+
def define_node(
61+
self,
62+
node: Node,
63+
tosa_graph: Any,
64+
inputs: List[TosaArg],
65+
output: TosaArg,
66+
) -> None:
67+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
68+
69+
bi_supported_dtypes = [
70+
ts.DType.INT8,
71+
ts.DType.INT16,
72+
ts.DType.INT32,
73+
ts.DType.BOOL,
74+
]
75+
self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
4876

4977

5078
@register_node_visitor
51-
class WhereVisitor_080_BI(NodeVisitor):
79+
class WhereVisitor_0_80_MI(WhereVisitor_0_80_BI):
80+
81+
tosa_specs = [
82+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
83+
]
84+
85+
def __init__(self, *args):
86+
super().__init__(*args)
87+
88+
def define_node(
89+
self,
90+
node: Node,
91+
tosa_graph: Any,
92+
inputs: List[TosaArg],
93+
output: TosaArg,
94+
) -> None:
95+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
96+
97+
mi_supported_dtypes = [
98+
ts.DType.FP16,
99+
ts.DType.FP32,
100+
ts.DType.INT8,
101+
ts.DType.INT16,
102+
ts.DType.INT32,
103+
ts.DType.BOOL,
104+
]
105+
self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
106+
107+
108+
@register_node_visitor
109+
class WhereVisitor_INT(NodeVisitor):
52110
target = "aten.where.self"
53111

54112
tosa_specs = [
55-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
113+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
56114
]
57115

58116
def __init__(self, *args):
59117
super().__init__(*args)
60118

119+
def _add_node_to_tosa_graph(
120+
self,
121+
tosa_graph: Any,
122+
inputs: List[TosaArg],
123+
output: TosaArg,
124+
supported_dtypes: Sequence,
125+
) -> None:
126+
import serializer.tosa_serializer as ts
127+
128+
if len(inputs) != 3:
129+
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
130+
131+
if inputs[0].dtype is not ts.DType.BOOL:
132+
raise ValueError("Input 0 needs to have dtype BOOL")
133+
if inputs[1].dtype != inputs[2].dtype:
134+
raise ValueError(
135+
"Non-condition tensors must have same data type, got "
136+
f"{inputs[1].dtype} and {inputs[2].dtype}"
137+
)
138+
for input_ in inputs[1:]:
139+
if input_.dtype not in supported_dtypes:
140+
raise ValueError(
141+
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
142+
)
143+
144+
tosa_graph.addOperator(
145+
ts.TosaOp.Op().SELECT,
146+
[inputs[0].name, inputs[1].name, inputs[2].name],
147+
[output.name],
148+
None,
149+
)
150+
61151
def define_node(
62152
self,
63153
node: Node,
64-
tosa_graph: ts.TosaSerializer,
154+
tosa_graph: Any,
65155
inputs: List[TosaArg],
66156
output: TosaArg,
67157
) -> None:
158+
import serializer.tosa_serializer as ts
68159

69160
bi_supported_dtypes = [
70161
ts.DType.INT8,
71162
ts.DType.INT16,
72163
ts.DType.INT32,
73164
ts.DType.BOOL,
74165
]
75-
_add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
166+
self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
76167

77168

78169
@register_node_visitor
79-
class WhereVisitor_080_MI(WhereVisitor_080_BI):
170+
class WhereVisitor_FP(WhereVisitor_INT):
80171

81172
tosa_specs = [
82-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
173+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
83174
]
84175

85176
def __init__(self, *args):
@@ -88,10 +179,12 @@ def __init__(self, *args):
88179
def define_node(
89180
self,
90181
node: Node,
91-
tosa_graph: ts.TosaSerializer,
182+
tosa_graph: Any,
92183
inputs: List[TosaArg],
93184
output: TosaArg,
94185
) -> None:
186+
import serializer.tosa_serializer as ts
187+
95188
mi_supported_dtypes = [
96189
ts.DType.FP16,
97190
ts.DType.FP32,
@@ -100,4 +193,4 @@ def define_node(
100193
ts.DType.INT32,
101194
ts.DType.BOOL,
102195
]
103-
_add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
196+
self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)

‎backends/arm/operators/ops_binary.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,62 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import torch
1111
import torch.fx
1212

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14-
1513
from executorch.backends.arm.operators.node_visitor import (
1614
NodeVisitor,
1715
register_node_visitor,
1816
)
1917
from executorch.backends.arm.tosa_mapping import TosaArg
2018

2119

20+
def binary_operator_factory_0_80(bw_target: str, tosa_op):
21+
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
22+
23+
class BinaryOperator_0_80(NodeVisitor):
24+
target = bw_target
25+
tosa_specs = NodeVisitor.tosa_specs_0_80
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: Any,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
) -> None:
34+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401
35+
36+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and outputs need same dtype."
39+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}."
40+
)
41+
42+
tosa_graph.addOperator(
43+
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
44+
)
45+
46+
register_node_visitor(BinaryOperator_0_80)
47+
48+
2249
def binary_operator_factory(bw_target: str, tosa_op):
2350
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
2451

2552
class BinaryOperator(NodeVisitor):
2653
target = bw_target
54+
tosa_specs = NodeVisitor.tosa_specs_1_00
2755

2856
def define_node(
2957
self,
3058
node: torch.fx.Node,
31-
tosa_graph: ts.TosaSerializer,
59+
tosa_graph: Any,
3260
inputs: List[TosaArg],
3361
output: TosaArg,
3462
) -> None:
63+
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
3564

3665
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
3766
raise ValueError(
@@ -46,6 +75,20 @@ def define_node(
4675
register_node_visitor(BinaryOperator)
4776

4877

78+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
79+
80+
binary_operator_factory_0_80("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND)
81+
binary_operator_factory_0_80("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR)
82+
binary_operator_factory_0_80("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR)
83+
binary_operator_factory_0_80("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND)
84+
binary_operator_factory_0_80("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR)
85+
binary_operator_factory_0_80("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR)
86+
binary_operator_factory_0_80(
87+
"aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT
88+
)
89+
90+
import serializer.tosa_serializer as ts # type: ignore
91+
4992
binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND)
5093
binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR)
5194
binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR)

‎backends/arm/operators/ops_unary.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch.fx
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1211
from executorch.backends.arm.operators.node_visitor import (
1312
NodeVisitor,
1413
register_node_visitor,
@@ -17,6 +16,44 @@
1716
from executorch.backends.arm.tosa_mapping import TosaArg
1817

1918

19+
def unary_operator_factory_0_80(unary_target: str, tosa_op):
20+
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
21+
22+
# Some TOSA unary operators only support float
23+
fp_only_ops = ["aten.floor.default"]
24+
25+
class UnaryOperator_0_80(NodeVisitor):
26+
target = unary_target
27+
tosa_specs = NodeVisitor.tosa_specs_0_80
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: torch.fx.Node,
35+
tosa_graph: Any,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401
40+
41+
if not (inputs[0].dtype == output.dtype):
42+
raise ValueError(
43+
"All inputs and output need same dtype."
44+
f"Got {inputs[0].dtype=}, {output.dtype=}"
45+
)
46+
47+
if self.target in fp_only_ops and not (inputs[0].dtype == ts.DType.FP32):
48+
raise ValueError(
49+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
50+
)
51+
52+
tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name])
53+
54+
register_node_visitor(UnaryOperator_0_80)
55+
56+
2057
def unary_operator_factory(unary_target: str, tosa_op):
2158
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
2259

@@ -25,17 +62,19 @@ def unary_operator_factory(unary_target: str, tosa_op):
2562

2663
class UnaryOperator(NodeVisitor):
2764
target = unary_target
65+
tosa_specs = NodeVisitor.tosa_specs_1_00
2866

2967
def __init__(self, *args):
3068
super().__init__(*args)
3169

3270
def define_node(
3371
self,
3472
node: torch.fx.Node,
35-
tosa_graph: ts.TosaSerializer,
73+
tosa_graph: Any,
3674
inputs: List[TosaArg],
3775
output: TosaArg,
3876
) -> None:
77+
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
3978

4079
if not (inputs[0].dtype == output.dtype):
4180
raise ValueError(
@@ -53,6 +92,14 @@ def define_node(
5392
register_node_visitor(UnaryOperator)
5493

5594

95+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
96+
97+
unary_operator_factory_0_80("aten.ceil.default", ts.TosaOp.Op().CEIL)
98+
unary_operator_factory_0_80("aten.floor.default", ts.TosaOp.Op().FLOOR)
99+
unary_operator_factory_0_80("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT)
100+
101+
import serializer.tosa_serializer as ts # type: ignore
102+
56103
unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL)
57104
unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR)
58105
unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT)

‎backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def _match_pattern(
171171
torch.ops.aten.reciprocal.default,
172172
torch.ops.aten.rsqrt.default,
173173
torch.ops.aten.sigmoid.default,
174+
torch.ops.aten.cos.default,
175+
torch.ops.aten.sin.default,
174176
torch.ops.aten.tanh.default,
175177
torch.ops.aten.sum.dim_IntList,
176178
torch.ops.aten.hardsigmoid.default,

‎backends/arm/test/misc/test_multiple_delegates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_inputs(self):
2020

2121
def forward(self, x: torch.Tensor, y: torch.Tensor):
2222
z = x + y
23-
s = torch.sin(z)
23+
s = torch.tan(z)
2424
return s * z
2525

2626
def test_tosa_MI(self):

‎backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 55 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -2,143 +2,74 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
65
#
76
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
87
#
9-
import unittest
108
from typing import Tuple
119

1210
import torch
13-
import torch.nn as nn
1411
import torch.nn.functional as F
1512
from executorch.backends.arm.test import common
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17-
from parameterized import parameterized
18-
19-
test_data_suite = [
20-
("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
21-
("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
22-
("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
23-
("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
24-
("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
25-
("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
26-
("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
27-
("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
28-
("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
29-
]
30-
31-
32-
class TestConstantPadND(unittest.TestCase):
33-
"""Tests pad."""
34-
35-
class ConstantPadND(torch.nn.Module):
36-
def __init__(self, pad: Tuple, value: float | None = None):
37-
super().__init__()
38-
self.dim = len(pad) // 2
39-
self.value = value
40-
in_channels = 1
41-
# Only apply conv2d when the input dim = 4.
42-
if self.dim == 4:
43-
in_channels += pad[-3] + pad[-4]
44-
45-
self.conv2d = nn.Conv2d(
46-
in_channels=in_channels,
47-
out_channels=3,
48-
kernel_size=3,
49-
bias=True,
50-
stride=(2, 2),
51-
padding=0,
52-
)
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
5317

54-
in_channels = 3
55-
in_channels += pad[-3] + pad[-4]
56-
self.conv2d_1 = nn.Conv2d(
57-
in_channels=in_channels,
58-
out_channels=3,
59-
kernel_size=3,
60-
bias=True,
61-
padding="same",
62-
)
18+
aten_op = "torch.ops.aten.pad.default"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
test_data_suite = {
22+
"4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
23+
"4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
24+
"4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
25+
"4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
26+
"3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
27+
"3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
28+
"3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
29+
"2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
30+
"2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
31+
}
32+
"""Tests pad."""
6333

64-
nonzero_idx = len(pad)
65-
for i in range(0, len(pad), 2):
66-
if pad[i] + pad[i + 1] == 0:
67-
nonzero_idx = i
68-
break
69-
self.pad = pad[:nonzero_idx]
70-
self.relu = nn.ReLU()
71-
self.sigmoid = nn.Sigmoid()
7234

73-
def forward(self, x: torch.Tensor):
74-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
75-
if self.dim == 4:
76-
x = self.conv2d(x)
77-
x = self.relu(x)
35+
class ConstantPadND(torch.nn.Module):
36+
def __init__(self, pad: Tuple, value: float | None = None):
37+
super().__init__()
38+
self.value = value
39+
nonzero_idx = len(pad)
40+
for i in range(0, len(pad), 2):
41+
if pad[i] + pad[i + 1] == 0:
42+
nonzero_idx = i
43+
break
44+
self.pad = pad[:nonzero_idx]
7845

79-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
80-
if self.dim == 4:
81-
x = self.conv2d_1(x)
82-
x = self.sigmoid(x)
83-
return x
46+
def forward(self, x: torch.Tensor):
47+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
48+
return x
8449

85-
def _test_constant_pad_nd_tosa_MI_pipeline(
86-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
87-
):
88-
(
89-
ArmTester(
90-
module,
91-
example_inputs=test_data,
92-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
93-
)
94-
.export()
95-
.check_count({"torch.ops.aten.pad.default": 2})
96-
.to_edge()
97-
.partition()
98-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99-
.to_executorch()
100-
.run_method_and_compare_outputs(inputs=test_data)
101-
)
10250

103-
def _test_constant_pad_nd_tosa_BI_pipeline(
104-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
105-
):
106-
(
107-
ArmTester(
108-
module,
109-
example_inputs=test_data,
110-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
111-
)
112-
.quantize()
113-
.export()
114-
.check_count({"torch.ops.aten.pad.default": 2})
115-
.to_edge()
116-
.partition()
117-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
118-
.to_executorch()
119-
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
120-
)
51+
@common.parametrize(
52+
"test_data",
53+
test_data_suite,
54+
)
55+
def test_constant_pad_nd_tosa_MI(test_data: Tuple):
56+
test_data, padding, value = test_data
57+
pipeline = TosaPipelineMI[input_t1](
58+
ConstantPadND(padding, value),
59+
(test_data,),
60+
aten_op,
61+
exir_op,
62+
)
63+
pipeline.run()
12164

122-
@parameterized.expand(test_data_suite)
123-
def test_constant_pad_nd_tosa_MI(
124-
self,
125-
test_name: str,
126-
test_data: torch.Tensor,
127-
padding: Tuple,
128-
value: float | None = None,
129-
):
130-
self._test_constant_pad_nd_tosa_MI_pipeline(
131-
self.ConstantPadND(padding, value), (test_data,)
132-
)
13365

134-
@parameterized.expand(test_data_suite)
135-
def test_constant_pad_nd_tosa_BI(
136-
self,
137-
test_name: str,
138-
test_data: torch.Tensor,
139-
padding: Tuple,
140-
value: float | None = None,
141-
):
142-
self._test_constant_pad_nd_tosa_BI_pipeline(
143-
self.ConstantPadND(padding, value), (test_data,)
144-
)
66+
@common.parametrize("test_data", test_data_suite)
67+
def test_constant_pad_nd_tosa_BI(test_data: Tuple):
68+
test_data, padding, value = test_data
69+
pipeline = TosaPipelineBI[input_t1](
70+
ConstantPadND(padding, value),
71+
(test_data,),
72+
aten_op,
73+
exir_op,
74+
)
75+
pipeline.run()
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
#
7+
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
8+
#
9+
10+
from typing import Tuple
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.test_pipeline import (
17+
TosaPipelineBI,
18+
TosaPipelineMI,
19+
)
20+
21+
aten_op = "torch.ops.aten.pad.default"
22+
exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default"
23+
24+
input_t1 = Tuple[torch.Tensor] # Input x
25+
26+
test_data_suite = {
27+
"4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
28+
"4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
29+
"4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
30+
"4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
31+
"3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
32+
"3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
33+
"3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
34+
"2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
35+
"2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
36+
}
37+
38+
39+
"""Tests conv + pad."""
40+
41+
42+
class ConstantPadND(torch.nn.Module):
43+
def __init__(self, pad: Tuple, value: float | None = None):
44+
super().__init__()
45+
self.dim = len(pad) // 2
46+
self.value = value
47+
in_channels = 1
48+
# Only apply conv2d when the input dim = 4.
49+
if self.dim == 4:
50+
in_channels += pad[-3] + pad[-4]
51+
52+
self.conv2d = nn.Conv2d(
53+
in_channels=in_channels,
54+
out_channels=3,
55+
kernel_size=3,
56+
bias=True,
57+
stride=(2, 2),
58+
padding=0,
59+
)
60+
61+
in_channels = 3
62+
in_channels += pad[-3] + pad[-4]
63+
self.conv2d_1 = nn.Conv2d(
64+
in_channels=in_channels,
65+
out_channels=3,
66+
kernel_size=3,
67+
bias=True,
68+
padding="same",
69+
)
70+
71+
nonzero_idx = len(pad)
72+
for i in range(0, len(pad), 2):
73+
if pad[i] + pad[i + 1] == 0:
74+
nonzero_idx = i
75+
break
76+
self.pad = pad[:nonzero_idx]
77+
self.relu = nn.ReLU()
78+
self.sigmoid = nn.Sigmoid()
79+
80+
def forward(self, x: torch.Tensor):
81+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
82+
if self.dim == 4:
83+
x = self.conv2d(x)
84+
x = self.relu(x)
85+
86+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
87+
if self.dim == 4:
88+
x = self.conv2d_1(x)
89+
x = self.sigmoid(x)
90+
return x
91+
92+
93+
@common.parametrize("test_data", test_data_suite)
94+
def test_constant_pad_nd_tosa_MI(test_data: Tuple):
95+
test_data, padding, value = test_data
96+
pipeline = TosaPipelineMI[input_t1](
97+
ConstantPadND(padding, value),
98+
(test_data,),
99+
aten_op,
100+
exir_op,
101+
)
102+
pipeline.run()
103+
104+
105+
@common.parametrize("test_data", test_data_suite)
106+
def test_constant_pad_nd_tosa_BI(test_data: Tuple):
107+
test_data, padding, value = test_data
108+
pipeline = TosaPipelineBI[input_t1](
109+
ConstantPadND(padding, value),
110+
(test_data,),
111+
aten_op,
112+
exir_op,
113+
)
114+
pipeline.run()

‎backends/arm/test/ops/test_cos.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
aten_op = "torch.ops.aten.cos.default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
test_data_suite = {
23+
# (test_name, test_data)
24+
"zeros": torch.zeros(10, 10, 10, 10),
25+
"ones": torch.ones(10, 10, 10),
26+
"rand": torch.rand(10, 10) - 0.5,
27+
"randn_pos": torch.randn(10) + 10,
28+
"randn_neg": torch.randn(10) - 10,
29+
"ramp": torch.arange(-16, 16, 0.2),
30+
}
31+
32+
33+
class Cos(torch.nn.Module):
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.cos(x)
37+
38+
39+
@common.parametrize("test_data", test_data_suite)
40+
def test_cos_tosa_MI(test_data: Tuple):
41+
pipeline = TosaPipelineMI[input_t1](
42+
Cos(),
43+
(test_data,),
44+
aten_op,
45+
exir_op=[],
46+
)
47+
if conftest.get_option("tosa_version") == "1.0":
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", test_data_suite)
52+
def test_cos_tosa_BI(test_data: Tuple):
53+
pipeline = TosaPipelineBI[input_t1](
54+
Cos(),
55+
(test_data,),
56+
aten_op,
57+
exir_op=[],
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_suite)
63+
def test_cos_tosa_u55_BI(test_data: Tuple):
64+
pipeline = EthosU55PipelineBI[input_t1](
65+
Cos(),
66+
(test_data,),
67+
aten_op,
68+
exir_ops=[],
69+
run_on_fvp=False,
70+
)
71+
pipeline.run()
72+
73+
74+
@common.parametrize("test_data", test_data_suite)
75+
def test_cos_tosa_u85_BI(test_data: Tuple):
76+
pipeline = EthosU85PipelineBI[input_t1](
77+
Cos(),
78+
(test_data,),
79+
aten_op,
80+
exir_ops=[],
81+
run_on_fvp=False,
82+
)
83+
pipeline.run()

‎backends/arm/test/ops/test_sin.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
aten_op = "torch.ops.aten.sin.default"
20+
input_t1 = Tuple[torch.Tensor] # Input x
21+
22+
test_data_suite = {
23+
# (test_name, test_data)
24+
"zeros": torch.zeros(10, 10, 10, 10),
25+
"ones": torch.ones(10, 10, 10),
26+
"rand": torch.rand(10, 10) - 0.5,
27+
"randn_pos": torch.randn(10) + 10,
28+
"randn_neg": torch.randn(10) - 10,
29+
"ramp": torch.arange(-16, 16, 0.2),
30+
}
31+
32+
33+
class Sin(torch.nn.Module):
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.sin(x)
37+
38+
39+
@common.parametrize("test_data", test_data_suite)
40+
def test_sin_tosa_MI(test_data: Tuple):
41+
pipeline = TosaPipelineMI[input_t1](
42+
Sin(),
43+
(test_data,),
44+
aten_op,
45+
exir_op=[],
46+
)
47+
if conftest.get_option("tosa_version") == "1.0":
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", test_data_suite)
52+
def test_sin_tosa_BI(test_data: Tuple):
53+
pipeline = TosaPipelineBI[input_t1](
54+
Sin(),
55+
(test_data,),
56+
aten_op,
57+
exir_op=[],
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", test_data_suite)
63+
def test_sin_tosa_u55_BI(test_data: Tuple):
64+
pipeline = EthosU55PipelineBI[input_t1](
65+
Sin(),
66+
(test_data,),
67+
aten_op,
68+
exir_ops=[],
69+
run_on_fvp=False,
70+
)
71+
pipeline.run()
72+
73+
74+
@common.parametrize("test_data", test_data_suite)
75+
def test_sin_tosa_u85_BI(test_data: Tuple):
76+
pipeline = EthosU85PipelineBI[input_t1](
77+
Sin(),
78+
(test_data,),
79+
aten_op,
80+
exir_ops=[],
81+
run_on_fvp=False,
82+
)
83+
pipeline.run()

0 commit comments

Comments
 (0)
Please sign in to comment.