Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions backends/arm/operators/op_addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def define_node(
quant_node = input_node.all_input_nodes[0]
else:
quant_node = input_node
input_zp = get_quant_node_args(quant_node)[1]
input_zp = get_quant_node_args(quant_node).zp
attr.ConvAttribute(
pad=pad_attr,
stride=stride_attr,
Expand Down Expand Up @@ -111,24 +111,21 @@ def define_node(
# rank > 2 linear layer
if input_node.target == exir_ops.edge.aten.view_copy.default:
quant_node = input_node.all_input_nodes[0]
input_scale, _ = get_quant_node_args(quant_node)
input_scale = get_quant_node_args(quant_node).scale
consumer_node = list(node.users)[0]
consumer_consumer_node = list(consumer_node.users)[0]
(
consumer_node_scale,
consumer_node_node_zp,
) = get_quant_node_args(consumer_consumer_node)

quant_args = get_quant_node_args(consumer_consumer_node)
consumer_node_scale = quant_args.scale
consumer_node_node_zp = quant_args.zp
else:
input_scale, _ = get_quant_node_args(input_node)
input_scale = get_quant_node_args(input_node).scale
consumer_node = list(node.users)[0]
(
consumer_node_scale,
consumer_node_node_zp,
) = get_quant_node_args(consumer_node)
quant_args = get_quant_node_args(consumer_node)
consumer_node_scale = quant_args.scale
consumer_node_node_zp = quant_args.zp

weight_node_q_node = weight_node.all_input_nodes[0]
weight_scale, _ = get_quant_node_args(weight_node_q_node)
weight_scale = get_quant_node_args(weight_node_q_node).scale

output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale
(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def build_avg_pool_2d_common(
output_zp = 0

if is_quant_node:
_, input_zp = get_quant_node_args(node.args[0])
_, output_zp = get_quant_node_args(list(node.users)[0])
input_zp = get_quant_node_args(node.args[0]).zp
output_zp = get_quant_node_args(list(node.users)[0]).zp

attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def define_node(
)

input_zp = (
get_quant_node_args(node.all_input_nodes[0])[1] if is_quant_node else 0
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
)

attr.ConvAttribute(
Expand Down
31 changes: 26 additions & 5 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,6 +11,8 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from serializer.tosa_serializer import TosaOp


Expand All @@ -30,12 +32,31 @@ def define_node(
is_quant_node: bool,
) -> None:
attr = ts.TosaSerializerAttribute()

if is_quant_node:
# Get quant parameters
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
# Convert to quantized representation
clamp_min_qs = round((inputs[1].number / scale) + zp)
clamp_min_qs = max(clamp_min_qs, qmin)
clamp_max_qs = round((inputs[2].number / scale) + zp)
clamp_max_qs = min(clamp_max_qs, qmax)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
else:
clamp_min_fp = inputs[1].number
clamp_max_fp = inputs[2].number
# Set qs values to 0 since they are not used
clamp_min_qs = 0
clamp_max_qs = 0

attr.ClampAttribute(
tosa_graph.builder,
int(inputs[1].number),
int(inputs[2].number),
inputs[1].number,
inputs[2].number,
clamp_min_qs,
clamp_max_qs,
clamp_min_fp,
clamp_max_fp,
)

tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)
12 changes: 7 additions & 5 deletions backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ def process_placeholder(
weight_node = weight_node_permuted.all_input_nodes[0]

if input_node.target == exir_ops.edge.aten.view_copy.default:
input_node_scale, _ = get_quant_node_args(input_node.all_input_nodes[0])
input_node_scale = get_quant_node_args(
input_node.all_input_nodes[0]
).scale
else:
input_node_scale, _ = get_quant_node_args(input_node)
input_node_scale = get_quant_node_args(input_node).scale

weight_node_scale, _ = get_quant_node_args(weight_node)
weight_node_scale = get_quant_node_args(weight_node).scale

bias_values_quantized = (
(parameter_values / (input_node_scale * weight_node_scale))
Expand All @@ -81,8 +83,8 @@ def process_placeholder(
bias_node,
) = consumer_node.all_input_nodes

input_node_scale, _ = get_quant_node_args(input_node)
weight_node_scale, _ = get_quant_node_args(weight_node)
input_node_scale = get_quant_node_args(input_node).scale
weight_node_scale = get_quant_node_args(weight_node).scale

bias_scales = input_node_scale * weight_node_scale
parameter_values_quantized = (
Expand Down
67 changes: 54 additions & 13 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -126,6 +127,32 @@ def forward(self, x):
return x


class ComboConvRelu6(torch.nn.Module):
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
]

test_data = [
(20 * torch.randn(1, 3, 256, 256),),
(5 * torch.randn(1, 3, 256, 256),),
(torch.randn(1, 3, 256, 256),),
(-5 * torch.randn(1, 3, 256, 256),),
]

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
self.relu6 = torch.nn.ReLU6()

def forward(self, x):
x = self.conv2d(x)
x = self.relu6(x)
return x


class TestConvCombos(unittest.TestCase):
def _test_conv_combo_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
Expand Down Expand Up @@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
# testcase as well (and also not tested). For now, just increase the
# tolerance, such that we don't skip the test entirely (i.e. we maintain
# functionality).
def test_conv_batchnorm_relu_tosa_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@unittest.skipIf(
not common.VELA_INSTALLED,
Expand All @@ -240,21 +261,41 @@ def test_conv_batchnorm_relu_u55_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

##################
## Conv + ReLU6 ##
##################
@parameterized.expand(ComboConvRelu6.test_data)
def test_conv_relu6_tosa_MI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_tosa_MI_pipeline(model, test_data)

@parameterized.expand(ComboConvRelu6.test_data)
def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_tosa_BI_pipeline(model, test_data)

@parameterized.expand(ComboConvRelu6.test_data)
@unittest.skipIf(
not common.VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_conv_relu6_u55_BI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_u55_BI_pipeline(model, test_data)

###############################
## Block bottleneck residual ##
###############################
def test_block_bottleneck_residual_tosa_MI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
# testcase as well. For now, just increase the tolerance, such that
# we don't skip the test entirely (i.e. we maintain functionality).
def test_block_bottleneck_residual_tosa_BI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@unittest.skipIf(
not common.VELA_INSTALLED,
Expand Down
29 changes: 25 additions & 4 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
# Utiliy functions for TOSA quantized lowerings

import math
from typing import NamedTuple

import serializer.tosa_serializer as ts
import torch.fx
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor
Expand All @@ -17,7 +19,14 @@
dq_q_ops = [q_op, dq_op]


def is_quant_node(node):
class QuantArgs(NamedTuple):
scale: float
zp: int
qmin: int
qmax: int


def is_quant_node(node: torch.fx.Node):
consumer_node = list(node.users)[0]
input = node.all_input_nodes[0]

Expand All @@ -41,10 +50,22 @@ def is_quant_arg(arg):
return consumer_node.target == q_op


def get_quant_node_args(node):
def get_quant_node_args(node: torch.fx.Node):
"""
Get the quantization parameters from a quant node.

Args:
node: The quant node.
Returns:
QuantArgs: scale, zp, qmin, qmax
"""
quant_args = [TosaArg(arg) for arg in node.args]
# Return the scale and zp
return quant_args[1].number, quant_args[2].number
return QuantArgs(
quant_args[1].number,
quant_args[2].number,
quant_args[3].number,
quant_args[4].number,
)


# Check if scale32 mode is used for given output element type
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
f"\t-- Model vs. Reference --\n"
f"\t Numel: {model.numel()}, {ref.numel()}\n"
f"\tMedian: {model.median()}, {ref.median()}\n"
Expand Down