Skip to content

Commit 7889c0f

Browse files
committed
Arm backend: Add check for unsupported dtypes on Ethos-U55
Move all Ethos-U55 support checks into a single file. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ib6444abdbe1cc15d7ec1a91efa15362022f57895
1 parent 77c35f5 commit 7889c0f

File tree

5 files changed

+198
-125
lines changed

5 files changed

+198
-125
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import ( # noqa
99
convolution_support,
10+
ethos_u55_support,
1011
minmax_support,
1112
pool_2d_support,
1213
reduce_sum_support,
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import typing
9+
10+
import torch
11+
import torch.fx as fx
12+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
13+
from executorch.backends.arm._passes.insert_table_ops import TableOps
14+
from executorch.exir.backend.utils import WhyNoPartitionReporter
15+
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch.fx.passes.operator_support import OperatorSupportBase
18+
19+
20+
class EthosU55DtypeSupport(OperatorSupportBase):
21+
22+
def __init__(self, reporter: WhyNoPartitionReporter):
23+
super().__init__()
24+
self.reporter = reporter
25+
26+
targeted_ops_i8_i16_i32 = [
27+
exir_ops.edge.aten.cat.default,
28+
exir_ops.edge.aten.repeat.default,
29+
exir_ops.edge.aten.constant_pad_nd.default,
30+
exir_ops.edge.aten.view.default,
31+
exir_ops.edge.aten.permute.default,
32+
]
33+
34+
target_ops_i8 = tuple(TableOps.included_ops())
35+
36+
def _try_determine_dtype(self, node: fx.Node) -> torch.dtype | None:
37+
"""Attempt to figure out the quantized data type of node. On failure, return None."""
38+
39+
dtype = get_first_fake_tensor(node).dtype
40+
if not dtype.is_floating_point:
41+
return dtype
42+
43+
if (
44+
node.target
45+
is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
46+
):
47+
return get_first_fake_tensor(node.all_input_nodes[0]).dtype
48+
49+
if len(node.users) == 0:
50+
return None
51+
52+
q_node = list(node.users)[0]
53+
if (
54+
q_node.target
55+
is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
56+
):
57+
return typing.cast(torch.dtype, q_node.args[-1])
58+
59+
# We can't easily figure out dtype, return None
60+
return None
61+
62+
def is_node_supported( # noqa: C901
63+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
64+
) -> bool:
65+
66+
dtype = self._try_determine_dtype(node)
67+
if dtype is None:
68+
# If we couldn't determine dtype, just return ok.
69+
return True
70+
71+
if node.target in self.targeted_ops_i8_i16_i32:
72+
if dtype not in (torch.int8, torch.int16, torch.int32):
73+
self.reporter.report_reject(
74+
node, f"Unsupported dtype {dtype} (Supports i8, i16, i32)."
75+
)
76+
return False
77+
78+
if node.target in self.target_ops_i8:
79+
if dtype not in (torch.int8,):
80+
self.reporter.report_reject(
81+
node, f"Unsupported dtype {dtype} (Supports i8)."
82+
)
83+
return False
84+
85+
if node.target == exir_ops.edge.aten.convolution.default:
86+
ifm, weight = node.all_input_nodes[0:2]
87+
ifm_dtype = self._try_determine_dtype(ifm)
88+
if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16):
89+
self.reporter.report_reject(
90+
node, f"Unsupported input dtype {dtype} (Supports i8, i16)."
91+
)
92+
return False
93+
weight_dtype = self._try_determine_dtype(weight)
94+
if weight_dtype is not None and weight_dtype not in (torch.int8,):
95+
self.reporter.report_reject(
96+
node, f"Unsupported weight dtype {dtype} (Supports i8)."
97+
)
98+
return False
99+
if len(node.all_input_nodes) > 2:
100+
bias = node.all_input_nodes[2]
101+
bias_dtype = self._try_determine_dtype(bias)
102+
if bias_dtype is not None and bias_dtype not in (torch.int32,):
103+
self.reporter.report_reject(
104+
node, f"Unsupported bias dtype {dtype} (Supports i32)."
105+
)
106+
return False
107+
108+
if node.target in (
109+
exir_ops.edge.aten.mm.default,
110+
exir_ops.edge.aten.bmm.default,
111+
):
112+
for input_node in node.all_input_nodes:
113+
dtype = self._try_determine_dtype(input_node)
114+
if dtype is not None and dtype != torch.int8:
115+
self.reporter.report_reject(
116+
input_node,
117+
f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8).",
118+
)
119+
return False
120+
121+
return True
122+
123+
124+
class EthosU55NotSupported(OperatorSupportBase):
125+
"""
126+
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
127+
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
128+
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
129+
"""
130+
131+
unsupported_ops = [
132+
exir_ops.edge.aten.any.default, # REDUCE_ANY
133+
exir_ops.edge.aten.any.dim, # REDUCE_ANY
134+
exir_ops.edge.aten.any.dims, # REDUCE_ANY
135+
exir_ops.edge.aten.bitwise_and.Tensor,
136+
exir_ops.edge.aten.bitwise_or.Tensor,
137+
exir_ops.edge.aten.bitwise_xor.Tensor,
138+
exir_ops.edge.aten.bitwise_not,
139+
exir_ops.edge.aten.logical_and.default,
140+
exir_ops.edge.aten.logical_or.default,
141+
exir_ops.edge.aten.logical_xor.default,
142+
exir_ops.edge.aten.logical_not.default,
143+
exir_ops.edge.aten.amax.default, # REDUCE_MAX
144+
exir_ops.edge.aten.amin.default, # REDUCE_MIN
145+
exir_ops.edge.aten.eq.Tensor,
146+
exir_ops.edge.aten.eq.Scalar,
147+
exir_ops.edge.aten.ge.Tensor,
148+
exir_ops.edge.aten.gt.Tensor,
149+
exir_ops.edge.aten.le.Tensor,
150+
exir_ops.edge.aten.lt.Tensor,
151+
exir_ops.edge.aten.flip.default, # REVERSE
152+
exir_ops.edge.aten.grid_sampler_2d, # GATHER
153+
exir_ops.edge.aten.scatter.src,
154+
exir_ops.edge.aten.scatter.value,
155+
exir_ops.edge.aten.select_scatter.default,
156+
exir_ops.edge.aten.scatter_reduce.two,
157+
exir_ops.edge.aten.scatter_add.default,
158+
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
159+
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
160+
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
161+
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
162+
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
163+
]
164+
165+
def __init__(self, reporter: WhyNoPartitionReporter):
166+
self.reporter = reporter
167+
168+
def is_node_supported(
169+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
170+
) -> bool:
171+
172+
if node.target in self.unsupported_ops:
173+
self.reporter.report_reject(node, "Op is not supported on U55.")
174+
return False
175+
176+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 5 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1919
FuseQuantizedActivationPass,
2020
)
21+
from executorch.backends.arm.operator_support.ethos_u55_support import (
22+
EthosU55DtypeSupport,
23+
EthosU55NotSupported,
24+
)
2125
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
2226
from executorch.exir import ExportedProgram
2327
from executorch.exir.backend.utils import WhyNoPartitionReporter
@@ -118,6 +122,7 @@ def tosa_support_factory(
118122
negative_checks.append(CheckProperQuantization(reporter))
119123
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
120124
negative_checks.append(EthosU55NotSupported(reporter))
125+
negative_checks.append(EthosU55DtypeSupport(reporter))
121126

122127
return chain(
123128
reporter.wrap_check(
@@ -216,61 +221,6 @@ def is_node_supported(
216221
return supported
217222

218223

219-
class EthosU55NotSupported(OperatorSupportBase):
220-
"""
221-
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
222-
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
223-
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
224-
"""
225-
226-
unsupported_ops = [
227-
exir_ops.edge.aten.any.default, # REDUCE_ANY
228-
exir_ops.edge.aten.any.dim, # REDUCE_ANY
229-
exir_ops.edge.aten.any.dims, # REDUCE_ANY
230-
exir_ops.edge.aten.bitwise_and.Tensor,
231-
exir_ops.edge.aten.bitwise_or.Tensor,
232-
exir_ops.edge.aten.bitwise_xor.Tensor,
233-
exir_ops.edge.aten.bitwise_not,
234-
exir_ops.edge.aten.logical_and.default,
235-
exir_ops.edge.aten.logical_or.default,
236-
exir_ops.edge.aten.logical_xor.default,
237-
exir_ops.edge.aten.logical_not.default,
238-
exir_ops.edge.aten.amax.default, # REDUCE_MAX
239-
exir_ops.edge.aten.amin.default, # REDUCE_MIN
240-
exir_ops.edge.aten.eq.Tensor,
241-
exir_ops.edge.aten.eq.Scalar,
242-
exir_ops.edge.aten.ge.Tensor,
243-
exir_ops.edge.aten.gt.Tensor,
244-
exir_ops.edge.aten.le.Tensor,
245-
exir_ops.edge.aten.lt.Tensor,
246-
exir_ops.edge.aten.flip.default, # REVERSE
247-
exir_ops.edge.aten.grid_sampler_2d, # GATHER
248-
exir_ops.edge.aten.scatter.src,
249-
exir_ops.edge.aten.scatter.value,
250-
exir_ops.edge.aten.select_scatter.default,
251-
exir_ops.edge.aten.scatter_reduce.two,
252-
exir_ops.edge.aten.scatter_add.default,
253-
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
254-
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
255-
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
256-
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
257-
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
258-
]
259-
260-
def __init__(self, reporter: WhyNoPartitionReporter):
261-
self.reporter = reporter
262-
263-
def is_node_supported(
264-
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
265-
) -> bool:
266-
267-
if node.target in self.unsupported_ops:
268-
self.reporter.report_reject(node, "Op is not supported on U55.")
269-
return False
270-
271-
return True
272-
273-
274224
class NeedsDecompositionCheck(OperatorSupportBase):
275225
"""
276226
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1414
from executorch.backends.arm.test import common
1515
from executorch.backends.arm.test.tester.test_pipeline import (
16-
EthosU55PipelineBI,
1716
EthosU85PipelineBI,
17+
OpNotSupportedPipeline,
1818
TosaPipelineBI,
1919
)
2020
from executorch.backends.xnnpack.test.tester import Quantize
@@ -109,22 +109,10 @@ def test_sigmoid_add_sigmoid_tosa_BI(test_data):
109109
@common.parametrize(
110110
"test_data",
111111
test_data_suite,
112-
xfails={
113-
"ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
114-
"rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
115-
"rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
116-
"randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
117-
"randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
118-
"ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
119-
},
120-
# int16 tables are not supported, but some tests happen to pass regardless.
121-
# Set them to xfail but strict=False -> ok if they pass.
122-
strict=False,
123112
)
124-
@common.XfailIfNoCorstone300
125113
def test_sigmoid_tosa_u55(test_data):
126-
pipeline = EthosU55PipelineBI(
127-
Sigmoid(), (test_data(),), Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True
114+
pipeline = OpNotSupportedPipeline(
115+
Sigmoid(), (test_data(),), "TOSA-0.80+BI+u55", {Sigmoid.exir_op: 1}
128116
)
129117
pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55"))
130118
pipeline.run()
@@ -133,26 +121,14 @@ def test_sigmoid_tosa_u55(test_data):
133121
@common.parametrize(
134122
"test_data",
135123
test_data_suite,
136-
xfails={
137-
"ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
138-
"rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
139-
"rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
140-
"randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
141-
"randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770",
142-
"ramp": "AsssertionError: Output 0 does not match reference output. MLBEDSW-9770",
143-
},
144-
# int16 tables are not supported, but some tests happen to pass regardless.
145-
# Set them to xfail but strict=False -> ok if they pass.
146-
strict=False,
147124
)
148-
@common.XfailIfNoCorstone300
149125
def test_sigmoid_add_sigmoid_tosa_u55(test_data):
150-
pipeline = EthosU55PipelineBI(
126+
pipeline = OpNotSupportedPipeline(
151127
SigmoidAddSigmoid(),
152128
(test_data(),),
153-
Sigmoid.aten_op,
154-
Sigmoid.exir_op,
155-
run_on_fvp=True,
129+
"TOSA-0.80+BI+u55",
130+
{Sigmoid.exir_op: 3},
131+
n_expected_delegates=1,
156132
)
157133
pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55"))
158134
pipeline.run()

0 commit comments

Comments
 (0)