Skip to content

NXP backend: Add NeutronQuantizer #9876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple, Union

import torch

from executorch.backends.nxp.quantizer.patterns import (
AddmmPattern,
AvgPoolPattern,
Conv1dPattern,
Conv2dPattern,
LinearPattern,
MaxPoolPattern,
PadPattern,
PermutePattern,
QuantizationPattern,
ReluInPlacePattern,
ReluPattern,
ReshapePattern,
SoftMaxPattern,
)
from executorch.backends.nxp.quantizer.utils import (
find_sequential_partitions_aten,
is_annotated,
no_outside_users,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
OperatorConfig,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
)
from torch import fx
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer


class NeutronAtenQuantizer(Quantizer):
def __init__(
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig
) -> None:
super().__init__()
self.pattern = pattern
self.quantization_config = quantization_config

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
fused_partitions = find_sequential_partitions_aten(
model,
self.pattern.partition_types(),
)

input_act_qspec = self.quantization_config.input_activation
weight_qspec = self.quantization_config.weight
bias_qspec = self.quantization_config.bias
output_act_qspec = self.quantization_config.output_activation

for fused_partition in fused_partitions:
if not no_outside_users(fused_partition):
continue

anchors = self.pattern.get_anchors(model, fused_partition)
if not anchors or anchors.empty:
continue
if is_annotated(
[
x[0]
for x in anchors.inputs
+ anchors.weights
+ anchors.biases
+ anchors.output
]
):
continue

for output, *custom_spec in anchors.output:
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
)

def annotate_inputs(
inputs: Union[
List[Tuple[fx.Node, int]],
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
arg = (
# pyre-ignore[16]: no attribute
node.args[idx]
if isinstance(idx, int)
# pyre-ignore[16]: no attribute
else node.args[idx[0]][idx[1]]
)
annotation.input_qspec_map[arg] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation

def annotate_weights_or_biases(
weights_or_biases: List[Tuple[fx.Node, int]],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
"quantization_annotation",
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta["quantization_annotation"] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
annotate_weights_or_biases(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_weights_or_biases(anchors.biases, bias_qspec)
return model

def validate(self, model: fx.GraphModule) -> None:
pass

@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return []


# Quantization Specification used by Neutron NPU
act_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)

wgt_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
ch_axis=0,
)

wgt_fc_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=MinMaxObserver,
)

# Is set by the *PatternQuantizer directly.
bias_qspec = None


class NeutronQuantizer(ComposableQuantizer):
def __init__(self):
static_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
super().__init__(
[
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig),
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig),
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
NeutronAtenQuantizer(PadPattern(), static_qconfig),
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
]
)

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
return model
342 changes: 342 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Type, Union

import torch

from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
)


@dataclass
class PartitionAnchors:
"""
All fields except output are lists of (node, args_index) pair, where node is from
the given partition and node.args[args_index] is an input to the partition. Assumes
a single output.
Quantizer uses inputs, weights and biases for quantization annotation. The others
field contains tensor inputs that aren't quantized, and the literals fields contains
is used for other types of input values as well as handling default parameters.
"""

# Inputs can share quantization parameters
inputs: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = field(default_factory=list)
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
biases: List[
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
] = field(default_factory=list)
others: List[Tuple[fx.Node, int]] = field(default_factory=list)
literals: List[Tuple[fx.Node, int]] = field(default_factory=list)
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
default_factory=list
)
empty: bool = False


