Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit cc73c7e

Browse files
authoredApr 24, 2025··
Revert "Arm backend: Update more node visitors to support TOSA 1.0" (#10455)
Reverts #10425
1 parent 1432243 commit cc73c7e

32 files changed

+236
-1519
lines changed
 

‎backends/arm/_passes/insert_table_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ class TableOps:
4848
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4949
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
5050
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
51-
exir_ops.edge.aten.cos.default: torch.cos,
52-
exir_ops.edge.aten.sin.default: torch.sin,
5351
exir_ops.edge.aten.tanh.default: torch.tanh,
5452
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5553
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,

‎backends/arm/operator_support/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
pool_2d_support,
1313
reduce_sum_support,
1414
right_shift_support,
15-
sin_cos_support,
1615
slice_copy_support,
1716
to_copy_support,
1817
tosa_supported_operators,

‎backends/arm/operator_support/sin_cos_support.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

‎backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
EthosU55NotSupported,
2424
EthosU55TransposeCheck,
2525
)
26-
from executorch.backends.arm.tosa_specification import (
27-
Tosa_0_80,
28-
Tosa_1_00,
29-
TosaSpecification,
30-
)
26+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
3127
from executorch.exir import ExportedProgram
3228
from executorch.exir.backend.utils import WhyNoPartitionReporter
3329
from executorch.exir.dialects._ops import ops as exir_ops
@@ -128,9 +124,7 @@ def tosa_support_factory(
128124
if not tosa_spec.support_float():
129125
negative_checks.append(NeedsDecompositionCheck(reporter))
130126
negative_checks.append(CheckProperQuantization(reporter))
131-
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
132-
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
133-
):
127+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
134128
negative_checks.append(EthosU55NotSupported(reporter))
135129
negative_checks.append(EthosU55DtypeSupport(reporter))
136130
negative_checks.append(EthosU55TransposeCheck(reporter))

‎backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
op_clamp,
1919
op_constant_pad_nd,
2020
op_conv2d,
21-
op_cos,
2221
op_eq,
2322
op_erf,
2423
op_exp,
@@ -39,7 +38,6 @@
3938
op_rshift_tensor,
4039
op_rsqrt,
4140
op_sigmoid,
42-
op_sin,
4341
op_slice,
4442
op_sub,
4543
op_sum,

‎backends/arm/operators/op_any.py

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, cast, List
7+
from typing import cast, List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1011
NodeVisitor,
1112
register_node_visitor,
@@ -15,59 +16,17 @@
1516
from torch.fx import Node
1617

1718

18-
@register_node_visitor
19-
class AnyVisitor_0_80(NodeVisitor):
20-
target = "aten.any.dim"
21-
22-
tosa_specs = NodeVisitor.tosa_specs_0_80
23-
24-
def define_node(
25-
self,
26-
node: Node,
27-
tosa_graph: Any,
28-
inputs: List[TosaArg],
29-
output: TosaArg,
30-
) -> None:
31-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
32-
33-
if not (inputs[0].dtype == output.dtype):
34-
raise ValueError(
35-
"All inputs and outputs need same dtype."
36-
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
37-
)
38-
if not (inputs[0].dtype == ts.DType.BOOL):
39-
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
40-
41-
input_shape = list(inputs[0].shape)
42-
dim = cast(int, inputs[1].number) % len(
43-
input_shape
44-
) # process the negative index
45-
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
46-
if not keep_dim:
47-
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
48-
49-
attr = ts.TosaSerializerAttribute()
50-
attr.AxisAttribute(inputs[0].dim_order.index(dim))
51-
52-
tosa_graph.addOperator(
53-
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
54-
)
55-
56-
5719
@register_node_visitor
5820
class AnyVisitor(NodeVisitor):
5921
target = "aten.any.dim"
6022

61-
tosa_specs = NodeVisitor.tosa_specs_1_00
62-
6323
def define_node(
6424
self,
6525
node: Node,
66-
tosa_graph: Any,
26+
tosa_graph: ts.TosaSerializer,
6727
inputs: List[TosaArg],
6828
output: TosaArg,
6929
) -> None:
70-
import serializer.tosa_serializer as ts
7130

