From e268aa50dbde7c6ff36f1e776fca478d51cf6801 Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Fri, 7 Feb 2025 16:49:57 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - op support - where / logical_not - test cases --- backends/qualcomm/_passes/layout_transform.py | 2 + backends/qualcomm/builders/__init__.py | 4 + backends/qualcomm/builders/op_logical_not.py | 55 +++++++++++++ backends/qualcomm/builders/op_where.py | 81 +++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 10 +++ backends/qualcomm/partition/common_defs.py | 3 - backends/qualcomm/quantizer/annotators.py | 23 ++++++ backends/qualcomm/tests/models.py | 26 ++++++ backends/qualcomm/tests/test_qnn_delegate.py | 36 +++++++++ 9 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 backends/qualcomm/builders/op_logical_not.py create mode 100644 backends/qualcomm/builders/op_where.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 6ca0b512643..574633fddcc 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -68,6 +68,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.logical_not.default, exir_ops.edge.aten.lt.Scalar, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten._log_softmax.default, @@ -88,6 +89,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.topk.default, exir_ops.edge.aten._to_copy.default, + exir_ops.edge.aten.where.self, *q_ops, *dq_ops, _operator.getitem, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index ce19b6dbc73..76abf66ff3d 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -40,6 +40,7 @@ op_linear, op_log, op_log_softmax, + op_logical_not, op_lt, op_matmul, op_max, @@ -76,6 +77,7 @@ op_unsqueeze, op_upsample_bilinear2d, op_upsample_nearest2d, + op_where, ) __all__ = [ @@ -113,6 +115,7 @@ op_le, op_linear, op_log, + op_logical_not, op_log_softmax, op_lt, op_matmul, @@ -150,4 +153,5 @@ op_unsqueeze, op_upsample_bilinear2d, op_upsample_nearest2d, + op_where, ] diff --git a/backends/qualcomm/builders/op_logical_not.py b/backends/qualcomm/builders/op_logical_not.py new file mode 100644 index 00000000000..457a1007ada --- /dev/null +++ b/backends/qualcomm/builders/op_logical_not.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseNot, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Not(NodeVisitor): + target = ["aten.logical_not.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + logical_not_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseNot.op_name, + ) + logical_not_op.AddInputTensors([input_tensor_wrapper]) + logical_not_op.AddOutputTensors([output_tensor_wrapper]) + + return logical_not_op diff --git a/backends/qualcomm/builders/op_where.py b/backends/qualcomm/builders/op_where.py new file mode 100644 index 00000000000..ecac45a7a6f --- /dev/null +++ b/backends/qualcomm/builders/op_where.py @@ -0,0 +1,81 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpElementWiseSelect, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Where(NodeVisitor): + target = ["aten.where.self"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + conditional_input_node = node.args[0] + conditional_input_tensor = self.get_tensor(conditional_input_node, node) + conditional_input_tensor_wrapper = self.define_tensor( + conditional_input_node, + node, + conditional_input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + true_input_node = node.args[1] + true_input_tensor = self.get_tensor(true_input_node, node) + true_input_tensor_wrapper = self.define_tensor( + true_input_node, + node, + true_input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + false_input_node = node.args[2] + false_input_tensor = self.get_tensor(false_input_node, node) + false_input_tensor_wrapper = self.define_tensor( + false_input_node, + node, + false_input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + where_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpElementWiseSelect.op_name, + ) + where_op.AddInputTensors( + [ + conditional_input_tensor_wrapper, + true_input_tensor_wrapper, + false_input_tensor_wrapper, + ] + ) + where_op.AddOutputTensors([output_tensor_wrapper]) + + return where_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index d53e6792869..514a9393cc9 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -158,6 +158,11 @@ class OpElementWiseNeuron: param_beta: str = "beta" +@dataclass(init=False, frozen=True) +class OpElementWiseNot: + op_name: str = "ElementWiseNot" + + @dataclass(init=False, frozen=True) class OpElementWisePower: op_name: str = "ElementWisePower" @@ -173,6 +178,11 @@ class OpElementWiseSin: op_name: str = "ElementWiseSin" +@dataclass(init=False, frozen=True) +class OpElementWiseSelect: + op_name = "ElementWiseSelect" + + @dataclass(init=False, frozen=True) class OpElementWiseSubtract: op_name = "ElementWiseSubtract" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 7f49cfb7867..d1756e44281 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -7,7 +7,6 @@ from executorch.exir.dialects._ops import ops as exir_ops - not_supported_operator = [ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, @@ -18,8 +17,6 @@ to_be_implemented_operator = [ exir_ops.edge.aten.any.dim, - exir_ops.edge.aten.logical_not.default, - exir_ops.edge.aten.where.self, ] constant_operator = [ diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index fe1729d19b8..b2e6aeb2994 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -1070,3 +1070,26 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: output_qspec=quantization_config.output_activation, _annotated=True, ) + + +@register_annotator([torch.ops.aten.where.self]) +def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None: + true_input_act = node.args[1] + false_input_act = node.args[2] + if _is_annotated([node]): + return + + _annotate_input_qspec_map( + node, + true_input_act, + quantization_config.input_activation, + ) + + _annotate_input_qspec_map( + node, + false_input_act, + quantization_config.input_activation, + ) + + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated([node]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 4e733087808..6b5d27f68e8 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -793,6 +793,14 @@ def forward(self, x): return torch.log(x) +class LogicalNot(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.logical_not(x > 0) + + class LogSoftmax(torch.nn.Module): def __init__(self): super().__init__() @@ -1306,3 +1314,21 @@ def forward(self, x, y): x = x.view(new_shape) x = x.permute(0, 2, 1, 3) return torch.matmul(x, y.transpose(-1, -2)) + + +class Where(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.where(x >= torch.zeros(x.shape), y, z) + + +class WhereConstant(torch.nn.Module): + def __init__(self, pos, neg): + super().__init__() + self.register_buffer("pos", pos) + self.register_buffer("neg", neg) + + def forward(self, x): + return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 0d1e80904fe..c4c9cbc859c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -513,6 +513,11 @@ def test_qnn_backend_log(self): sample_input = (torch.rand([1, 2, 3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_not(self): + module = LogicalNot() # noqa: F405 + sample_input = (torch.rand([1, 2, 3, 4]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_softmax(self): module = LogSoftmax() # noqa: F405 sample_input = (torch.randn([1, 4, 8, 8]),) @@ -692,6 +697,18 @@ def test_qnn_backend_view(self): sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256])) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_where(self): + modules = [ + Where(), # noqa: F405 + WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), + (torch.randn(3, 2),), + ] + for i, module in enumerate(modules): + self.lower_module_and_test_output(module, sample_inputs[i]) + class TestQNNFloatingPointModel(TestQNN): # TODO: refactor to support different backends @@ -1396,6 +1413,12 @@ def test_qnn_backend_log(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_logical_not(self): + module = LogicalNot() # noqa: F405 + sample_input = (torch.rand([1, 2, 3, 4]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_softmax(self): module = LogSoftmax() # noqa: F405 sample_input = (torch.randn([1, 4, 8, 8]),) @@ -1609,6 +1632,19 @@ def test_qnn_backend_view(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_where(self): + modules = [ + Where(), # noqa: F405 + WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405 + ] + sample_inputs = [ + (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)), + (torch.randn(3, 2),), + ] + for i, module in enumerate(modules): + module = self.get_qdq_module(module, sample_inputs[i]) + self.lower_module_and_test_output(module, sample_inputs[i]) + class TestQNNQuantizedModel(TestQNN): # TODO: refactor to support different backends