Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node

from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -62,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
}
for partition in matmul_partitions:
quantized_input = all(
input_node.target in dq_ops for input_node in partition.input_nodes
input_node.target in DQ_OPS for input_node in partition.input_nodes
)
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
Expand Down Expand Up @@ -93,7 +93,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
graph_module.graph.erase_node(partition_input)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target in q_ops
quantized_output = partition_output.target in Q_OPS
if quantized_output:
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
Expand Down
15 changes: 8 additions & 7 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
get_param_tensor,
is_param_node,
)
from executorch.backends.arm.constants import DQ_OPS, Q_OPS

from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -109,7 +110,7 @@ def fold_and_annotate_arg(
return

arg_quant_params = None
if arg.target in dq_ops:
if arg.target in DQ_OPS:
args = arg.args
scales = args[1]
if (
Expand Down Expand Up @@ -137,9 +138,9 @@ def fold_and_annotate_arg(
if input_qparams is not None:
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
if n.target not in dq_ops:
if n.target not in DQ_OPS:
raise RuntimeError(
f"Expected one of {dq_ops} dq_op, got {n.target}"
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
)

node.replace_input_with(n, cast(Node, n.args[0]))
Expand All @@ -154,7 +155,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
if n.op != "call_function":
continue
# Don't fold chains of quant-ops into each other.
if n.target in (*q_ops, *dq_ops):
if n.target in (*Q_OPS, *DQ_OPS):
continue

# Make sure we haven't already set qparams meta information on the node
Expand Down Expand Up @@ -184,7 +185,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Copy the users, since we are modifying it.
users_copy = copy.copy(n.users)
for i, user in enumerate(users_copy):
if user.target not in q_ops:
if user.target not in Q_OPS:
continue

# quantization node found here, store the quantization parameters in meta value
Expand Down Expand Up @@ -221,7 +222,7 @@ def call(self, graph_module: GraphModule) -> PassResult:

# Make sure we have a quantized operator
user = list(n.users)[0]
if user.target not in q_ops:
if user.target not in Q_OPS:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
Expand Down
5 changes: 3 additions & 2 deletions backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# pyre-unsafe

import torch
from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs
from executorch.backends.arm.constants import Q_OPS
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node
Expand All @@ -21,7 +22,7 @@ def _is_fuseable_quantized_activation(node: Node):
min_val = node.args[1]
is_fuseable = min_val == 0

is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in Q_OPS
if is_fuseable and is_quantized:
quant_node = next(iter(node.users))
quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args)
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.tosa_quant_utils import QuantArgs
from executorch.exir.pass_base import ExportPass, PassResult
from torch import Tensor
from torch.fx import GraphModule, Node
Expand Down Expand Up @@ -94,11 +95,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
node = cast(Node, node)

if node.target not in dq_ops:
if node.target not in DQ_OPS:
continue
# Copy users since we remove them while iterating, modyfing the node.users list.
for user in copy(node.users):
if user.target in q_ops:
if user.target in Q_OPS:
self.fold_dq_q_to_rescale(node, user, graph_module)
modified = True
if len(node.users) == 0:
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/mm_to_bmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_first_fake_tensor,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import Node
Expand Down Expand Up @@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule):
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target in dq_ops:
if input_node.target in DQ_OPS:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)

Expand Down Expand Up @@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule):
user.replace_input_with(bmm_node, squeeze_after)

# If quantized, insert mm --> q --> dq --> squeeze
if all(original_user.target in q_ops for original_user in original_users):
if all(original_user.target in Q_OPS for original_user in original_users):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)

Expand Down
31 changes: 31 additions & 0 deletions backends/arm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, cast, Final

from executorch.exir.dialects._ops import ops as exir_ops

exir_ops = cast(Any, exir_ops)

qd = exir_ops.edge.quantized_decomposed

QUANT_PER_TENSOR_OP: Final = qd.quantize_per_tensor.default
QUANT_PER_TENSOR_OP_T: Final = qd.quantize_per_tensor.tensor
QUANT_PER_CHANNEL_OP: Final = qd.quantize_per_channel.default

DEQUANT_PER_TENSOR_OP: Final = qd.dequantize_per_tensor.default
DEQUANT_PER_TENSOR_OP_T: Final = qd.dequantize_per_tensor.tensor
DEQUANT_PER_CHANNEL_OP: Final = qd.dequantize_per_channel.default