7231
if not (inputs[0].dtype == output.dtype):
7332
raise ValueError(
@@ -86,7 +45,7 @@ def define_node(
8645
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
8746

8847
attr = ts.TosaSerializerAttribute()
89-
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
48+
attr.AxisAttribute(inputs[0].dim_order.index(dim))
9049

9150
tosa_graph.addOperator(
9251
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr

‎backends/arm/operators/op_avg_pool2d.py

Lines changed: 7 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_input_qparams,
1315
get_output_qparams,
@@ -34,16 +36,14 @@ def __init__(self, *args):
3436
def _build_generic_avgpool2d(
3537
self,
3638
node: torch.fx.Node,
37-
tosa_graph: Any,
39+
tosa_graph: ts.TosaSerializer,
3840
inputs: List[TosaArg],
3941
output: TosaArg,
4042
input_zp: int,
4143
output_zp: int,
42-
accumulator_type: Any,
44+
accumulator_type: ts.DType,
4345
) -> None:
4446

45-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46-
4747
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
@@ -79,12 +79,10 @@ def _build_generic_avgpool2d(
7979
def define_node(
8080
self,
8181
node: torch.fx.Node,
82-
tosa_graph: Any,
82+
tosa_graph: ts.TosaSerializer,
8383
inputs: List[TosaArg],
8484
output: TosaArg,
8585
) -> None:
86-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
87-
8886
input_tensor = inputs[0]
8987
assert input_tensor.dtype == ts.DType.INT8
9088

@@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
112110
def define_node(
113111
self,
114112
node: torch.fx.Node,
115-
tosa_graph: Any,
113+
tosa_graph: ts.TosaSerializer,
116114
inputs: List[TosaArg],
117115
output: TosaArg,
118116
) -> None:
119-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120-
121-
assert (
122-
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123-
), "Only FP32 and INT8 supported"
124-
125-
if inputs[0].dtype == ts.DType.INT8:
126-
super().define_node(node, tosa_graph, inputs, output)
127-
128-
if inputs[0].dtype == ts.DType.FP32:
129-
accumulator_type = ts.DType.FP32
130-
# Initilize zero point to zero.
131-
input_zp = 0
132-
output_zp = 0
133-
134-
self._build_generic_avgpool2d(
135-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
136-
)
137-
138-
139-
@register_node_visitor
140-
class AvgPool2dVisitor(NodeVisitor):
141-
target = "aten.avg_pool2d.default"
142-
143-
tosa_specs = [
144-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145-
]
146-
147-
def __init__(self, *args):
148-
super().__init__(*args)
149-
150-
def _build_generic_avgpool2d(
151-
self,
152-
node: torch.fx.Node,
153-
tosa_graph: Any,
154-
inputs: List[TosaArg],
155-
output: TosaArg,
156-
input_zp: int,
157-
output_zp: int,
158-
accumulator_type: Any,
159-
) -> None:
160-
161-
import serializer.tosa_serializer as ts # type: ignore
162-
163-
input_tensor = inputs[0]
164-
kernel_size_list = inputs[1].special
165-
stride_size_list = inputs[2].special
166-
167-
try:
168-
pad_size_list = inputs[3].special
169-
pad_size_list = [
170-
pad_size_list[0],
171-
pad_size_list[0],
172-
pad_size_list[1],
173-
pad_size_list[1],
174-
]
175-
except IndexError:
176-
pad_size_list = [0, 0, 0, 0]
177-
178-
attr = ts.TosaSerializerAttribute()
179-
attr.AvgPool2dAttribute(
180-
kernel=kernel_size_list,
181-
stride=stride_size_list,
182-
pad=pad_size_list,
183-
acc_type=accumulator_type,
184-
)
185-
input_zp_tensor = tosa_graph.addConst(
186-
shape=[1], dtype=output.dtype, vals=[input_zp]
187-
)
188-
output_zp_tensor = tosa_graph.addConst(
189-
shape=[1], dtype=output.dtype, vals=[output_zp]
190-
)
191-
192-
tosa_graph.addOperator(
193-
ts.TosaOp.Op().AVG_POOL2D,
194-
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
195-
[output.name],
196-
attr,
197-
)
198-
199-
def define_node(
200-
self,
201-
node: torch.fx.Node,
202-
tosa_graph: Any,
203-
inputs: List[TosaArg],
204-
output: TosaArg,
205-
) -> None:
206-
import serializer.tosa_serializer as ts # type: ignore
207-
208-
input_tensor = inputs[0]
209-
assert input_tensor.dtype == ts.DType.INT8
210-
211-
accumulator_type = ts.DType.INT32
212-
213-
input_qargs = get_input_qparams(node)
214-
input_zp = input_qargs[0].zp
215-
216-
output_qargs = get_output_qparams(node)
217-
output_zp = output_qargs[0].zp
218-
219-
self._build_generic_avgpool2d(
220-
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
221-
)
222-
223-
224-
@register_node_visitor
225-
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
226-
target = "aten.avg_pool2d.default"
227-
228-
tosa_specs = [
229-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
230-
]
231-
232-
def __init__(self, *args):
233-
super().__init__(*args)
234-
235-
def define_node(
236-
self,
237-
node: torch.fx.Node,
238-
tosa_graph: Any,
239-
inputs: List[TosaArg],
240-
output: TosaArg,
241-
) -> None:
242-
import serializer.tosa_serializer as ts # type: ignore
243-
244117
assert (
245118
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
246119
), "Only FP32 and INT8 supported"

‎backends/arm/operators/op_cat.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

10+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1011
from executorch.backends.arm.operators.node_visitor import (
1112
NodeVisitor,
1213
register_node_visitor,
@@ -15,58 +16,20 @@
1516
from torch.fx import Node
1617

1718

18-
@register_node_visitor
19-
class CatVisitor_0_80(NodeVisitor):
20-
target = "aten.cat.default"
21-
22-
tosa_specs = NodeVisitor.tosa_specs_0_80
23-
24-
def __init__(self, *args):
25-
super().__init__(*args)
26-
27-
def define_node(
28-
self,
29-
node: Node,
30-
tosa_graph: Any,
31-
inputs: List[TosaArg],
32-
output: TosaArg,
33-
) -> None:
34-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
35-
36-
tensors = inputs[0].special
37-
dim = 0 if len(inputs) < 2 else inputs[1].number
38-
rank = len(output.shape)
39-
dim = (dim + rank) % rank
40-
dim = output.dim_order.index(dim)
41-
42-
attr = ts.TosaSerializerAttribute()
43-
attr.AxisAttribute(dim)
44-
45-
tosa_graph.addOperator(
46-
ts.TosaOp.Op().CONCAT,
47-
[tensor.name for tensor in tensors],
48-
[output.name],
49-
attr,
50-
)
51-
52-
5319
@register_node_visitor
5420
class CatVisitor(NodeVisitor):
5521
target = "aten.cat.default"
5622

57-
tosa_specs = NodeVisitor.tosa_specs_1_00
58-
5923
def __init__(self, *args):
6024
super().__init__(*args)
6125

6226
def define_node(
6327
self,
6428
node: Node,
65-
tosa_graph: Any,
29+
tosa_graph: ts.TosaSerializer,
6630
inputs: List[TosaArg],
6731
output: TosaArg,
6832
) -> None:
69-
import serializer.tosa_serializer as ts
7033

7134
tensors = inputs[0].special
7235
dim = 0 if len(inputs) < 2 else inputs[1].number
@@ -75,7 +38,7 @@ def define_node(
7538
dim = output.dim_order.index(dim)
7639

7740
attr = ts.TosaSerializerAttribute()
78-
attr.ConcatAttribute(dim)
41+
attr.AxisAttribute(dim)
7942

8043
tosa_graph.addOperator(
8144
ts.TosaOp.Op().CONCAT,

‎backends/arm/operators/op_constant_pad_nd.py

Lines changed: 5 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

1010
import torch
1111

12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
13+
1214
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1315
get_input_qparams,
1416
)
@@ -17,27 +19,20 @@
1719
register_node_visitor,
1820
)
1921
from executorch.backends.arm.tosa_mapping import TosaArg
20-
from executorch.backends.arm.tosa_specification import TosaSpecification
2122

2223

2324
@register_node_visitor
24-
class ConstantPadNDVisitor_0_80(NodeVisitor):
25+
class ConstantPadNDVisitor(NodeVisitor):
2526

2627
target = "aten.constant_pad_nd.default"
2728

28-
tosa_specs = [
29-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
30-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
31-
]
32-
3329
def define_node(
3430
self,
3531
node: torch.fx.Node,
36-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3733
inputs: List[TosaArg],
3834
output: TosaArg,
3935
) -> None:
40-
import tosa_tools.v0_80.serializer.tosa_serializer as ts
4136

4237
if inputs[0].dtype == ts.DType.INT8:
4338
input_qparams = get_input_qparams(node)
@@ -79,72 +74,3 @@ def define_node(
7974
tosa_graph.addOperator(
8075
ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr
8176
)
82-
83-
84-
@register_node_visitor
85-
class ConstantPadNDVisitor(NodeVisitor):
86-
87-
target = "aten.constant_pad_nd.default"
88-
89-
tosa_specs = [
90-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
91-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
92-
]
93-
94-
def define_node(
95-
self,
96-
node: torch.fx.Node,
97-
tosa_graph: Any,
98-
inputs: List[TosaArg],
99-
output: TosaArg,
100-
) -> None:
101-
102-
import serializer.tosa_serializer as ts # type: ignore
103-
104-
if inputs[0].dtype == ts.DType.INT8:
105-
input_qparams = get_input_qparams(node)
106-
qargs = input_qparams[0]
107-
pad_const_val = qargs.quantize_value(inputs[2].number).item()
108-
pad_const_dtype = ts.DType.INT8
109-
else:
110-
pad_const_val = inputs[2].number
111-
pad_const_dtype = inputs[0].dtype
112-
113-
rank = len(output.shape)
114-
# Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form
115-
# (padding_left, padding_right); to pad the last two dimensions, the pad has the form
116-
# (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding
117-
# values are in the reverse order. So, firstly we need to reverse the input padding parameters.
118-
input_pad = sum(
119-
[
120-
[inputs[1].special[i], inputs[1].special[i + 1]]
121-
for i in range(0, len(inputs[1].special), 2)
122-
][::-1],
123-
[],
124-
)
125-
# Then, add dummy zeros to make sure that both input_pad and output_pad has the same size.
126-
input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad
127-
# For PyTorch NCHW format, dim order is [0,...,rank-1]
128-
input_dim_order = list(range(rank))
129-
output_pad = [0] * rank * 2
130-
131-
# Map input padding parameters into output padding parameters. TOSA is NHWC format.
132-
for input_dim_idx, input_dim in enumerate(input_dim_order):
133-
output_dim_idx = output.dim_order.index(input_dim)
134-
output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[
135-
input_dim_idx * 2 : (input_dim_idx + 1) * 2
136-
]
137-
138-
padding = tosa_graph.addConst(
139-
shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad
140-
)
141-
142-
pad_const = tosa_graph.addConst(
143-
shape=[1], dtype=pad_const_dtype, vals=[pad_const_val]
144-
)
145-
146-
tosa_graph.addOperator(
147-
ts.TosaOp.Op().PAD,
148-
[inputs[0].name, padding.name, pad_const.name],
149-
[output.name],
150-
)

‎backends/arm/operators/op_cos.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

‎backends/arm/operators/op_erf.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55
# pyre-unsafe
6-
from typing import Any, List
6+
from typing import List
77

88
import torch.fx
9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
10+
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
911
from executorch.backends.arm.operators.node_visitor import (
1012
NodeVisitor,
1113
register_node_visitor,
@@ -27,49 +29,16 @@ def __init__(self, *args):
2729
def define_node(
2830
self,
2931
node: torch.fx.Node,
30-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3133
inputs: List[TosaArg],
3234
output: TosaArg,
3335
) -> None:
34-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
35-
36-
if not (inputs[0].dtype == output.dtype):
37-
raise ValueError(
38-
"All inputs and output need same dtype."
39-
f"Got {inputs[0].dtype=}, {output.dtype=}"
40-
)
41-
if not (inputs[0].dtype == ts.DType.FP32):
42-
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
43-
# MI lowering
44-
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])
45-
46-
47-
@register_node_visitor
48-
class ERFVisitor(NodeVisitor):
49-
target = "aten.erf.default"
50-
51-
# INT case handled by op_table
52-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
53-
54-
def __init__(self, *args):
55-
super().__init__(*args)
56-
57-
def define_node(
58-
self,
59-
node: torch.fx.Node,
60-
tosa_graph: Any,
61-
inputs: List[TosaArg],
62-
output: TosaArg,
63-
) -> None:
64-
import serializer.tosa_serializer as ts
65-
6636
if not (inputs[0].dtype == output.dtype):
6737
raise ValueError(
6838
"All inputs and output need same dtype."
6939
f"Got {inputs[0].dtype=}, {output.dtype=}"
7040
)
7141
if not (inputs[0].dtype == ts.DType.FP32):
7242
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
73-
7443
# MI lowering
75-
tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name])
44+
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])

‎backends/arm/operators/op_exp.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import (
1011
NodeVisitor,
1112
register_node_visitor,
@@ -28,43 +29,10 @@ def __init__(self, *args):
2829
def define_node(
2930
self,
3031
node: Node,
31-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3233
inputs: List[TosaArg],
3334
output: TosaArg,
3435
) -> None:
35-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36-
37-
if len(node.all_input_nodes) != 1:
38-
raise ValueError(
39-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40-
)
41-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42-
raise ValueError(
43-
f"Input and output for {self.target} need to be FP32, got input dtype: "
44-
f"{inputs[0].dtype} and output dtype: {output.dtype}"
45-
)
46-
47-
tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name])
48-
49-
50-
@register_node_visitor
51-
class ExpVisitor(NodeVisitor):
52-
target = "aten.exp.default"
53-
54-
# BI case should be handled by op_table
55-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56-
57-
def __init__(self, *args):
58-
super().__init__(*args)
59-
60-
def define_node(
61-
self,
62-
node: Node,
63-
tosa_graph: Any,
64-
inputs: List[TosaArg],
65-
output: TosaArg,
66-
) -> None:
67-
import serializer.tosa_serializer as ts
6836