class QuantizationPattern(ABC):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions_aten.
"""
pass

@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass


class SharedSpecPattern(QuantizationPattern):
"""
Quantization pattern for shared quantization.
The quantization is derived from the previous node quantization and the input and output shares the same
quantization parameters (scale and zero-point).
"""

def partition_types(self) -> List[Type[torch.nn.Module]]:
pass

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
prev_node = fused_partition[0].input_nodes[0]

# Previous node was not quantized => we are not able to share q-params
if "quantization_annotation" not in prev_node.meta:
return None

qspec = SharedQuantizationSpec(prev_node)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)


class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
addmm_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(addmm_node.args[1], addmm_node),
(addmm_node.args[2], addmm_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(addmm_node, 1)],
weights=[(addmm_node, 2)],
biases=[(addmm_node, 0, bias_qspec)],
output=[(addmm_node,)],
)


class AvgPoolPattern(SharedSpecPattern):
"""
Quantizer for AvgPool2D operator.
"""

def partition_types(self):
return [torch.ops.aten.avg_pool2d.default]


class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv1d_node.args[0], conv1d_node),
(conv1d_node.args[1], conv1d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
bias = [(conv1d_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(conv1d_node, 0)],
weights=[(conv1d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv1d_node,)],
)


class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv2d_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv2d_node.args[0], conv2d_node),
(conv2d_node.args[1], conv2d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
bias = [(conv2d_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(conv2d_node, 0)],
weights=[(conv2d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv2d_node,)],
)


class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)

# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2:
bias = [(linear_node, 2, bias_qspec)]

return PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)


class MaxPoolPattern(SharedSpecPattern):
"""
Quantizer for MaxPool2D operator.
"""

def partition_types(self):
return [torch.ops.aten.max_pool2d.default]


class PadPattern(SharedSpecPattern):
"""
Quantizer for Pad operator.
"""

def partition_types(self):
return [torch.ops.aten.pad.default]


class PermutePattern(SharedSpecPattern):
"""
Quantizer for Permute operator.
"""

def partition_types(self):
return [torch.ops.aten.permute.default]


class ReluPattern(SharedSpecPattern):
"""
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
"""

def partition_types(self):
return [torch.ops.aten.relu.default]


class ReluInPlacePattern(SharedSpecPattern):
"""
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
follows computation layer.
"""

def partition_types(self):
return [torch.ops.aten.relu_.default]


class ReshapePattern(SharedSpecPattern):
"""
Quantizer for Reshape operator.
"""

def partition_types(self):
return [torch.ops.aten.reshape.default]


class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8.
"""

def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.softmax.int]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=-128,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)
151 changes: 151 additions & 0 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
from collections import OrderedDict
from typing import Any, Dict, List, Tuple, Type

import torch
from torch import fx
from torch._ops import OpOverload
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
SourcePartition,
)


def is_annotated(nodes: List[fx.Node]) -> bool:
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
)
return annotated


def no_outside_users(fused_partition) -> bool:
"""
Checks if each partition other than the last does not have any outside users.
"""
for source_partition in fused_partition[:-1]:
if len(source_partition.output_nodes) != 1:
return False
if len(source_partition.output_nodes[0].users) != 1:
return False
return True


def get_bias_qparams(
obs_or_fqs: List[ObserverOrFakeQuantize],
) -> Tuple[torch.Tensor, torch.Tensor]:
act_scale, _ = obs_or_fqs[0].calculate_qparams()
weight_scale, _ = obs_or_fqs[1].calculate_qparams()
bias_scale = act_scale * weight_scale
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
return bias_scale, bias_zero_point


def get_aten_node_target_partitions(
graph: torch.fx.Graph,
wanted_original_aten_op: List[OpOverload],
):
"""
Args:
graph: The graph we want to partition
wanted_original_aten_op: List of original_aten ops (OpOverload)
Returns:
Dictionary mapping aten ops that were given to a list of SourcePartitions
that correspond to the list of nodes that were decomposed from the given
aten ops.
"""
modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {}

for node in graph.nodes:
# The metadata source_fn should contain a tuple of a unique name for the
# source, and the source function if the node is decomposed from a
# function, or the type of module if the node is decomposed from a leaf
# module
# TODO(matthiascremon): look into ways to avoid using source_fn_stack
if (source_fn_st := node.meta.get("source_fn_stack")) is None:
continue

source_fn = source_fn_st[-1]
if node.target not in wanted_original_aten_op:
continue

diff_modules = modules.setdefault(source_fn[1], {})
partition = diff_modules.setdefault(node.name, [])
partition.append(node)

def make_partition(
nodes: List[torch.fx.Node], module_type: Type
) -> SourcePartition:
input_nodes = set()
output_nodes = set()
params = set()
for node in nodes:
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg not in nodes:
input_nodes.add(arg)

if node.op == "get_attr":
params.add(node)

for user in node.users.keys():
if user not in nodes:
output_nodes.add(node)

return SourcePartition(
nodes,
module_type,
list(input_nodes),
list(output_nodes),
list(params), # type: ignore[arg-type]
)

ret: Dict[Type[Any], List[SourcePartition]] = {}

for k, v in modules.items():
ret[k] = [make_partition(partition, k) for partition in v.values()]

return ret


def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool:
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
prev_partition, partition
):
return False
prev_partition = partition
return True


