Skip to content

Commit 3c21e3a

Browse files
authored
Arm Backend: Update unit tests for TOSA 1.0 (#10776)
### Summary Refactoring of unit tests to allow for testing of TOSA 1.0 Adds command-line argument --arm_run_tosa_version to run tests on particular version
1 parent d9c6f80 commit 3c21e3a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+5412
-6841
lines changed

backends/arm/scripts/parse_test_names.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT
66

77
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
8-
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
8+
CUSTOM_EDGE_OPS = [
9+
"linspace.default",
10+
"eye.default",
11+
"hardsigmoid.default",
12+
"hardswish.default",
13+
"linear.default",
14+
"maximum.default",
15+
"adaptive_avg_pool2d.default",
16+
]
917
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
1018

1119
# Add all targets and TOSA profiles we support here.

backends/arm/test/common.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,15 @@ def decorator_func(func):
259259
raise RuntimeError(
260260
"xfail info needs to be str, or tuple[str, type[Exception]]"
261261
)
262-
pytest_param = pytest.param(
263-
test_parameters,
264-
id=id,
265-
marks=pytest.mark.xfail(
266-
reason=reason, raises=raises, strict=strict
267-
),
262+
# Set up our fail marker
263+
marker = (
264+
pytest.mark.xfail(reason=reason, raises=raises, strict=strict),
268265
)
269266
else:
270-
pytest_param = pytest.param(test_parameters, id=id)
271-
pytest_testsuite.append(pytest_param)
267+
marker = ()
272268

269+
pytest_param = pytest.param(test_parameters, id=id, marks=marker)
270+
pytest_testsuite.append(pytest_param)
273271
return pytest.mark.parametrize(arg_name, pytest_testsuite)(func)
274272

275273
return decorator_func

backends/arm/test/conftest.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@
1212

1313
import pytest
1414

15-
try:
16-
import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model
17-
except ImportError:
18-
logging.warning("tosa_reference_model not found, can't run reference model tests")
19-
tosa_reference_model = None
20-
2115
"""
2216
This file contains the pytest hooks, fixtures etc. for the Arm test suite.
2317
"""
@@ -50,10 +44,11 @@ def pytest_configure(config):
5044
if getattr(config.option, "fast_fvp", False):
5145
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
5246

53-
# TODO: remove this flag once we have a way to run the reference model tests with Buck
54-
pytest._test_options["tosa_ref_model"] = False # type: ignore[attr-defined]
55-
if tosa_reference_model is not None:
56-
pytest._test_options["tosa_ref_model"] = True # type: ignore[attr-defined]
47+
if config.option.arm_run_tosa_version:
48+
pytest._test_options["tosa_version"] = config.option.arm_run_tosa_version
49+
50+
pytest._test_options["tosa_ref_model"] = True # type: ignore[attr-defined]
51+
5752
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5853

5954

@@ -76,6 +71,7 @@ def try_addoption(*args, **kwargs):
7671
nargs="+",
7772
help="List of two files. Firstly .pt file. Secondly .json",
7873
)
74+
try_addoption("--arm_run_tosa_version", action="store", default="0.80")
7975

8076

8177
def pytest_sessionstart(session):

backends/arm/test/ops/test_abs.py

Lines changed: 58 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,68 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import unittest
98

109
from typing import Tuple
1110

12-
import pytest
13-
1411
import torch
15-
from executorch.backends.arm.test import common, conftest
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
18-
from parameterized import parameterized
19-
20-
21-
class TestAbs(unittest.TestCase):
22-
class Abs(torch.nn.Module):
23-
test_parameters = [
24-
(torch.zeros(5),),
25-
(torch.full((5,), -1, dtype=torch.float32),),
26-
(torch.ones(5) * -1,),
27-
(torch.randn(8),),
28-
(torch.randn(2, 3, 4),),
29-
(torch.randn(1, 2, 3, 4),),
30-
(torch.normal(mean=0, std=10, size=(2, 3, 4)),),
31-
]
32-
33-
def forward(self, x):
34-
return torch.abs(x)
35-
36-
def _test_abs_tosa_MI_pipeline(
37-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
38-
):
39-
(
40-
ArmTester(
41-
module,
42-
example_inputs=test_data,
43-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
44-
)
45-
.export()
46-
.check_count({"torch.ops.aten.abs.default": 1})
47-
.check_not(["torch.ops.quantized_decomposed"])
48-
.to_edge()
49-
.partition()
50-
.check_not(["torch.ops.aten.abs.default"])
51-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
52-
.to_executorch()
53-
.run_method_and_compare_outputs(inputs=test_data)
54-
)
55-
56-
def _test_abs_tosa_BI_pipeline(
57-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
58-
):
59-
(
60-
ArmTester(
61-
module,
62-
example_inputs=test_data,
63-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
64-
)
65-
.quantize()
66-
.export()
67-
.check_count({"torch.ops.aten.abs.default": 1})
68-
.check(["torch.ops.quantized_decomposed"])
69-
.to_edge()
70-
.partition()
71-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
72-
.to_executorch()
73-
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
74-
)
75-
76-
def _test_abs_ethosu_BI_pipeline(
77-
self,
78-
compile_spec: list[CompileSpec],
79-
module: torch.nn.Module,
80-
test_data: Tuple[torch.Tensor],
81-
):
82-
tester = (
83-
ArmTester(
84-
module,
85-
example_inputs=test_data,
86-
compile_spec=compile_spec,
87-
)
88-
.quantize()
89-
.export()
90-
.check_count({"torch.ops.aten.abs.default": 1})
91-
.check(["torch.ops.quantized_decomposed"])
92-
.to_edge()
93-
.partition()
94-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
95-
.to_executorch()
96-
.serialize()
97-
)
98-
if conftest.is_option_enabled("corstone_fvp"):
99-
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
100-
101-
@parameterized.expand(Abs.test_parameters)
102-
def test_abs_tosa_MI(self, test_data: torch.Tensor):
103-
test_data = (test_data,)
104-
self._test_abs_tosa_MI_pipeline(self.Abs(), test_data)
105-
106-
@parameterized.expand(Abs.test_parameters)
107-
def test_abs_tosa_BI(self, test_data: torch.Tensor):
108-
test_data = (test_data,)
109-
self._test_abs_tosa_BI_pipeline(self.Abs(), test_data)
110-
111-
@parameterized.expand(Abs.test_parameters)
112-
@pytest.mark.corstone_fvp
113-
def test_abs_u55_BI(self, test_data: torch.Tensor):
114-
test_data = (test_data,)
115-
self._test_abs_ethosu_BI_pipeline(
116-
common.get_u55_compile_spec(), self.Abs(), test_data
117-
)
118-
119-
@parameterized.expand(Abs.test_parameters)
120-
@pytest.mark.corstone_fvp
121-
def test_abs_u85_BI(self, test_data: torch.Tensor):
122-
test_data = (test_data,)
123-
self._test_abs_ethosu_BI_pipeline(
124-
common.get_u85_compile_spec(), self.Abs(), test_data
125-
)
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineBI,
15+
EthosU85PipelineBI,
16+
TosaPipelineBI,
17+
TosaPipelineMI,
18+
)
19+
20+
aten_op = "torch.ops.aten.abs.default"
21+
exir_op = "executorch_exir_dialects_edge__ops_aten_abs_default"
22+
23+
input_t1 = Tuple[torch.Tensor] # Input x
24+
25+
26+
class Abs(torch.nn.Module):
27+
test_parameters = {
28+
"zeros": lambda: (torch.zeros(5),),
29+
"full": lambda: (torch.full((5,), -1, dtype=torch.float32),),
30+
"ones": lambda: (torch.ones(5) * -1,),
31+
"randn_1d": lambda: (torch.randn(8),),
32+
"randn_3d": lambda: (torch.randn(2, 3, 4),),
33+
"randn_4d": lambda: (torch.randn(1, 2, 3, 4),),
34+
"torch_normal": lambda: (torch.normal(mean=0, std=10, size=(2, 3, 4)),),
35+
}
36+
37+
def forward(self, x):
38+
return torch.abs(x)
39+
40+
41+
@common.parametrize("test_data", Abs.test_parameters)
42+
def test_abs_tosa_MI(test_data: torch.Tensor):
43+
pipeline = TosaPipelineMI[input_t1](Abs(), test_data(), aten_op, exir_op)
44+
pipeline.run()
45+
46+
47+
@common.parametrize("test_data", Abs.test_parameters)
48+
def test_abs_tosa_BI(test_data: torch.Tensor):
49+
pipeline = TosaPipelineBI[input_t1](Abs(), test_data(), aten_op, exir_op)
50+
pipeline.run()
51+
52+
53+
@common.parametrize("test_data", Abs.test_parameters)
54+
@common.XfailIfNoCorstone300
55+
def test_abs_u55_BI(test_data: torch.Tensor):
56+
pipeline = EthosU55PipelineBI[input_t1](
57+
Abs(), test_data(), aten_op, exir_op, run_on_fvp=True
58+
)
59+
pipeline.run()
60+
61+
62+
@common.parametrize("test_data", Abs.test_parameters)
63+
@common.XfailIfNoCorstone320
64+
def test_abs_u85_BI(test_data: torch.Tensor):
65+
pipeline = EthosU85PipelineBI[input_t1](
66+
Abs(), test_data(), aten_op, exir_op, run_on_fvp=True
67+
)
68+
pipeline.run()

0 commit comments

Comments
 (0)