6937
if len(node.all_input_nodes) != 1:
7038
raise ValueError(

‎backends/arm/operators/op_log.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import (
1011
NodeVisitor,
1112
register_node_visitor,
@@ -16,7 +17,7 @@
1617

1718

1819
@register_node_visitor
19-
class LogVisitor_0_80_MI(NodeVisitor):
20+
class LogVisitor(NodeVisitor):
2021
target = "aten.log.default"
2122

2223
# BI case should be handled by op_table
@@ -28,44 +29,10 @@ def __init__(self, *args):
2829
def define_node(
2930
self,
3031
node: Node,
31-
tosa_graph: Any,
32-
inputs: List[TosaArg],
33-
output: TosaArg,
34-
) -> None:
35-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36-
37-
if len(node.all_input_nodes) != 1:
38-
raise ValueError(
39-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40-
)
41-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42-
raise ValueError(
43-
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44-
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45-
)
46-
47-
tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name])
48-
49-
50-
@register_node_visitor
51-
class LogVisitor(NodeVisitor):
52-
target = "aten.log.default"
53-
54-
# INT case should be handled by op_table
55-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56-
57-
def __init__(self, *args):
58-
super().__init__(*args)
59-
60-
def define_node(
61-
self,
62-
node: Node,
63-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
6433
inputs: List[TosaArg],
6534
output: TosaArg,
6635
) -> None:
67-
import serializer.tosa_serializer as ts
68-
6936
if len(node.all_input_nodes) != 1:
7037
raise ValueError(
7138
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_max_pool2d.py

Lines changed: 5 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
1113
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1214
get_input_qparams,
1315
get_output_qparams,
@@ -17,29 +19,22 @@
1719
register_node_visitor,
1820
)
1921
from executorch.backends.arm.tosa_mapping import TosaArg
20-
from executorch.backends.arm.tosa_specification import TosaSpecification
2122

2223

2324
@register_node_visitor
24-
class MaxPool2dVisitor_0_80(NodeVisitor):
25+
class MaxPool2dVisitor(NodeVisitor):
2526
target = "aten.max_pool2d.default"
2627

27-
tosa_specs = [
28-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
30-
]
31-
3228
def __init__(self, *args):
3329
super().__init__(*args)
3430

3531
def define_node(
3632
self,
3733
node: torch.fx.Node,
38-
tosa_graph: Any,
34+
tosa_graph: ts.TosaSerializer,
3935
inputs: List[TosaArg],
4036
output: TosaArg,
4137
) -> None:
42-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4338

