diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 9b4e3934da6..b500540ffb7 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -31,7 +31,7 @@ class InsertTableOpsPass(ExportPass): """ For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). - When loweringthe _table node target_str will be used to find the corresponding torch operator + When lowering the _table node target_str will be used to find the corresponding torch operator which will be used to produce the table values in operators/op_table.py. """ @@ -43,6 +43,7 @@ class InsertTableOpsPass(ExportPass): exir_ops.edge.aten.sigmoid.default: torch.sigmoid, exir_ops.edge.aten.tanh.default: torch.tanh, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, + exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, } def __init__(self, exported_program: ExportedProgram) -> None: diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 5a5a281ff6a..8fde8dff610 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -115,6 +115,7 @@ def ops_to_not_decompose( ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: ops_to_not_decompose_if_quant_op = [ torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, ] def filter_fn(node: torch.fx.Node) -> bool: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index d7398a7b804..36914579fe4 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -81,6 +81,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.eq.Tensor, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 4b60b86e0d4..32f64963e87 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -133,6 +133,7 @@ def _match_pattern( torch.ops.aten.tanh.default, torch.ops.aten.sum.dim_IntList, torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, ] _one_to_one_shared_input_qspec = [ diff --git a/backends/arm/test/ops/test_hardswish.py b/backends/arm/test/ops/test_hardswish.py new file mode 100644 index 00000000000..81aba540e3f --- /dev/null +++ b/backends/arm/test/ops/test_hardswish.py @@ -0,0 +1,128 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import pytest +import torch + +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data) + ("zeros", torch.zeros(1, 10, 10, 10)), + ("ones", torch.ones(10, 10, 10)), + ("rand", torch.rand(10, 10) - 0.5), + ("randn_pos", torch.randn(10) + 10), + ("randn_neg", torch.randn(10) - 10), + ("ramp", torch.arange(-16, 16, 0.2)), +] + + +class TestHardswish(unittest.TestCase): + class Hardswish(torch.nn.Module): + def __init__(self): + super().__init__() + self.hardswish = torch.nn.Hardswish() + + def forward(self, x): + return self.hardswish(x) + + def _test_hardswish_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check(["torch.ops.aten.hardswish.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_hardswish_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check(["torch.ops.aten.hardswish.default"]) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_hardswish_tosa_ethos_BI_pipeline( + self, + compile_spec: list[CompileSpec], + module: torch.nn.Module, + test_data: Tuple[torch.tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .check_count({"torch.ops.aten.hardswish.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + + @parameterized.expand(test_data_suite) + def test_hardswish_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + ): + self._test_hardswish_tosa_MI_pipeline(self.Hardswish(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_hardswish_tosa_BI(self, test_name: str, test_data: torch.Tensor): + self._test_hardswish_tosa_BI_pipeline(self.Hardswish(), (test_data,)) + + @parameterized.expand(test_data_suite) + @pytest.mark.corstone_fvp + def test_hardswish_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): + self._test_hardswish_tosa_ethos_BI_pipeline( + common.get_u55_compile_spec(), self.Hardswish(), (test_data,) + ) + + @parameterized.expand(test_data_suite) + @pytest.mark.corstone_fvp + def test_hardswish_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor): + self._test_hardswish_tosa_ethos_BI_pipeline( + common.get_u85_compile_spec(), self.Hardswish(), (test_data,) + )