def find_sequential_partitions_aten(
gm: torch.fx.GraphModule,
partition_types: List[Any],
):
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
for partition_type in partition_types:
partitions = get_aten_node_target_partitions(gm.graph, [partition_type])
typed_partitions[partition_type] = list(
itertools.chain.from_iterable(partitions.values())
)

typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = []
for candidate in fusion_candidates:
if _partitions_sequential(candidate):
fused_partitions.append(candidate)
return fused_partitions
238 changes: 238 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Collection, Union

import torch


class Conv2dModule(torch.nn.Module):
def __init__(
self,
bias: bool = True,
dilation: Union[int, tuple[int, int]] = 1,
in_channels: int = 4,
kernel_size: Union[int, tuple[int, int]] = 3,
out_channels: int = 8,
padding: Union[str, int, Collection[int]] = 0,
stride: Union[int, tuple[int, int]] = 2,
):
super().__init__()

self.conv = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)

def forward(self, x):
return self.conv(x)


class Conv2dAndMaxPool2DModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.conv = torch.nn.Conv2d(
in_channels=8, out_channels=32, kernel_size=5, bias=True
)
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)

def forward(self, x):
x = self.conv(x)
return self.maxpool(x)


class Conv2dConstantPadNDModule(torch.nn.Module):
def __init__(self, paddings: Collection[int], constant: float | int | None = None):
super().__init__()
self.pad = ConstantPadNDModule(paddings, constant)
self.conv = Conv2dModule()

def forward(self, x):
x = self.conv(x)
return self.pad(x)


class SoftmaxModule(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()

self.softmax = torch.nn.Softmax(dim=dim)

def forward(self, x):
return self.softmax(x)


class SoftmaxConvModule(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()

self.conv = Conv2dModule()
self.softmax = SoftmaxModule(dim=dim)

def forward(self, x):
x = self.conv(x)
return self.softmax(x)


class LinearModule(torch.nn.Module):
def __init__(self, bias: bool):
super().__init__()
self.linear = torch.nn.Linear(32, 16, bias=bias)

def forward(self, x):
return self.linear(x)


class LinearSoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.linear = torch.nn.Linear(12, 10)
self.softmax = torch.nn.Softmax(1)

def forward(self, x):
x = self.linear(x)
x = self.softmax(x)

return x


class ConvFCSoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)
self.fc = torch.nn.Linear(1024, 10)
self.softmax = torch.nn.Softmax(1)

def forward(self, x):
x = self.conv(x)
x = torch.reshape(x, (-1, 1024))
x = self.fc(x)
x = self.softmax(x)

return x


class ConstantPadNDModule(torch.nn.Module):
def __init__(self, paddings: Collection[int], constant: float | int | None = None):
super().__init__()
self.paddings = paddings
self.constant = constant

def forward(self, x):
if self.constant is None:
return torch.nn.functional.pad(x, tuple(self.paddings), "constant")
else:
return torch.nn.functional.pad(
x, tuple(self.paddings), "constant", self.constant
)


class ConstantPadNDConvModule(torch.nn.Module):
def __init__(self, paddings: Collection[int], constant: float | int | None = None):
super().__init__()
self.pad = ConstantPadNDModule(paddings, constant)
self.conv = Conv2dModule()

def forward(self, x):
x = self.pad(x)
return self.conv(x)


class MaxPool2dModule(torch.nn.Module):
def __init__(self, padding=0):
super().__init__()

self.max_pool2d = torch.nn.MaxPool2d(
kernel_size=3, stride=2, padding=padding, dilation=1
)

def forward(self, x):
return self.max_pool2d(x)


class MaxPool2dConvModule(torch.nn.Module):
def __init__(self, padding=0):
super().__init__()

self.conv = Conv2dModule()
self.max_pool2d = torch.nn.MaxPool2d(
kernel_size=3, stride=2, padding=padding, dilation=1
)

def forward(self, x):
x = self.conv(x)
return self.max_pool2d(x)


class AvgPool2dModule(torch.nn.Module):
def __init__(self, count_include_pad, padding=0):
super().__init__()

self.avg_pool = torch.nn.AvgPool2d(
kernel_size=3,
stride=2,
padding=padding,
count_include_pad=count_include_pad,
)

def forward(self, x):
return self.avg_pool(x)


class AvgPool2dConvModule(torch.nn.Module):
def __init__(self, count_include_pad, padding=0):
super().__init__()

self.conv = Conv2dModule()
self.avg_pool = torch.nn.AvgPool2d(
kernel_size=3,
stride=1,
padding=padding,
count_include_pad=count_include_pad,
)

def forward(self, x):
x = self.conv(x)
return self.avg_pool(x)


class ReLUModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x)