4439
input_tensor = inputs[0]
4540
kernel_size = inputs[1].special
@@ -85,53 +80,3 @@ def define_node(
8580
[output.name],
8681
attr,
8782
)
88-
89-
90-
@register_node_visitor
91-
class MaxPool2dVisitor(NodeVisitor):
92-
target = "aten.max_pool2d.default"
93-
94-
tosa_specs = [
95-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
96-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
97-
]
98-
99-
def __init__(self, *args):
100-
super().__init__(*args)
101-
102-
def define_node(
103-
self,
104-
node: torch.fx.Node,
105-
tosa_graph: Any,
106-
inputs: List[TosaArg],
107-
output: TosaArg,
108-
) -> None:
109-
110-
import serializer.tosa_serializer as ts # type: ignore
111-
112-
input_tensor = inputs[0]
113-
kernel_size = inputs[1].special
114-
stride = inputs[2].special
115-
116-
try:
117-
pad_size_list = inputs[3].special
118-
pad_size_list = [
119-
pad_size_list[0],
120-
pad_size_list[0],
121-
pad_size_list[1],
122-
pad_size_list[1],
123-
]
124-
except IndexError:
125-
pad_size_list = [0, 0, 0, 0]
126-
127-
attr = ts.TosaSerializerAttribute()
128-
attr.MaxPool2dAttribute(
129-
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1
130-
)
131-
132-
tosa_graph.addOperator(
133-
ts.TosaOp.Op().MAX_POOL2D,
134-
[input_tensor.name],
135-
[output.name],
136-
attr,
137-
)

‎backends/arm/operators/op_permute.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

1010
import torch
1111

