diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 9f9168d9238..8156ca0b89d 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -12,7 +12,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult @@ -62,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult: } for partition in matmul_partitions: quantized_input = all( - input_node.target in dq_ops for input_node in partition.input_nodes + input_node.target in DQ_OPS for input_node in partition.input_nodes ) matmul_node = [ node for node in partition.nodes if node.target in matmul_targets @@ -93,7 +93,7 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.graph.erase_node(partition_input) partition_output = list(partition.output_nodes[0].users)[0] - quantized_output = partition_output.target in q_ops + quantized_output = partition_output.target in Q_OPS if quantized_output: with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 215bf21db2d..cb9fb8a50c7 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -15,8 +15,9 @@ get_param_tensor, is_param_node, ) +from executorch.backends.arm.constants import DQ_OPS, Q_OPS -from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs +from executorch.backends.arm.tosa_quant_utils import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -109,7 +110,7 @@ def fold_and_annotate_arg( return arg_quant_params = None - if arg.target in dq_ops: + if arg.target in DQ_OPS: args = arg.args scales = args[1] if ( @@ -137,9 +138,9 @@ def fold_and_annotate_arg( if input_qparams is not None: node.meta["input_qparams"][i] = input_qparams for n in nodes_to_remove: - if n.target not in dq_ops: + if n.target not in DQ_OPS: raise RuntimeError( - f"Expected one of {dq_ops} dq_op, got {n.target}" + f"Expected one of {DQ_OPS} dq_op, got {n.target}" ) node.replace_input_with(n, cast(Node, n.args[0])) @@ -154,7 +155,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if n.op != "call_function": continue # Don't fold chains of quant-ops into each other. - if n.target in (*q_ops, *dq_ops): + if n.target in (*Q_OPS, *DQ_OPS): continue # Make sure we haven't already set qparams meta information on the node @@ -184,7 +185,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Copy the users, since we are modifying it. users_copy = copy.copy(n.users) for i, user in enumerate(users_copy): - if user.target not in q_ops: + if user.target not in Q_OPS: continue # quantization node found here, store the quantization parameters in meta value @@ -221,7 +222,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Make sure we have a quantized operator user = list(n.users)[0] - if user.target not in q_ops: + if user.target not in Q_OPS: continue qargs = QuantArgs.from_operator(user.target, user.args) diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index f70d6d8755b..fb52aab9071 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -6,7 +6,8 @@ # pyre-unsafe import torch -from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs +from executorch.backends.arm.constants import Q_OPS +from executorch.backends.arm.tosa_quant_utils import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node @@ -21,7 +22,7 @@ def _is_fuseable_quantized_activation(node: Node): min_val = node.args[1] is_fuseable = min_val == 0 - is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops + is_quantized = len(node.users) == 1 and next(iter(node.users)).target in Q_OPS if is_fuseable and is_quantized: quant_node = next(iter(node.users)) quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 97b8fb15711..8a2e10b6b2d 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -9,7 +9,8 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.tosa_quant_utils import QuantArgs from executorch.exir.pass_base import ExportPass, PassResult from torch import Tensor from torch.fx import GraphModule, Node @@ -94,11 +95,11 @@ def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: node = cast(Node, node) - if node.target not in dq_ops: + if node.target not in DQ_OPS: continue # Copy users since we remove them while iterating, modyfing the node.users list. for user in copy(node.users): - if user.target in q_ops: + if user.target in Q_OPS: self.fold_dq_q_to_rescale(node, user, graph_module) modified = True if len(node.users) == 0: diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 519b755080c..69d8573013e 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -12,7 +12,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node @@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule): node.replace_input_with(input_node, unsqueeze_before) # If Quantized we must insert unsqueeze --> q --> dq --> node - if input_node.target in dq_ops: + if input_node.target in DQ_OPS: q_params = input_node.args[1:] insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) @@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule): user.replace_input_with(bmm_node, squeeze_after) # If quantized, insert mm --> q --> dq --> squeeze - if all(original_user.target in q_ops for original_user in original_users): + if all(original_user.target in Q_OPS for original_user in original_users): q_params = original_users[0].args[1:] insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) diff --git a/backends/arm/constants.py b/backends/arm/constants.py new file mode 100644 index 00000000000..fd8710d3ead --- /dev/null +++ b/backends/arm/constants.py @@ -0,0 +1,31 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, cast, Final + +from executorch.exir.dialects._ops import ops as exir_ops + +exir_ops = cast(Any, exir_ops) + +qd = exir_ops.edge.quantized_decomposed + +QUANT_PER_TENSOR_OP: Final = qd.quantize_per_tensor.default +QUANT_PER_TENSOR_OP_T: Final = qd.quantize_per_tensor.tensor +QUANT_PER_CHANNEL_OP: Final = qd.quantize_per_channel.default + +DEQUANT_PER_TENSOR_OP: Final = qd.dequantize_per_tensor.default +DEQUANT_PER_TENSOR_OP_T: Final = qd.dequantize_per_tensor.tensor +DEQUANT_PER_CHANNEL_OP: Final = qd.dequantize_per_channel.default + +Q_OPS: Final = (QUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP_T, QUANT_PER_CHANNEL_OP) +DQ_OPS: Final = (DEQUANT_PER_TENSOR_OP, DEQUANT_PER_TENSOR_OP_T, DEQUANT_PER_CHANNEL_OP) + +PER_TENSOR_QDQ_OPS: Final = ( + QUANT_PER_TENSOR_OP, + QUANT_PER_TENSOR_OP_T, + DEQUANT_PER_TENSOR_OP, + DEQUANT_PER_TENSOR_OP_T, +) +PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 508b673ef5f..fb38c07b8a6 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -19,13 +19,13 @@ FuseQuantizedActivationPass, ) from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55DtypeSupport, EthosU55NotSupported, EthosU55TransposeCheck, EthosU55ViewCheck, ) -from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter @@ -368,7 +368,7 @@ def _is_matmul_node_supported( matched_partition = partition if matched_partition is not None: input_quantized = all( - input_node.target in dq_ops + input_node.target in DQ_OPS for input_node in matched_partition.input_nodes ) if not input_quantized: @@ -377,7 +377,7 @@ def _is_matmul_node_supported( ) return False output_quantized = all( - output_node_user.target in q_ops + output_node_user.target in Q_OPS for output_node_user in matched_partition.output_nodes[0].users ) if not output_quantized: @@ -413,7 +413,7 @@ def is_node_supported( users = node.users output_quantized = all( user.target == operator.getitem - and all(user_user.target in q_ops for user_user in user.users) + and all(user_user.target in Q_OPS for user_user in user.users) for user in users ) elif FuseQuantizedActivationPass._is_fuseable_input(node): @@ -427,7 +427,7 @@ def is_node_supported( input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node) input_quantized = input_quantized or all( - (input_node.target in dq_ops) + (input_node.target in DQ_OPS) or (not get_first_fake_tensor(input_node).dtype.is_floating_point) for input_node in node.all_input_nodes ) @@ -436,7 +436,7 @@ def is_node_supported( self.reporter.report_reject(node, "One or more inputs were not quantized.") return False - all_q_users = all((output_node.target in q_ops) for output_node in node.users) + all_q_users = all((output_node.target in Q_OPS) for output_node in node.users) is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point output_quantized = output_quantized or all_q_users or not is_floating_point diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 0a0b0f33b6c..8c923568265 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional, Sequence, Tuple import torch +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.backends.arm.arm_backend import ( get_tosa_spec, is_tosa, @@ -25,7 +26,6 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter -from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase @@ -34,22 +34,6 @@ logger = logging.getLogger(__name__) -def is_quant_node(node: torch.fx.node.Node) -> bool: - return node.target in { - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - } - - -def is_dequant_node(node: torch.fx.node.Node) -> bool: - return node.target in { - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - } - - class TOSAPartitioner(Partitioner): def __init__( self, @@ -99,14 +83,14 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: for node in exported_program.graph_module.graph.nodes: if not is_partitioned(node): continue - if is_quant_node(node): + if node.target in Q_OPS: for input in node.all_input_nodes: if not is_partitioned(input): del node.meta["delegation_tag"] break continue - if is_dequant_node(node): + if node.target in DQ_OPS: for user in node.users: if not is_partitioned(user): del node.meta["delegation_tag"] diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index f6324efb401..d6a2d7bbe59 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -15,6 +15,7 @@ import torch.fx import torch.fx.node +from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops @@ -23,25 +24,6 @@ from tosa.RoundingMode import RoundingMode # type: ignore -q_ops = ( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, -) -dq_ops = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, -) -per_tensor_q_dq_ops = ( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, -) -per_channel_q_dq_ops = ( - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, -) -dq_q_ops = (*q_ops, *dq_ops) - - def insert_rescale_ops_to_int32( tosa_graph: Any, inputs: list[TosaArg], @@ -185,7 +167,7 @@ def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: @classmethod def from_operator(cls, op, args): - if op in per_tensor_q_dq_ops: + if op in PER_TENSOR_QDQ_OPS: return cls( scale=cast(float, args[1]), zp=cast(int, args[2]), @@ -195,7 +177,7 @@ def from_operator(cls, op, args): axis=0, per_channel=False, ) - elif op in per_channel_q_dq_ops: + elif op in PER_CHANNEL_QDQ_OPS: return cls( scale=cast(list[float], args[1].tolist()), zp=cast(list[int], args[2].tolist()),