Q_OPS: Final = (QUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP_T, QUANT_PER_CHANNEL_OP)
DQ_OPS: Final = (DEQUANT_PER_TENSOR_OP, DEQUANT_PER_TENSOR_OP_T, DEQUANT_PER_CHANNEL_OP)

PER_TENSOR_QDQ_OPS: Final = (
QUANT_PER_TENSOR_OP,
QUANT_PER_TENSOR_OP_T,
DEQUANT_PER_TENSOR_OP,
DEQUANT_PER_TENSOR_OP_T,
)
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)
12 changes: 6 additions & 6 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import TableOps
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.operator_support.ethos_u55_support import (
EthosU55DtypeSupport,
EthosU55NotSupported,
EthosU55TransposeCheck,
EthosU55ViewCheck,
)
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
Expand Down Expand Up @@ -368,7 +368,7 @@ def _is_matmul_node_supported(
matched_partition = partition
if matched_partition is not None:
input_quantized = all(
input_node.target in dq_ops
input_node.target in DQ_OPS
for input_node in matched_partition.input_nodes
)
if not input_quantized:
Expand All @@ -377,7 +377,7 @@ def _is_matmul_node_supported(
)
return False
output_quantized = all(
output_node_user.target in q_ops
output_node_user.target in Q_OPS
for output_node_user in matched_partition.output_nodes[0].users
)
if not output_quantized:
Expand Down Expand Up @@ -413,7 +413,7 @@ def is_node_supported(
users = node.users
output_quantized = all(
user.target == operator.getitem
and all(user_user.target in q_ops for user_user in user.users)
and all(user_user.target in Q_OPS for user_user in user.users)
for user in users
)
elif FuseQuantizedActivationPass._is_fuseable_input(node):
Expand All @@ -427,7 +427,7 @@ def is_node_supported(
input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node)

input_quantized = input_quantized or all(
(input_node.target in dq_ops)
(input_node.target in DQ_OPS)
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
for input_node in node.all_input_nodes
)
Expand All @@ -436,7 +436,7 @@ def is_node_supported(
self.reporter.report_reject(node, "One or more inputs were not quantized.")
return False

all_q_users = all((output_node.target in q_ops) for output_node in node.users)
all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
output_quantized = output_quantized or all_q_users or not is_floating_point

Expand Down
22 changes: 3 additions & 19 deletions backends/arm/tosa_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, List, Optional, Sequence, Tuple

import torch
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.arm_backend import (
get_tosa_spec,
is_tosa,
Expand All @@ -25,7 +26,6 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
Expand All @@ -34,22 +34,6 @@
logger = logging.getLogger(__name__)


def is_quant_node(node: torch.fx.node.Node) -> bool:
return node.target in {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}


def is_dequant_node(node: torch.fx.node.Node) -> bool:
return node.target in {
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}


class TOSAPartitioner(Partitioner):
def __init__(
self,
Expand Down Expand Up @@ -99,14 +83,14 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
for node in exported_program.graph_module.graph.nodes:
if not is_partitioned(node):
continue
if is_quant_node(node):
if node.target in Q_OPS:
for input in node.all_input_nodes:
if not is_partitioned(input):
del node.meta["delegation_tag"]
break
continue

if is_dequant_node(node):
if node.target in DQ_OPS:
for user in node.users:
if not is_partitioned(user):
del node.meta["delegation_tag"]
Expand Down
24 changes: 3 additions & 21 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch.fx
import torch.fx.node
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -23,25 +24,6 @@
from tosa.RoundingMode import RoundingMode # type: ignore


q_ops = (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
)
dq_ops = (
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
)
per_tensor_q_dq_ops = (
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
)
per_channel_q_dq_ops = (
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
)
dq_q_ops = (*q_ops, *dq_ops)


def insert_rescale_ops_to_int32(
tosa_graph: Any,
inputs: list[TosaArg],
Expand Down Expand Up @@ -185,7 +167,7 @@ def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:

@classmethod
def from_operator(cls, op, args):
if op in per_tensor_q_dq_ops:
if op in PER_TENSOR_QDQ_OPS:
return cls(
scale=cast(float, args[1]),
zp=cast(int, args[2]),
Expand All @@ -195,7 +177,7 @@ def from_operator(cls, op, args):
axis=0,
per_channel=False,
)
elif op in per_channel_q_dq_ops:
elif op in PER_CHANNEL_QDQ_OPS:
return cls(
scale=cast(list[float], args[1].tolist()),
zp=cast(list[int], args[2].tolist()),
Expand Down
Loading