12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1213
from executorch.backends.arm.operators.node_visitor import (
1314
NodeVisitor,
1415
register_node_visitor,
@@ -87,61 +88,20 @@ def transform_permutation_vector(permutation_vector: list[int], dim_order: list[
8788
return permutation_vector
8889

8990

90-
@register_node_visitor
91-
class PermuteVisitor_0_80(NodeVisitor):
92-
target = "aten.permute_copy.default"
93-
94-
tosa_specs = NodeVisitor.tosa_specs_0_80
95-
96-
def __init__(self, *args):
97-
super().__init__(*args)
98-
99-
def define_node(
100-
self,
101-
node: torch.fx.Node,
102-
tosa_graph: Any,
103-
inputs: List[TosaArg],
104-
output: TosaArg,
105-
) -> None:
106-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107-
108-
# The permutation vector describes a permutation P in default Pytorch dim_order.
109-
# For rank 4, the default dim_order NCHW.
110-
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
111-
permutation_vector = inputs[1].special
112-
113-
if output.dim_order != tuple(range(len(output.dim_order))):
114-
# the permutation vector can't be used directly if we are not in NCHW dim_order.
115-
# Transform to dim_order.
116-
permutation_vector = transform_permutation_vector(
117-
permutation_vector, output.dim_order
118-
)
119-
120-
attr = ts.TosaSerializerAttribute()
121-
attr.TransposeAttribute(permutation_vector)
122-
tosa_graph.addOperator(
123-
ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
124-
)
125-
126-
12791
@register_node_visitor
12892
class PermuteVisitor(NodeVisitor):
12993
target = "aten.permute_copy.default"
13094

131-
tosa_specs = NodeVisitor.tosa_specs_1_00
132-
13395
def __init__(self, *args):
13496
super().__init__(*args)
13597

13698
def define_node(
13799
self,
138100
node: torch.fx.Node,
139-
tosa_graph: Any,
101+
tosa_graph: ts.TosaSerializer,
140102
inputs: List[TosaArg],
141103
output: TosaArg,
142104
) -> None:
143-
import serializer.tosa_serializer as ts
144-
145105
# The permutation vector describes a permutation P in default Pytorch dim_order.
146106
# For rank 4, the default dim_order NCHW.
147107
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)

‎backends/arm/operators/op_pow.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

10+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
1011
from executorch.backends.arm.operators.node_visitor import (
1112
NodeVisitor,
1213
register_node_visitor,
@@ -30,53 +31,10 @@ def __init__(self, *args):
3031
def define_node(
3132
self,
3233
node: Node,
33-
tosa_graph: Any,
34+
tosa_graph: ts.TosaSerializer,
3435
inputs: List[TosaArg],
3536
output: TosaArg,
3637
) -> None:
37-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
38-
39-
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
40-
raise ValueError(
41-
"All inputs and outputs need same dtype."
42-
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}"
43-
)
44-
if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]:
45-
raise ValueError(
46-
f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}"
47-
)
48-
49-
tosa_graph.addOperator(
50-
ts.TosaOp.Op().POW,
51-
[
52-
inputs[0].name,
53-
inputs[1].name,
54-
],
55-
[output.name],
56-
None,
57-
)
58-
59-
60-
@register_node_visitor
61-
class PowVisitor(NodeVisitor):
62-
target = "aten.pow.Tensor_Tensor"
63-
64-
tosa_specs = [
65-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
66-
]
67-
68-
def __init__(self, *args):
69-
super().__init__(*args)
70-
71-
def define_node(
72-
self,
73-
node: Node,
74-
tosa_graph: Any,
75-
inputs: List[TosaArg],
76-
output: TosaArg,
77-
) -> None:
78-
import serializer.tosa_serializer as ts
79-
8038
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
8139
raise ValueError(
8240
"All inputs and outputs need same dtype."

‎backends/arm/operators/op_reciprocal.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1112
from executorch.backends.arm.operators.node_visitor import (
1213
NodeVisitor,
1314
register_node_visitor,
@@ -29,46 +30,10 @@ def __init__(self, *args):
2930
def define_node(
3031
self,
3132
node: torch.fx.Node,
32-
tosa_graph: Any,
33+
tosa_graph: ts.TosaSerializer,
3334
inputs: List[TosaArg],
3435
output: TosaArg,
3536
) -> None:
36-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
37-
38-
if len(node.all_input_nodes) != 1:
39-
raise ValueError(
40-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
41-
)
42-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
43-
raise ValueError(
44-
f"Input and output for {self.target} need to be FP32, got "
45-
f"{inputs[0].dtype=} and {output.dtype=}"
46-
)
47-
48-
tosa_graph.addOperator(
49-
ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
50-
)
51-
52-
53-
@register_node_visitor
54-
class ReciprocalVisitor(NodeVisitor):
55-
target = "aten.reciprocal.default"
56-
57-
# INT case should be handled by op_table
58-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
59-
60-
def __init__(self, *args):
61-
super().__init__(*args)
62-
63-
def define_node(
64-
self,
65-
node: torch.fx.Node,
66-
tosa_graph: Any,
67-
inputs: List[TosaArg],
68-
output: TosaArg,
69-
) -> None:
70-
import serializer.tosa_serializer as ts
71-
7237
if len(node.all_input_nodes) != 1:
7338
raise ValueError(
7439
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_rshift_tensor.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,67 +5,34 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

1010
import torch
1111

12+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
1213
from executorch.backends.arm.operators.node_visitor import (
1314
NodeVisitor,
1415
register_node_visitor,
1516
)
1617
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00
18-
19-
20-
@register_node_visitor
21-
class RshiftVisitor_0_80(NodeVisitor):
22-
target = "aten.bitwise_right_shift.Tensor"
23-
24-
tosa_specs = NodeVisitor.tosa_specs_0_80
25-
26-
def define_node(
27-
self,
28-
node: torch.fx.Node,
29-
tosa_graph: Any,
30-
inputs: List[TosaArg],
31-
output: TosaArg,
32-
) -> None:
33-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
34-
35-
attr = ts.TosaSerializerAttribute()
36-
round = False
37-
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
38-
# U55 only supports INT32 and round == True
39-
# TODO MLETORCH-525 Emulate round == False with different decomposition
40-
round = True
41-
attr.ArithmeticRightShiftAttribute(round=round)
42-
43-
tosa_graph.addOperator(
44-
ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
45-
[inputs[0].name, inputs[1].name],
46-
[output.name],
47-
attr,
48-
)
18+
from executorch.backends.arm.tosa_specification import Tosa_0_80
4919

5020

5121
@register_node_visitor
5222
class RshiftVisitor(NodeVisitor):
5323
target = "aten.bitwise_right_shift.Tensor"
5424

55-
tosa_specs = NodeVisitor.tosa_specs_1_00
56-
5725
def define_node(
5826
self,
5927
node: torch.fx.Node,
60-
tosa_graph: Any,
28+
tosa_graph: ts.TosaSerializer,
6129
inputs: List[TosaArg],
6230
output: TosaArg,
6331
) -> None:
64-
import serializer.tosa_serializer as ts
6532

6633
attr = ts.TosaSerializerAttribute()
6734
round = False
68-
if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions:
35+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
6936
# U55 only supports INT32 and round == True
7037
# TODO MLETORCH-525 Emulate round == False with different decomposition
7138
round = True

‎backends/arm/operators/op_rsqrt.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1112
from executorch.backends.arm.operators.node_visitor import (
1213
NodeVisitor,
1314
register_node_visitor,
@@ -29,44 +30,10 @@ def __init__(self, *args):
2930
def define_node(
3031
self,
3132
node: torch.fx.Node,
32-
tosa_graph: Any,
33+
tosa_graph: ts.TosaSerializer,
3334
inputs: List[TosaArg],
3435
output: TosaArg,
3536
) -> None:
36-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
37-
38-
if len(node.all_input_nodes) != 1:
39-
raise ValueError(
40-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
41-
)
42-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
43-
raise ValueError(
44-
f"Input and output for {self.target} need to be FP32, got "
45-
f"{inputs[0].dtype=} and {output.dtype=}"
46-
)
47-
48-
tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name])
49-
50-
51-
@register_node_visitor
52-
class RsqrtVisitor(NodeVisitor):
53-
target = "aten.rsqrt.default"
54-
55-
# INT case should be handled by op_table
56-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
57-
58-
def __init__(self, *args):
59-
super().__init__(*args)
60-
61-
def define_node(
62-
self,
63-
node: torch.fx.Node,
64-
tosa_graph: Any,
65-
inputs: List[TosaArg],
66-
output: TosaArg,
67-
) -> None:
68-
import serializer.tosa_serializer as ts
69-
7037
if len(node.all_input_nodes) != 1:
7138
raise ValueError(
7239
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_sigmoid.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import (
1011
NodeVisitor,
1112
register_node_visitor,
@@ -28,43 +29,10 @@ def __init__(self, *args):
2829
def define_node(
2930
self,
3031
node: Node,
31-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3233
inputs: List[TosaArg],
3334
output: TosaArg,
3435
) -> None:
35-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36-
37-
if len(node.all_input_nodes) != 1:
38-
raise ValueError(
39-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40-
)
41-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42-
raise ValueError(
43-
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44-
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45-
)
46-
47-
tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])
48-
49-
50-
@register_node_visitor
51-
class SigmoidVisitor(NodeVisitor):
52-
target = "aten.sigmoid.default"
53-
54-
# INT case should be handled by op_table
55-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56-
57-
def __init__(self, *args):
58-
super().__init__(*args)
59-
60-
def define_node(
61-
self,
62-
node: Node,
63-
tosa_graph: Any,
64-
inputs: List[TosaArg],
65-
output: TosaArg,
66-
) -> None:
67-
import serializer.tosa_serializer as ts
6836

6937
if len(node.all_input_nodes) != 1:
7038
raise ValueError(

‎backends/arm/operators/op_sin.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

‎backends/arm/operators/op_tanh.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

9+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
910
from executorch.backends.arm.operators.node_visitor import (
1011
NodeVisitor,
1112
register_node_visitor,
@@ -16,7 +17,7 @@
1617

1718

1819
@register_node_visitor
19-
class TanhVisitor_0_80_MI(NodeVisitor):
20+
class TanhVisitor_080_MI(NodeVisitor):
2021
target = "aten.tanh.default"
2122

2223
# BI case should be handled by op_table
@@ -28,44 +29,10 @@ def __init__(self, *args):
2829
def define_node(
2930
self,
3031
node: Node,
31-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3233
inputs: List[TosaArg],
3334
output: TosaArg,
3435
) -> None:
35-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36-
37-
if len(node.all_input_nodes) != 1:
38-
raise ValueError(
39-
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40-
)
41-
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42-
raise ValueError(
43-
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44-
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45-
)
46-
47-
tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name])
48-
49-
50-
@register_node_visitor
51-
class TanhVisitor(NodeVisitor):
52-
target = "aten.tanh.default"
53-
54-
# INT case should be handled by op_table
55-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
56-
57-
def __init__(self, *args):
58-
super().__init__(*args)
59-
60-
def define_node(
61-
self,
62-
node: Node,
63-
tosa_graph: Any,
64-
inputs: List[TosaArg],
65-
output: TosaArg,
66-
) -> None:
67-
import serializer.tosa_serializer as ts
68-
6936
if len(node.all_input_nodes) != 1:
7037
raise ValueError(
7138
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"

‎backends/arm/operators/op_upsample_nearest2d.py

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1112
from executorch.backends.arm.operators.node_visitor import (
1213
NodeVisitor,
1314
register_node_visitor,
@@ -19,23 +20,19 @@
1920

2021

2122
@register_node_visitor
22-
class UpsampleNearest2dVisitor_0_80(NodeVisitor):
23+
class UpsampleNearest2dVisitor(NodeVisitor):
2324
target = "aten.upsample_nearest2d.vec"
2425

25-
tosa_specs = NodeVisitor.tosa_specs_0_80
26-
2726
def __init__(self, *args):
2827
super().__init__(*args)
2928

3029
def define_node(
3130
self,
3231
node: torch.fx.Node,
33-
tosa_graph: Any,
32+
tosa_graph: ts.TosaSerializer,
3433
inputs: List[TosaArg],
3534
output: TosaArg,
3635
) -> None:
37-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
38-
3936
assert (
4037
inputs[0].shape is not None and output.shape is not None
4138
), "Only static shapes are supported"
@@ -70,74 +67,3 @@ def in_int16_range(x):
7067
tosa_graph.addOperator(
7168
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
7269
)
73-
74-
75-
@register_node_visitor
76-
class UpsampleNearest2dVisitor(NodeVisitor):
77-
target = "aten.upsample_nearest2d.vec"
78-
79-
tosa_specs = NodeVisitor.tosa_specs_1_00
80-
81-
def __init__(self, *args):
82-
super().__init__(*args)
83-
84-
def define_node(
85-
self,
86-
node: torch.fx.Node,
87-
tosa_graph: Any,
88-
inputs: List[TosaArg],
89-
output: TosaArg,
90-
) -> None:
91-
import serializer.tosa_serializer as ts
92-
93-
assert (
94-
inputs[0].shape is not None and output.shape is not None
95-
), "Only static shapes are supported"
96-
97-
# tosa_shape output is NHWC, take HW
98-
input_size_yx = torch.tensor(
99-
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
100-
)
101-
# Ignore scale and size parameters, directly use the output size as
102-
# we only support static shapes currently
103-
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])
104-
105-
scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
106-
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
107-
)
108-
109-
def in_int16_range(x):
110-
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
111-
112-
assert in_int16_range(scale_n_yx)
113-
assert in_int16_range(scale_d_yx)
114-
assert in_int16_range(border_yx)
115-
116-
scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]
117-
scales_tensor = tosa_graph.addConst(
118-
[len(scales)], ts.DType.SHAPE, scales, node.name + "_scales"
119-
)
120-
offset = offset_yx.tolist()
121-
offset_tensor = tosa_graph.addConst(
122-
[len(offset)], ts.DType.SHAPE, offset, node.name + "_offset"
123-
)
124-
border = border_yx.tolist()
125-
border_tensor = tosa_graph.addConst(
126-
[len(border)], ts.DType.SHAPE, border, node.name + "_border"
127-
)
128-
attr = ts.TosaSerializerAttribute()
129-
attr.ResizeAttribute(
130-
mode=ResizeMode.NEAREST,
131-
)
132-
133-
tosa_graph.addOperator(
134-
ts.TosaOp.Op().RESIZE,
135-
[
136-
inputs[0].name,
137-
scales_tensor.name,
138-
offset_tensor.name,
139-
border_tensor.name,
140-
],
141-
[output.name],
142-
attr,
143-
)