class Conv2dReLUModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv(x)
return self.relu(x)


class Conv2dPermuteModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)

def forward(self, x):
x = self.conv(x)
return torch.permute(x, [0, 2, 1, 3])
273 changes: 273 additions & 0 deletions backends/nxp/tests/test_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Tests for NeutronQuantizer.

import executorch.backends.nxp.tests.models as models
import torch
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e


def _get_target_name(node):
return node._pretty_print_target(node.target)


def test_quantizer_conv2d():
model = models.Conv2dModule()
model.eval()

example_input = (torch.ones(1, 4, 32, 32),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 11
assert nodes[7].name == "conv2d"
# [0]: Input, [1] : weights, [2]: bias
assert (
_get_target_name(nodes[7].args[0])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[7].args[1])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[7].args[2])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[8])
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
)
assert nodes[8].args[0].name == "conv2d"


def test_quantizer_linear():
model = models.LinearModule(bias=True)
model.eval()

example_input = (torch.ones(10, 32),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 11
assert nodes[7].name == "linear"
# [0]: Input, [1] : weights, [2]: bias
assert (
_get_target_name(nodes[7].args[0])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[7].args[1])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[7].args[2])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[8])
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
)
assert nodes[8].args[0].name == "linear"


def test_quantizer_maxpool2d():
model = models.Conv2dAndMaxPool2DModule()
model.eval()

example_input = (torch.ones(1, 8, 32, 32),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 14
# Check if QDQ pattern:
assert nodes[10].name == "max_pool2d"
assert (
_get_target_name(nodes[10].args[0])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[11])
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
)
assert nodes[11].args[0].name == "max_pool2d"

# Check if input and output quantization is same
input_quant = nodes[10].args[0].args[1:]
output_quant = nodes[11].args[1:]
assert input_quant == output_quant


def test_quantizer_softmax():
model = models.SoftmaxModule(dim=0)
model.eval()

example_input = (torch.ones(1, 10),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 7
# Check if QDQ pattern:
assert nodes[3].name == "softmax"
assert (
_get_target_name(nodes[3].args[0])
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
)
assert (
_get_target_name(nodes[4])
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
)
assert nodes[4].args[0].name == "softmax"

# Check output quantization
scale, zp, _, _, dtype = nodes[4].args[1:]
assert scale == 1.0 / 256.0
assert zp == -128
assert dtype == torch.int8


def test_quantizer_single_maxpool2d():
model = models.MaxPool2dModule()
model.eval()

example_input = (torch.ones(1, 4, 32, 32),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 3
assert nodes[1].name == "max_pool2d"
assert "quantization_annotation" not in nodes[1].meta


def test_quantizer_conv2d_relu():
model = models.Conv2dReLUModule()
model.eval()

example_input = (torch.ones(1, 4, 32, 32),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 12
assert nodes[7].name == "dequantize_per_tensor_default_2"
assert nodes[8].name == "relu"
assert nodes[9].name == "quantize_per_tensor_default_3"


def test_quantizer_conv2d_avg_pool2d():
model = models.AvgPool2dConvModule(count_include_pad=False)
model.eval()

example_input = (torch.ones(1, 4, 16, 16),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 14
assert nodes[9].name == "dequantize_per_tensor_default_3"
assert nodes[10].name == "avg_pool2d"
assert nodes[11].name == "quantize_per_tensor_default_4"


def test_quantizer_conv2d_permute():
model = models.Conv2dPermuteModule()
model.eval()

example_input = (torch.ones(1, 4, 16, 16),)
quantizer = NeutronQuantizer()
graph_module = torch.export.export_for_training(
model, example_input, strict=True
).module()

# noinspection PyTypeChecker
m = prepare_pt2e(graph_module, quantizer)
m(*example_input)
m = convert_pt2e(m)

# Dry run
m(*example_input)

nodes = list(m.graph.nodes)
assert len(nodes) == 12
assert nodes[7].name == "dequantize_per_tensor_default_2"
assert nodes[8].name == "permute"
assert nodes[9].name == "quantize_per_tensor_default_3"