diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index c920aa4b97a..26457259a93 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -194,6 +194,7 @@ def is_node_supported( exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.ne.Tensor, exir_ops.edge.aten.ne.Scalar, + exir_ops.edge.aten.neg.default, exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar, exir_ops.edge.aten.mul.Scalar, @@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase): exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.neg.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.upsample_bilinear2d.vec, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index e496fe74d54..a58e443c6a4 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -31,6 +31,7 @@ op_maximum, op_minimum, op_mul, + op_neg, op_permute, op_pow, op_reciprocal, diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py new file mode 100644 index 00000000000..a5fefe25db6 --- /dev/null +++ b/backends/arm/operators/op_neg.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import torch.fx + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.arm.tosa_mapping import TosaArg + + +def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, int]: + """ + Returns (input1_zp, output_zp) for TOSA NEGATE. + Must be zero for non-int8 types. + """ + if dtype == ts.DType.INT8: + return ( + get_input_qparams(node)[0].zp, + get_output_qparams(node)[0].zp, + ) + return (0, 0) + + +@register_node_visitor +class NegVisitor(NodeVisitor): + target = "aten.neg.default" + + supported_dtypes = { + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP16, + ts.DType.BF16, + ts.DType.FP32, + } + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + if inputs[0].dtype not in self.supported_dtypes: + raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}") + + if inputs[0].dtype != output.dtype: + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + input_zp, output_zp = get_negate_zero_points(node, inputs[0].dtype) + + attr = ts.TosaSerializerAttribute() + attr.NegateAttribute(input1_zp=input_zp, output_zp=output_zp) + tosa_graph.addOperator( + ts.TosaOp.Op().NEGATE, + [inputs[0].name], + [output.name], + attributes=attr, + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5ac747177be..330f1a57c45 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -375,6 +375,9 @@ def any_or_hardtanh_min_zero(n: Node): ) ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + elif node.target in (torch.ops.aten.neg.default,): + quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] + quant_properties.quant_output = _QuantProperty(0, input_act_qspec) elif node.target in _one_to_one: quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) diff --git a/backends/arm/test/ops/test_neg.py b/backends/arm/test/ops/test_neg.py new file mode 100644 index 00000000000..e4d705dfba9 --- /dev/null +++ b/backends/arm/test/ops/test_neg.py @@ -0,0 +1,66 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t1 = Tuple[torch.Tensor] + + +class Neg(torch.nn.Module): + + aten_op = "torch.ops.aten.neg.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_neg_default" + + test_data: Dict[str, input_t1] = { + "rank_1_ramp": (torch.arange(-16, 16, 0.2),), + "rank_2_rand_uniform": (torch.rand(10, 10) - 0.5,), + "rank_3_all_ones": (torch.ones(10, 10, 10),), + "rank_4_all_zeros": (torch.zeros(1, 10, 10, 10),), + "rank_4_randn_pos": (torch.randn(1, 4, 4, 4) + 10,), + "rank_4_randn_neg": (torch.randn(1, 4, 4, 4) - 10,), + } + + def forward(self, x: torch.Tensor): + return torch.neg(x) + + +@common.parametrize("test_data", Neg.test_data) +def test_neg_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op) + pipeline.run() + + +@common.parametrize("test_data", Neg.test_data) +def test_neg_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op) + pipeline.run() + + +@common.parametrize("test_data", Neg.test_data) +@common.XfailIfNoCorstone300 +def test_neg_u55_BI(test_data: input_t1): + pipeline = EthosU55PipelineBI[input_t1]( + Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Neg.test_data) +@common.XfailIfNoCorstone320 +def test_neg_u85_BI(test_data: input_t1): + pipeline = EthosU85PipelineBI[input_t1]( + Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True + ) + pipeline.run()