‎backends/arm/operators/op_where.py

Lines changed: 37 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List, Sequence
6+
from typing import List, Sequence
7+
8+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
9+
import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore
710

811
from executorch.backends.arm.operators.node_visitor import (
912
NodeVisitor,
@@ -14,163 +17,69 @@
1417
from torch.fx import Node
1518

1619

17-
@register_node_visitor
18-
class WhereVisitor_0_80_BI(NodeVisitor):
19-
target = "aten.where.self"
20-
21-
tosa_specs = [
22-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
23-
]
24-
25-
def __init__(self, *args):
26-
super().__init__(*args)
27-
28-
def _add_node_to_tosa_graph(
29-
self,
30-
tosa_graph: Any,
31-
inputs: List[TosaArg],
32-
output: TosaArg,
33-
supported_dtypes: Sequence,
34-
) -> None:
35-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
36-
37-
if len(inputs) != 3:
38-
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
39-
40-
if inputs[0].dtype is not ts.DType.BOOL:
41-
raise ValueError("Input 0 needs to have dtype BOOL")
42-
if inputs[1].dtype != inputs[2].dtype:
20+
def _add_node_to_tosa_graph(
21+
tosa_graph: ts.TosaSerializer,
22+
inputs: List[TosaArg],
23+
output: TosaArg,
24+
supported_dtypes: Sequence,
25+
) -> None:
26+
if len(inputs) != 3:
27+
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
28+
29+
if inputs[0].dtype is not ts.DType.BOOL:
30+
raise ValueError("Input 0 needs to have dtype BOOL")
31+
if inputs[1].dtype != inputs[2].dtype:
32+
raise ValueError(
33+
"Non-condition tensors must have same data type, got "
34+
f"{inputs[1].dtype} and {inputs[2].dtype}"
35+
)
36+
for input_ in inputs[1:]:
37+
if input_.dtype not in supported_dtypes:
4338
raise ValueError(
44-
"Non-condition tensors must have same data type, got "
45-
f"{inputs[1].dtype} and {inputs[2].dtype}"
39+
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
4640
)
47-
for input_ in inputs[1:]:
48-
if input_.dtype not in supported_dtypes:
49-
raise ValueError(
50-
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
51-
)
52-
53-
tosa_graph.addOperator(
54-
ts.TosaOp.Op().SELECT,
55-
[inputs[0].name, inputs[1].name, inputs[2].name],
56-
[output.name],
57-
None,
58-
)
5941

60-
def define_node(
61-
self,
62-
node: Node,
63-
tosa_graph: Any,
64-
inputs: List[TosaArg],
65-
output: TosaArg,
66-
) -> None:
67-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
68-
69-
bi_supported_dtypes = [
70-
ts.DType.INT8,
71-
ts.DType.INT16,
72-
ts.DType.INT32,
73-
ts.DType.BOOL,
74-
]
75-
self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
42+
tosa_graph.addOperator(
43+
TosaOp.Op().SELECT,
44+
[inputs[0].name, inputs[1].name, inputs[2].name],
45+
[output.name],
46+
None,
47+
)
7648

7749

7850
@register_node_visitor
79-
class WhereVisitor_0_80_MI(WhereVisitor_0_80_BI):
80-
81-
tosa_specs = [
82-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
83-
]
84-
85-
def __init__(self, *args):
86-
super().__init__(*args)
87-
88-
def define_node(
89-
self,
90-
node: Node,
91-
tosa_graph: Any,
92-
inputs: List[TosaArg],
93-
output: TosaArg,
94-
) -> None:
95-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
96-
97-
mi_supported_dtypes = [
98-
ts.DType.FP16,
99-
ts.DType.FP32,
100-
ts.DType.INT8,
101-
ts.DType.INT16,
102-
ts.DType.INT32,
103-
ts.DType.BOOL,
104-
]
105-
self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
106-
107-
108-
@register_node_visitor
109-
class WhereVisitor_INT(NodeVisitor):
51+
class WhereVisitor_080_BI(NodeVisitor):
11052
target = "aten.where.self"
11153

11254
tosa_specs = [
113-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
55+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
11456
]
11557

11658
def __init__(self, *args):
11759
super().__init__(*args)
11860

119-
def _add_node_to_tosa_graph(
120-
self,
121-
tosa_graph: Any,
122-
inputs: List[TosaArg],
123-
output: TosaArg,
124-
supported_dtypes: Sequence,
125-
) -> None:
126-
import serializer.tosa_serializer as ts
127-
128-
if len(inputs) != 3:
129-
raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}")
130-
131-
if inputs[0].dtype is not ts.DType.BOOL:
132-
raise ValueError("Input 0 needs to have dtype BOOL")
133-
if inputs[1].dtype != inputs[2].dtype:
134-
raise ValueError(
135-
"Non-condition tensors must have same data type, got "
136-
f"{inputs[1].dtype} and {inputs[2].dtype}"
137-
)
138-
for input_ in inputs[1:]:
139-
if input_.dtype not in supported_dtypes:
140-
raise ValueError(
141-
f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}"
142-
)
143-
144-
tosa_graph.addOperator(
145-
ts.TosaOp.Op().SELECT,
146-
[inputs[0].name, inputs[1].name, inputs[2].name],
147-
[output.name],
148-
None,
149-
)
150-
15161
def define_node(
15262
self,
15363
node: Node,
154-
tosa_graph: Any,
64+
tosa_graph: ts.TosaSerializer,
15565
inputs: List[TosaArg],
15666
output: TosaArg,
15767
) -> None:
158-
import serializer.tosa_serializer as ts
15968

