Skip to content

Commit 90ff059

Browse files
Arm backend: Move q/dq ops constants to backends/arm/constants.py (#13095)
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent a89b963 commit 90ff059

9 files changed

+64
-64
lines changed

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
1414

15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1818
from executorch.exir.pass_base import ExportPass, PassResult
@@ -62,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6262
}
6363
for partition in matmul_partitions:
6464
quantized_input = all(
65-
input_node.target in dq_ops for input_node in partition.input_nodes
65+
input_node.target in DQ_OPS for input_node in partition.input_nodes
6666
)
6767
matmul_node = [
6868
node for node in partition.nodes if node.target in matmul_targets
@@ -93,7 +93,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9393
graph_module.graph.erase_node(partition_input)
9494

9595
partition_output = list(partition.output_nodes[0].users)[0]
96-
quantized_output = partition_output.target in q_ops
96+
quantized_output = partition_output.target in Q_OPS
9797
if quantized_output:
9898
with graph_module.graph.inserting_after(matmul_node):
9999
# Create q-node after matmul

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1819

19-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
20+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
2021

2122
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -109,7 +110,7 @@ def fold_and_annotate_arg(
109110
return
110111

111112
arg_quant_params = None
112-
if arg.target in dq_ops:
113+
if arg.target in DQ_OPS:
113114
args = arg.args
114115
scales = args[1]
115116
if (
@@ -137,9 +138,9 @@ def fold_and_annotate_arg(
137138
if input_qparams is not None:
138139
node.meta["input_qparams"][i] = input_qparams
139140
for n in nodes_to_remove:
140-
if n.target not in dq_ops:
141+
if n.target not in DQ_OPS:
141142
raise RuntimeError(
142-
f"Expected one of {dq_ops} dq_op, got {n.target}"
143+
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
143144
)
144145

145146
node.replace_input_with(n, cast(Node, n.args[0]))
@@ -154,7 +155,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
154155
if n.op != "call_function":
155156
continue
156157
# Don't fold chains of quant-ops into each other.
157-
if n.target in (*q_ops, *dq_ops):
158+
if n.target in (*Q_OPS, *DQ_OPS):
158159
continue
159160

160161
# Make sure we haven't already set qparams meta information on the node
@@ -184,7 +185,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
184185
# Copy the users, since we are modifying it.
185186
users_copy = copy.copy(n.users)
186187
for i, user in enumerate(users_copy):
187-
if user.target not in q_ops:
188+
if user.target not in Q_OPS:
188189
continue
189190

190191
# quantization node found here, store the quantization parameters in meta value
@@ -221,7 +222,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
221222

222223
# Make sure we have a quantized operator
223224
user = list(n.users)[0]
224-
if user.target not in q_ops:
225+
if user.target not in Q_OPS:
225226
continue
226227

227228
qargs = QuantArgs.from_operator(user.target, user.args)

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs
9+
from executorch.backends.arm.constants import Q_OPS
10+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213
from torch.fx import Node
@@ -21,7 +22,7 @@ def _is_fuseable_quantized_activation(node: Node):
2122
min_val = node.args[1]
2223
is_fuseable = min_val == 0
2324

24-
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops
25+
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in Q_OPS
2526
if is_fuseable and is_quantized:
2627
quant_node = next(iter(node.users))
2728
quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
12+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import Tensor
1516
from torch.fx import GraphModule, Node
@@ -94,11 +95,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
9495
for node in graph_module.graph.nodes:
9596
node = cast(Node, node)
9697

97-
if node.target not in dq_ops:
98+
if node.target not in DQ_OPS:
9899
continue
99100
# Copy users since we remove them while iterating, modyfing the node.users list.
100101
for user in copy(node.users):
101-
if user.target in q_ops:
102+
if user.target in Q_OPS:
102103
self.fold_dq_q_to_rescale(node, user, graph_module)
103104
modified = True
104105
if len(node.users) == 0:

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_first_fake_tensor,
1313
insert_q_dq_pair,
1414
)
15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.pass_base import ExportPass, PassResult
1818
from torch.fx import Node
@@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5656
node.replace_input_with(input_node, unsqueeze_before)
5757

5858
# If Quantized we must insert unsqueeze --> q --> dq --> node
59-
if input_node.target in dq_ops:
59+
if input_node.target in DQ_OPS:
6060
q_params = input_node.args[1:]
6161
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
6262

@@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8989
user.replace_input_with(bmm_node, squeeze_after)
9090

9191
# If quantized, insert mm --> q --> dq --> squeeze
92-
if all(original_user.target in q_ops for original_user in original_users):
92+
if all(original_user.target in Q_OPS for original_user in original_users):
9393
q_params = original_users[0].args[1:]
9494
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
9595

backends/arm/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, cast, Final
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
12+
qd = exir_ops.edge.quantized_decomposed
13+
14+
QUANT_PER_TENSOR_OP: Final = qd.quantize_per_tensor.default
15+
QUANT_PER_TENSOR_OP_T: Final = qd.quantize_per_tensor.tensor
16+
QUANT_PER_CHANNEL_OP: Final = qd.quantize_per_channel.default
17+
18+
DEQUANT_PER_TENSOR_OP: Final = qd.dequantize_per_tensor.default
19+
DEQUANT_PER_TENSOR_OP_T: Final = qd.dequantize_per_tensor.tensor
20+
DEQUANT_PER_CHANNEL_OP: Final = qd.dequantize_per_channel.default
21+
22+
Q_OPS: Final = (QUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP_T, QUANT_PER_CHANNEL_OP)
23+
DQ_OPS: Final = (DEQUANT_PER_TENSOR_OP, DEQUANT_PER_TENSOR_OP_T, DEQUANT_PER_CHANNEL_OP)
24+
25+
PER_TENSOR_QDQ_OPS: Final = (
26+
QUANT_PER_TENSOR_OP,
27+
QUANT_PER_TENSOR_OP_T,
28+
DEQUANT_PER_TENSOR_OP,
29+
DEQUANT_PER_TENSOR_OP_T,
30+
)
31+
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2223
from executorch.backends.arm.operator_support.ethos_u55_support import (
2324
EthosU55DtypeSupport,
2425
EthosU55NotSupported,
2526
EthosU55TransposeCheck,
2627
EthosU55ViewCheck,
2728
)
28-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
2929
from executorch.backends.arm.tosa_specification import TosaSpecification
3030
from executorch.exir import ExportedProgram
3131
from executorch.exir.backend.utils import WhyNoPartitionReporter
@@ -368,7 +368,7 @@ def _is_matmul_node_supported(
368368
matched_partition = partition
369369
if matched_partition is not None:
370370
input_quantized = all(
371-
input_node.target in dq_ops
371+
input_node.target in DQ_OPS
372372
for input_node in matched_partition.input_nodes
373373
)
374374
if not input_quantized:
@@ -377,7 +377,7 @@ def _is_matmul_node_supported(
377377
)
378378
return False
379379
output_quantized = all(
380-
output_node_user.target in q_ops
380+
output_node_user.target in Q_OPS
381381
for output_node_user in matched_partition.output_nodes[0].users
382382
)
383383
if not output_quantized:
@@ -413,7 +413,7 @@ def is_node_supported(
413413
users = node.users
414414
output_quantized = all(
415415
user.target == operator.getitem
416-
and all(user_user.target in q_ops for user_user in user.users)
416+
and all(user_user.target in Q_OPS for user_user in user.users)
417417
for user in users
418418
)
419419
elif FuseQuantizedActivationPass._is_fuseable_input(node):
@@ -427,7 +427,7 @@ def is_node_supported(
427427
input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node)
428428

429429
input_quantized = input_quantized or all(
430-
(input_node.target in dq_ops)
430+
(input_node.target in DQ_OPS)
431431
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
432432
for input_node in node.all_input_nodes
433433
)
@@ -436,7 +436,7 @@ def is_node_supported(
436436
self.reporter.report_reject(node, "One or more inputs were not quantized.")
437437
return False
438438

439-
all_q_users = all((output_node.target in q_ops) for output_node in node.users)
439+
all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
440440
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
441441
output_quantized = output_quantized or all_q_users or not is_floating_point
442442

backends/arm/tosa_partitioner.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Callable, List, Optional, Sequence, Tuple
1010

1111
import torch
12+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1213
from executorch.backends.arm.arm_backend import (
1314
get_tosa_spec,
1415
is_tosa,
@@ -25,7 +26,6 @@
2526
PartitionResult,
2627
)
2728
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
28-
from executorch.exir.dialects._ops import ops as exir_ops
2929
from torch.export.exported_program import ExportedProgram
3030
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
3131
from torch.fx.passes.operator_support import OperatorSupportBase
@@ -34,22 +34,6 @@
3434
logger = logging.getLogger(__name__)
3535

3636

37-
def is_quant_node(node: torch.fx.node.Node) -> bool:
38-
return node.target in {
39-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
40-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
41-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
42-
}
43-
44-
45-
def is_dequant_node(node: torch.fx.node.Node) -> bool:
46-
return node.target in {
47-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
48-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
49-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
50-
}
51-
52-
5337
class TOSAPartitioner(Partitioner):
5438
def __init__(
5539
self,
@@ -99,14 +83,14 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
9983
for node in exported_program.graph_module.graph.nodes:
10084
if not is_partitioned(node):
10185
continue
102-
if is_quant_node(node):
86+
if node.target in Q_OPS:
10387
for input in node.all_input_nodes:
10488
if not is_partitioned(input):
10589
del node.meta["delegation_tag"]
10690
break
10791
continue
10892

109-
if is_dequant_node(node):
93+
if node.target in DQ_OPS:
11094
for user in node.users:
11195
if not is_partitioned(user):
11296
del node.meta["delegation_tag"]

backends/arm/tosa_quant_utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch.fx
1717
import torch.fx.node
18+
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
1819

1920
from executorch.backends.arm.tosa_mapping import TosaArg
2021
from executorch.exir.dialects._ops import ops as exir_ops
@@ -23,25 +24,6 @@
2324
from tosa.RoundingMode import RoundingMode # type: ignore
2425

2526

26-
q_ops = (
27-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
28-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
29-
)
30-
dq_ops = (
31-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
32-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
33-
)
34-
per_tensor_q_dq_ops = (
35-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
36-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37-
)
38-
per_channel_q_dq_ops = (
39-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
40-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
41-
)
42-
dq_q_ops = (*q_ops, *dq_ops)
43-
44-
4527
def insert_rescale_ops_to_int32(
4628
tosa_graph: Any,
4729
inputs: list[TosaArg],
@@ -185,7 +167,7 @@ def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
185167

186168
@classmethod
187169
def from_operator(cls, op, args):
188-
if op in per_tensor_q_dq_ops:
170+
if op in PER_TENSOR_QDQ_OPS:
189171
return cls(
190172
scale=cast(float, args[1]),
191173
zp=cast(int, args[2]),
@@ -195,7 +177,7 @@ def from_operator(cls, op, args):
195177
axis=0,
196178
per_channel=False,
197179
)
198-
elif op in per_channel_q_dq_ops:
180+
elif op in PER_CHANNEL_QDQ_OPS:
199181
return cls(
200182
scale=cast(list[float], args[1].tolist()),
201183
zp=cast(list[int], args[2].tolist()),

0 commit comments

Comments
 (0)