Skip to content

Commit c691722

Browse files
XNNPack: Ignore Alpha!=1 for add nodes (fixes: #11683)
1 parent 0c12dcd commit c691722

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from typing import cast, List, Optional
1111

12+
import numpy as np
1213
import torch
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -106,6 +107,15 @@ def __init__(self, **kwargs):
106107
def supported_precision_types(self) -> List[ConfigPrecisionType]:
107108
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
108109

110+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
111+
# No support for add nodes with alpha != 1
112+
if "alpha" in node.kwargs and not np.isclose(
113+
node.kwargs["alpha"], 1.0, atol=1e-9, rtol=1e-9
114+
):
115+
why(node, reason="Add node doesn't support alpha != 1")
116+
return False
117+
return True
118+
109119

110120
class ReLUConfig(GenericNodePartitionerConfig):
111121
target_name = "relu.default"

backends/xnnpack/test/ops/test_add.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,27 @@ def forward(self, x, z):
240240
.serialize()
241241
.run_method_and_compare_outputs()
242242
)
243+
244+
class AddWithAlpha(torch.nn.Module):
245+
def forward(self, x, y):
246+
# node with alpha = 1.0 will be partitioned
247+
out1 = torch.add(x, y, alpha=1)
248+
# node with alpha != 1.0 will not be partitioned
249+
out2 = torch.add(x, y, alpha=2)
250+
return out1, out2
251+
252+
def test_add_with_alpha(self):
253+
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
254+
(
255+
Tester(self.AddWithAlpha(), inputs)
256+
.export()
257+
.check_count({"torch.ops.aten.add.Tensor": 2})
258+
.to_edge_transform_and_lower()
259+
# unpartitioned node
260+
.check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1})
261+
# partitioned node
262+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
263+
.to_executorch()
264+
.serialize()
265+
.run_method_and_compare_outputs()
266+
)

backends/xnnpack/test/ops/test_linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,9 @@ def test_qd8_f32_per_channel_shared_dq_chain(self):
551551
inputs,
552552
dynamic_shapes=None,
553553
is_per_channel=True,
554-
linear_count=2,
554+
linear_count=3, # 2 linear + 1 add
555555
uses_bias=use_bias,
556+
atol=1e-1,
556557
)
557558

558559
def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):
@@ -872,7 +873,7 @@ def test_qd8_per_channel_linear_parallel(self):
872873
),
873874
inputs,
874875
dynamic_shapes=dynamic_shapes,
875-
linear_count=2,
876+
linear_count=3, # 2 linear + 1 add
876877
is_per_channel=True,
877878
uses_bias=True,
878879
)

0 commit comments

Comments
 (0)