16069
bi_supported_dtypes = [
16170
ts.DType.INT8,
16271
ts.DType.INT16,
16372
ts.DType.INT32,
16473
ts.DType.BOOL,
16574
]
166-
self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
75+
_add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes)
16776

16877

16978
@register_node_visitor
170-
class WhereVisitor_FP(WhereVisitor_INT):
79+
class WhereVisitor_080_MI(WhereVisitor_080_BI):
17180

17281
tosa_specs = [
173-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
82+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
17483
]
17584

17685
def __init__(self, *args):
@@ -179,12 +88,10 @@ def __init__(self, *args):
17988
def define_node(
18089
self,
18190
node: Node,
182-
tosa_graph: Any,
91+
tosa_graph: ts.TosaSerializer,
18392
inputs: List[TosaArg],
18493
output: TosaArg,
18594
) -> None:
186-
import serializer.tosa_serializer as ts
187-
18895
mi_supported_dtypes = [
18996
ts.DType.FP16,
19097
ts.DType.FP32,
@@ -193,4 +100,4 @@ def define_node(
193100
ts.DType.INT32,
194101
ts.DType.BOOL,
195102
]
196-
self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)
103+
_add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes)

‎backends/arm/operators/ops_binary.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,62 +5,33 @@
55

66
# pyre-unsafe
77

8-
from typing import Any, List
8+
from typing import List
99

1010
import torch
1111
import torch.fx
1212

13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14+
1315
from executorch.backends.arm.operators.node_visitor import (
1416
NodeVisitor,
1517
register_node_visitor,
1618
)
1719
from executorch.backends.arm.tosa_mapping import TosaArg
1820

1921

20-
def binary_operator_factory_0_80(bw_target: str, tosa_op):
21-
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
22-
23-
class BinaryOperator_0_80(NodeVisitor):
24-
target = bw_target
25-
tosa_specs = NodeVisitor.tosa_specs_0_80
26-
27-
def define_node(
28-
self,
29-
node: torch.fx.Node,
30-
tosa_graph: Any,
31-
inputs: List[TosaArg],
32-
output: TosaArg,
33-
) -> None:
34-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401
35-
36-
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
37-
raise ValueError(
38-
"All inputs and outputs need same dtype."
39-
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}."
40-
)
41-
42-
tosa_graph.addOperator(
43-
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
44-
)
45-
46-
register_node_visitor(BinaryOperator_0_80)
47-
48-
4922
def binary_operator_factory(bw_target: str, tosa_op):
5023
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
5124

5225
class BinaryOperator(NodeVisitor):
5326
target = bw_target
54-
tosa_specs = NodeVisitor.tosa_specs_1_00
5527

5628
def define_node(
5729
self,
5830
node: torch.fx.Node,
59-
tosa_graph: Any,
31+
tosa_graph: ts.TosaSerializer,
6032
inputs: List[TosaArg],
6133
output: TosaArg,
6234
) -> None:
63-
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
6435

6536
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
6637
raise ValueError(
@@ -75,20 +46,6 @@ def define_node(
7546
register_node_visitor(BinaryOperator)
7647

7748

78-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
79-
80-
binary_operator_factory_0_80("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND)
81-
binary_operator_factory_0_80("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR)
82-
binary_operator_factory_0_80("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR)
83-
binary_operator_factory_0_80("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND)
84-
binary_operator_factory_0_80("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR)
85-
binary_operator_factory_0_80("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR)
86-
binary_operator_factory_0_80(
87-
"aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT
88-
)
89-
90-
import serializer.tosa_serializer as ts # type: ignore
91-
9249
binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND)
9350
binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR)
9451
binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR)

‎backends/arm/operators/ops_unary.py

Lines changed: 3 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import Any, List
7+
from typing import List
88

99
import torch.fx
1010

11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1112
from executorch.backends.arm.operators.node_visitor import (
1213
NodeVisitor,
1314
register_node_visitor,
@@ -16,44 +17,6 @@
1617
from executorch.backends.arm.tosa_mapping import TosaArg
1718

1819

19-
def unary_operator_factory_0_80(unary_target: str, tosa_op):
20-
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
21-
22-
# Some TOSA unary operators only support float
23-
fp_only_ops = ["aten.floor.default"]
24-
25-
class UnaryOperator_0_80(NodeVisitor):
26-
target = unary_target
27-
tosa_specs = NodeVisitor.tosa_specs_0_80
28-
29-
def __init__(self, *args):
30-
super().__init__(*args)
31-
32-
def define_node(
33-
self,
34-
node: torch.fx.Node,
35-
tosa_graph: Any,
36-
inputs: List[TosaArg],
37-
output: TosaArg,
38-
) -> None:
39-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401
40-
41-
if not (inputs[0].dtype == output.dtype):
42-
raise ValueError(
43-
"All inputs and output need same dtype."
44-
f"Got {inputs[0].dtype=}, {output.dtype=}"
45-
)
46-
47-
if self.target in fp_only_ops and not (inputs[0].dtype == ts.DType.FP32):
48-
raise ValueError(
49-
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
50-
)
51-
52-
tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name])
53-
54-
register_node_visitor(UnaryOperator_0_80)
55-
56-
5720
def unary_operator_factory(unary_target: str, tosa_op):
5821
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
5922

@@ -62,19 +25,17 @@ def unary_operator_factory(unary_target: str, tosa_op):
6225

6326
class UnaryOperator(NodeVisitor):
6427
target = unary_target
65-
tosa_specs = NodeVisitor.tosa_specs_1_00
6628

6729
def __init__(self, *args):
6830
super().__init__(*args)
6931

7032
def define_node(
7133
self,
7234
node: torch.fx.Node,
73-
tosa_graph: Any,
35+
tosa_graph: ts.TosaSerializer,
7436
inputs: List[TosaArg],
7537
output: TosaArg,
7638
) -> None:
77-
import serializer.tosa_serializer as ts # type: ignore # noqa: F401
7839

7940
if not (inputs[0].dtype == output.dtype):
8041
raise ValueError(
@@ -92,14 +53,6 @@ def define_node(
9253
register_node_visitor(UnaryOperator)
9354

9455

95-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
96-
97-
unary_operator_factory_0_80("aten.ceil.default", ts.TosaOp.Op().CEIL)
98-
unary_operator_factory_0_80("aten.floor.default", ts.TosaOp.Op().FLOOR)
99-
unary_operator_factory_0_80("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT)
100-
101-
import serializer.tosa_serializer as ts # type: ignore
102-
10356
unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL)
10457
unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR)
10558
unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT)

‎backends/arm/quantizer/quantization_annotator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,6 @@ def _match_pattern(
177177
torch.ops.aten.reciprocal.default,
178178
torch.ops.aten.rsqrt.default,
179179
torch.ops.aten.sigmoid.default,
180-
torch.ops.aten.cos.default,
181-
torch.ops.aten.sin.default,
182180
torch.ops.aten.tanh.default,
183181
torch.ops.aten.sum.dim_IntList,
184182
torch.ops.aten.hardsigmoid.default,

‎backends/arm/test/misc/test_multiple_delegates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_inputs(self):
2020

2121
def forward(self, x: torch.Tensor, y: torch.Tensor):
2222
z = x + y
23-
s = torch.tan(z)
23+
s = torch.sin(z)
2424
return s * z
2525

2626
def test_tosa_MI(self):

‎backends/arm/test/ops/test_constant_pad_nd.py

Lines changed: 124 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,143 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
56
#
67
# Test the pad_constant_nd op which pads the input tensor at specific dimension(s).
78
#
9+
import unittest
810
from typing import Tuple
911

1012
import torch
13+
import torch.nn as nn
1114
import torch.nn.functional as F
1215
from executorch.backends.arm.test import common
13-
from executorch.backends.arm.test.tester.test_pipeline import (
14-
TosaPipelineBI,
15-
TosaPipelineMI,
16-
)
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from parameterized import parameterized
18+
19+
test_data_suite = [
20+
("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
21+
("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
22+
("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
23+
("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
24+
("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
25+
("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
26+
("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
27+
("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
28+
("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
29+
]
30+
31+
32+
class TestConstantPadND(unittest.TestCase):
33+
"""Tests pad."""
34+
35+
class ConstantPadND(torch.nn.Module):
36+
def __init__(self, pad: Tuple, value: float | None = None):
37+
super().__init__()
38+
self.dim = len(pad) // 2
39+
self.value = value
40+
in_channels = 1
41+
# Only apply conv2d when the input dim = 4.
42+
if self.dim == 4:
43+
in_channels += pad[-3] + pad[-4]
44+
45+
self.conv2d = nn.Conv2d(
46+
in_channels=in_channels,
47+
out_channels=3,
48+
kernel_size=3,
49+
bias=True,
50+
stride=(2, 2),
51+
padding=0,
52+
)
1753

18-
aten_op = "torch.ops.aten.pad.default"
19-
exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default"
20-
input_t1 = Tuple[torch.Tensor] # Input x
21-
test_data_suite = {
22-
"4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1),
23-
"4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2),
24-
"4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3),
25-
"4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4),
26-
"3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1),
27-
"3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2),
28-
"3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3),
29-
"2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1),
30-
"2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2),
31-
}
32-
"""Tests pad."""
54+
in_channels = 3
55+
in_channels += pad[-3] + pad[-4]
56+
self.conv2d_1 = nn.Conv2d(
57+
in_channels=in_channels,
58+
out_channels=3,
59+
kernel_size=3,
60+
bias=True,
61+
padding="same",
62+
)
3363

64+
nonzero_idx = len(pad)
65+
for i in range(0, len(pad), 2):
66+
if pad[i] + pad[i + 1] == 0:
67+
nonzero_idx = i
68+
break
69+
self.pad = pad[:nonzero_idx]
70+
self.relu = nn.ReLU()
71+
self.sigmoid = nn.Sigmoid()
3472

35-
class ConstantPadND(torch.nn.Module):
36-
def __init__(self, pad: Tuple, value: float | None = None):
37-
super().__init__()
38-
self.value = value
39-
nonzero_idx = len(pad)
40-
for i in range(0, len(pad), 2):
41-
if pad[i] + pad[i + 1] == 0:
42-
nonzero_idx = i
43-
break
44-
self.pad = pad[:nonzero_idx]
73+
def forward(self, x: torch.Tensor):
74+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
75+
if self.dim == 4:
76+
x = self.conv2d(x)
77+
x = self.relu(x)
4578

46-
def forward(self, x: torch.Tensor):
47-
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
48-
return x
79+
x = F.pad(x, pad=self.pad, mode="constant", value=self.value)
80+
if self.dim == 4:
81+
x = self.conv2d_1(x)
82+
x = self.sigmoid(x)
83+
return x
4984

85+
def _test_constant_pad_nd_tosa_MI_pipeline(
86+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
87+
):
88+
(
89+
ArmTester(
90+
module,
91+
example_inputs=test_data,
92+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
93+
)
94+
.export()
95+
.check_count({"torch.ops.aten.pad.default": 2})
96+
.to_edge()
97+
.partition()
98+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99+
.to_executorch()
100+
.run_method_and_compare_outputs(inputs=test_data)
101+
)
50102

51-
@common.parametrize(
52-
"test_data",
53-
test_data_suite,
54-
)
55-
def test_constant_pad_nd_tosa_MI(test_data: Tuple):
56-
test_data, padding, value = test_data
57-
pipeline = TosaPipelineMI[input_t1](
58-
ConstantPadND(padding, value),
59-
(test_data,),
60-
aten_op,
61-
exir_op,
62-
)
63-
pipeline.run()
103+
def _test_constant_pad_nd_tosa_BI_pipeline(
104+
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
105+
):
106+
(
107+
ArmTester(
108+
module,
109+
example_inputs=test_data,
110+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
111+
)
112+
.quantize()
113+
.export()
114+
.check_count({"torch.ops.aten.pad.default": 2})
115+
.to_edge()
116+
.partition()
117+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
118+
.to_executorch()
119+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
120+
)
64121

122+
@parameterized.expand(test_data_suite)
123+
def test_constant_pad_nd_tosa_MI(
124+
self,
125+
test_name: str,
126+
test_data: torch.Tensor,
127+
padding: Tuple,
128+
value: float | None = None,
129+
):
130+
self._test_constant_pad_nd_tosa_MI_pipeline(
131+
self.ConstantPadND(padding, value), (test_data,)
132+
)
65133

66-
@common.parametrize("test_data", test_data_suite)
67-
def test_constant_pad_nd_tosa_BI(test_data: Tuple):
68-
test_data, padding, value = test_data
69-
pipeline = TosaPipelineBI[input_t1](
70-
ConstantPadND(padding, value),
71-
(test_data,),
72-
aten_op,
73-
exir_op,
74-
)
75-
pipeline.run()
134+
@parameterized.expand(test_data_suite)
135+
def test_constant_pad_nd_tosa_BI(
136+
self,
137+
test_name: str,
138+
test_data: torch.Tensor,
139+
padding: Tuple,
140+
value: float | None = None,
141+
):
142+
self._test_constant_pad_nd_tosa_BI_pipeline(
143+
self.ConstantPadND(padding, value), (test_data,)
144+
)

‎backends/arm/test/ops/test_conv_constant_pad_nd.py

Lines changed: 0 additions & 114 deletions
This file was deleted.

‎backends/arm/test/ops/test_cos.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

‎backends/arm/test/ops/test_sin.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.