diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py new file mode 100644 index 00000000000..eff7f513cb9 --- /dev/null +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Union + +import torch + +from executorch.backends.nxp.quantizer.patterns import ( + AddmmPattern, + AvgPoolPattern, + Conv1dPattern, + Conv2dPattern, + LinearPattern, + MaxPoolPattern, + PadPattern, + PermutePattern, + QuantizationPattern, + ReluInPlacePattern, + ReluPattern, + ReshapePattern, + SoftMaxPattern, +) +from executorch.backends.nxp.quantizer.utils import ( + find_sequential_partitions_aten, + is_annotated, + no_outside_users, +) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + QuantizationAnnotation, + QuantizationConfig, + QuantizationSpec, +) +from torch import fx +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver +from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer + + +class NeutronAtenQuantizer(Quantizer): + def __init__( + self, pattern: QuantizationPattern, quantization_config: QuantizationConfig + ) -> None: + super().__init__() + self.pattern = pattern + self.quantization_config = quantization_config + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + fused_partitions = find_sequential_partitions_aten( + model, + self.pattern.partition_types(), + ) + + input_act_qspec = self.quantization_config.input_activation + weight_qspec = self.quantization_config.weight + bias_qspec = self.quantization_config.bias + output_act_qspec = self.quantization_config.output_activation + + for fused_partition in fused_partitions: + if not no_outside_users(fused_partition): + continue + + anchors = self.pattern.get_anchors(model, fused_partition) + if not anchors or anchors.empty: + continue + if is_annotated( + [ + x[0] + for x in anchors.inputs + + anchors.weights + + anchors.biases + + anchors.output + ] + ): + continue + + for output, *custom_spec in anchors.output: + # pyre-ignore[16]: no attribute + output.meta["quantization_annotation"] = QuantizationAnnotation( + # pyre-ignore[6]: incompatible parameter type + output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), + _annotated=True, + ) + + def annotate_inputs( + inputs: Union[ + List[Tuple[fx.Node, int]], + List[Tuple[fx.Node, int, DerivedQuantizationSpec],], + ], + spec: Optional[QuantizationSpec], + ) -> None: + for node, idx, *custom_spec in inputs: + # pyre-ignore[16]: no attribute + annotation = node.meta.get( + "quantization_annotation", + QuantizationAnnotation(_annotated=True), + ) + arg = ( + # pyre-ignore[16]: no attribute + node.args[idx] + if isinstance(idx, int) + # pyre-ignore[16]: no attribute + else node.args[idx[0]][idx[1]] + ) + annotation.input_qspec_map[arg] = ( + custom_spec[0] if custom_spec else spec + ) + # pyre-ignore[16]: no attribute + node.meta["quantization_annotation"] = annotation + + def annotate_weights_or_biases( + weights_or_biases: List[Tuple[fx.Node, int]], + spec: Optional[QuantizationSpec], + ) -> None: + for node, idx, *custom_spec in weights_or_biases: + annotation = node.meta.get( + "quantization_annotation", + QuantizationAnnotation(_annotated=True), + ) + annotation.input_qspec_map[node.args[idx]] = ( + custom_spec[0] if custom_spec else spec + ) + node.meta["quantization_annotation"] = annotation + + # pyre-ignore[6]: incompatible parameter type + annotate_inputs(anchors.inputs, input_act_qspec) + annotate_weights_or_biases(anchors.weights, weight_qspec) + # pyre-ignore[6]: incompatible parameter type + annotate_weights_or_biases(anchors.biases, bias_qspec) + return model + + def validate(self, model: fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return [] + + +# Quantization Specification used by Neutron NPU +act_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), +) + +wgt_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, + ch_axis=0, +) + +wgt_fc_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, +) + +# Is set by the *PatternQuantizer directly. +bias_qspec = None + + +class NeutronQuantizer(ComposableQuantizer): + def __init__(self): + static_qconfig = QuantizationConfig( + act_qspec, + act_qspec, + wgt_qspec, + None, + ) + static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) + super().__init__( + [ + NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig), + NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), + NeutronAtenQuantizer(Conv2dPattern(), static_qconfig), + NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig), + NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), + NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), + NeutronAtenQuantizer(ReshapePattern(), static_qconfig), + NeutronAtenQuantizer(PermutePattern(), static_qconfig), + NeutronAtenQuantizer(PadPattern(), static_qconfig), + NeutronAtenQuantizer(ReluPattern(), static_qconfig), + NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), + NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), + ] + ) + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + return model diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py new file mode 100644 index 00000000000..6797447c50c --- /dev/null +++ b/backends/nxp/quantizer/patterns.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Type, Union + +import torch + +from executorch.backends.nxp.quantizer.utils import get_bias_qparams +from torch import fx +from torch._ops import OpOverload +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + FixedQParamsQuantizationSpec, + SharedQuantizationSpec, +) + + +@dataclass +class PartitionAnchors: + """ + All fields except output are lists of (node, args_index) pair, where node is from + the given partition and node.args[args_index] is an input to the partition. Assumes + a single output. + + Quantizer uses inputs, weights and biases for quantization annotation. The others + field contains tensor inputs that aren't quantized, and the literals fields contains + is used for other types of input values as well as handling default parameters. + """ + + # Inputs can share quantization parameters + inputs: List[ + Union[ + Tuple[fx.Node, Union[int, Tuple[int, int]]], + Tuple[ + fx.Node, + Union[int, Tuple[int, int]], + SharedQuantizationSpec, + ], + ] + ] = field(default_factory=list) + weights: List[Tuple[fx.Node, int]] = field(default_factory=list) + biases: List[ + Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] + ] = field(default_factory=list) + others: List[Tuple[fx.Node, int]] = field(default_factory=list) + literals: List[Tuple[fx.Node, int]] = field(default_factory=list) + output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( + default_factory=list + ) + empty: bool = False + + +class QuantizationPattern(ABC): + @abstractmethod + def partition_types(self) -> list[OpOverload]: + """ + List of types to be passed to find_sequential_partitions_aten. + """ + pass + + @abstractmethod + def get_anchors( + self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> Optional[PartitionAnchors]: + pass + + +class SharedSpecPattern(QuantizationPattern): + """ + Quantization pattern for shared quantization. + + The quantization is derived from the previous node quantization and the input and output shares the same + quantization parameters (scale and zero-point). + """ + + def partition_types(self) -> List[Type[torch.nn.Module]]: + pass + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors | None: + node = fused_partition[0].nodes[-1] + assert len(fused_partition[0].input_nodes) == 1 + prev_node = fused_partition[0].input_nodes[0] + + # Previous node was not quantized => we are not able to share q-params + if "quantization_annotation" not in prev_node.meta: + return None + + qspec = SharedQuantizationSpec(prev_node) + + return PartitionAnchors( + inputs=[(node, 0)], + weights=[], + biases=[], + output=[ + (node, qspec), + ], + ) + + +class AddmmPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.addmm.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + addmm_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (addmm_node.args[1], addmm_node), + (addmm_node.args[2], addmm_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + return PartitionAnchors( + inputs=[(addmm_node, 1)], + weights=[(addmm_node, 2)], + biases=[(addmm_node, 0, bias_qspec)], + output=[(addmm_node,)], + ) + + +class AvgPoolPattern(SharedSpecPattern): + """ + Quantizer for AvgPool2D operator. + """ + + def partition_types(self): + return [torch.ops.aten.avg_pool2d.default] + + +class Conv1dPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv1d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + conv1d_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv1d_node.args[0], conv1d_node), + (conv1d_node.args[1], conv1d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: + bias = [(conv1d_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv1d_node, 0)], + weights=[(conv1d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv1d_node,)], + ) + + +class Conv2dPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.conv2d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + conv2d_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv2d_node.args[0], conv2d_node), + (conv2d_node.args[1], conv2d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: + bias = [(conv2d_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv2d_node, 0)], + weights=[(conv2d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv2d_node,)], + ) + + +class LinearPattern(QuantizationPattern): + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.linear.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... + linear_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (linear_node.args[0], linear_node), + (linear_node.args[1], linear_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(linear_node.args) > 2: + bias = [(linear_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(linear_node,)], + ) + + +class MaxPoolPattern(SharedSpecPattern): + """ + Quantizer for MaxPool2D operator. + """ + + def partition_types(self): + return [torch.ops.aten.max_pool2d.default] + + +class PadPattern(SharedSpecPattern): + """ + Quantizer for Pad operator. + """ + + def partition_types(self): + return [torch.ops.aten.pad.default] + + +class PermutePattern(SharedSpecPattern): + """ + Quantizer for Permute operator. + """ + + def partition_types(self): + return [torch.ops.aten.permute.default] + + +class ReluPattern(SharedSpecPattern): + """ + Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer. + """ + + def partition_types(self): + return [torch.ops.aten.relu.default] + + +class ReluInPlacePattern(SharedSpecPattern): + """ + Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually + follows computation layer. + """ + + def partition_types(self): + return [torch.ops.aten.relu_.default] + + +class ReshapePattern(SharedSpecPattern): + """ + Quantizer for Reshape operator. + """ + + def partition_types(self): + return [torch.ops.aten.reshape.default] + + +class SoftMaxPattern(QuantizationPattern): + """ + Quantizer for Softmax operator. + + The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. + """ + + def partition_types(self) -> List[OpOverload]: + return [torch.ops.aten.softmax.int] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + node = fused_partition[0].nodes[-1] + assert len(fused_partition[0].input_nodes) == 1 + + qspec = FixedQParamsQuantizationSpec( + dtype=torch.int8, + scale=1.0 / 256.0, + zero_point=-128, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + ) + + return PartitionAnchors( + inputs=[(node, 0)], + weights=[], + biases=[], + output=[ + (node, qspec), + ], + ) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py new file mode 100644 index 00000000000..1effcdff25a --- /dev/null +++ b/backends/nxp/quantizer/utils.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024-2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import itertools +from collections import OrderedDict +from typing import Any, Dict, List, Tuple, Type + +import torch +from torch import fx +from torch._ops import OpOverload +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + SourcePartition, +) + + +def is_annotated(nodes: List[fx.Node]) -> bool: + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def no_outside_users(fused_partition) -> bool: + """ + Checks if each partition other than the last does not have any outside users. + """ + for source_partition in fused_partition[:-1]: + if len(source_partition.output_nodes) != 1: + return False + if len(source_partition.output_nodes[0].users) != 1: + return False + return True + + +def get_bias_qparams( + obs_or_fqs: List[ObserverOrFakeQuantize], +) -> Tuple[torch.Tensor, torch.Tensor]: + act_scale, _ = obs_or_fqs[0].calculate_qparams() + weight_scale, _ = obs_or_fqs[1].calculate_qparams() + bias_scale = act_scale * weight_scale + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) + return bias_scale, bias_zero_point + + +def get_aten_node_target_partitions( + graph: torch.fx.Graph, + wanted_original_aten_op: List[OpOverload], +): + """ + Args: + graph: The graph we want to partition + wanted_original_aten_op: List of original_aten ops (OpOverload) + + Returns: + Dictionary mapping aten ops that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + aten ops. + """ + modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + # TODO(matthiascremon): look into ways to avoid using source_fn_stack + if (source_fn_st := node.meta.get("source_fn_stack")) is None: + continue + + source_fn = source_fn_st[-1] + if node.target not in wanted_original_aten_op: + continue + + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(node.name, []) + partition.append(node) + + def make_partition( + nodes: List[torch.fx.Node], module_type: Type + ) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg not in nodes: + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: Dict[Type[Any], List[SourcePartition]] = {} + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool: + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def find_sequential_partitions_aten( + gm: torch.fx.GraphModule, + partition_types: List[Any], +): + typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + partitions = get_aten_node_target_partitions(gm.graph, [partition_type]) + typed_partitions[partition_type] = list( + itertools.chain.from_iterable(partitions.values()) + ) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [] + for candidate in fusion_candidates: + if _partitions_sequential(candidate): + fused_partitions.append(candidate) + return fused_partitions diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py new file mode 100644 index 00000000000..741e64a28a1 --- /dev/null +++ b/backends/nxp/tests/models.py @@ -0,0 +1,238 @@ +# Copyright 2024 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Collection, Union + +import torch + + +class Conv2dModule(torch.nn.Module): + def __init__( + self, + bias: bool = True, + dilation: Union[int, tuple[int, int]] = 1, + in_channels: int = 4, + kernel_size: Union[int, tuple[int, int]] = 3, + out_channels: int = 8, + padding: Union[str, int, Collection[int]] = 0, + stride: Union[int, tuple[int, int]] = 2, + ): + super().__init__() + + self.conv = torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv(x) + + +class Conv2dAndMaxPool2DModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv = torch.nn.Conv2d( + in_channels=8, out_channels=32, kernel_size=5, bias=True + ) + self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = self.conv(x) + return self.maxpool(x) + + +class Conv2dConstantPadNDModule(torch.nn.Module): + def __init__(self, paddings: Collection[int], constant: float | int | None = None): + super().__init__() + self.pad = ConstantPadNDModule(paddings, constant) + self.conv = Conv2dModule() + + def forward(self, x): + x = self.conv(x) + return self.pad(x) + + +class SoftmaxModule(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.softmax = torch.nn.Softmax(dim=dim) + + def forward(self, x): + return self.softmax(x) + + +class SoftmaxConvModule(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.conv = Conv2dModule() + self.softmax = SoftmaxModule(dim=dim) + + def forward(self, x): + x = self.conv(x) + return self.softmax(x) + + +class LinearModule(torch.nn.Module): + def __init__(self, bias: bool): + super().__init__() + self.linear = torch.nn.Linear(32, 16, bias=bias) + + def forward(self, x): + return self.linear(x) + + +class LinearSoftmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.linear = torch.nn.Linear(12, 10) + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + x = self.linear(x) + x = self.softmax(x) + + return x + + +class ConvFCSoftmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) + self.fc = torch.nn.Linear(1024, 10) + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + x = self.conv(x) + x = torch.reshape(x, (-1, 1024)) + x = self.fc(x) + x = self.softmax(x) + + return x + + +class ConstantPadNDModule(torch.nn.Module): + def __init__(self, paddings: Collection[int], constant: float | int | None = None): + super().__init__() + self.paddings = paddings + self.constant = constant + + def forward(self, x): + if self.constant is None: + return torch.nn.functional.pad(x, tuple(self.paddings), "constant") + else: + return torch.nn.functional.pad( + x, tuple(self.paddings), "constant", self.constant + ) + + +class ConstantPadNDConvModule(torch.nn.Module): + def __init__(self, paddings: Collection[int], constant: float | int | None = None): + super().__init__() + self.pad = ConstantPadNDModule(paddings, constant) + self.conv = Conv2dModule() + + def forward(self, x): + x = self.pad(x) + return self.conv(x) + + +class MaxPool2dModule(torch.nn.Module): + def __init__(self, padding=0): + super().__init__() + + self.max_pool2d = torch.nn.MaxPool2d( + kernel_size=3, stride=2, padding=padding, dilation=1 + ) + + def forward(self, x): + return self.max_pool2d(x) + + +class MaxPool2dConvModule(torch.nn.Module): + def __init__(self, padding=0): + super().__init__() + + self.conv = Conv2dModule() + self.max_pool2d = torch.nn.MaxPool2d( + kernel_size=3, stride=2, padding=padding, dilation=1 + ) + + def forward(self, x): + x = self.conv(x) + return self.max_pool2d(x) + + +class AvgPool2dModule(torch.nn.Module): + def __init__(self, count_include_pad, padding=0): + super().__init__() + + self.avg_pool = torch.nn.AvgPool2d( + kernel_size=3, + stride=2, + padding=padding, + count_include_pad=count_include_pad, + ) + + def forward(self, x): + return self.avg_pool(x) + + +class AvgPool2dConvModule(torch.nn.Module): + def __init__(self, count_include_pad, padding=0): + super().__init__() + + self.conv = Conv2dModule() + self.avg_pool = torch.nn.AvgPool2d( + kernel_size=3, + stride=1, + padding=padding, + count_include_pad=count_include_pad, + ) + + def forward(self, x): + x = self.conv(x) + return self.avg_pool(x) + + +class ReLUModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(x) + + +class Conv2dReLUModule(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv(x) + return self.relu(x) + + +class Conv2dPermuteModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) + + def forward(self, x): + x = self.conv(x) + return torch.permute(x, [0, 2, 1, 3]) diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py new file mode 100644 index 00000000000..868a94059b5 --- /dev/null +++ b/backends/nxp/tests/test_quantizer.py @@ -0,0 +1,273 @@ +# Copyright 2024 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Tests for NeutronQuantizer. + +import executorch.backends.nxp.tests.models as models +import torch +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def _get_target_name(node): + return node._pretty_print_target(node.target) + + +def test_quantizer_conv2d(): + model = models.Conv2dModule() + model.eval() + + example_input = (torch.ones(1, 4, 32, 32),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 11 + assert nodes[7].name == "conv2d" + # [0]: Input, [1] : weights, [2]: bias + assert ( + _get_target_name(nodes[7].args[0]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[7].args[1]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[7].args[2]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[8]) + == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + ) + assert nodes[8].args[0].name == "conv2d" + + +def test_quantizer_linear(): + model = models.LinearModule(bias=True) + model.eval() + + example_input = (torch.ones(10, 32),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 11 + assert nodes[7].name == "linear" + # [0]: Input, [1] : weights, [2]: bias + assert ( + _get_target_name(nodes[7].args[0]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[7].args[1]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[7].args[2]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[8]) + == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + ) + assert nodes[8].args[0].name == "linear" + + +def test_quantizer_maxpool2d(): + model = models.Conv2dAndMaxPool2DModule() + model.eval() + + example_input = (torch.ones(1, 8, 32, 32),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 14 + # Check if QDQ pattern: + assert nodes[10].name == "max_pool2d" + assert ( + _get_target_name(nodes[10].args[0]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[11]) + == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + ) + assert nodes[11].args[0].name == "max_pool2d" + + # Check if input and output quantization is same + input_quant = nodes[10].args[0].args[1:] + output_quant = nodes[11].args[1:] + assert input_quant == output_quant + + +def test_quantizer_softmax(): + model = models.SoftmaxModule(dim=0) + model.eval() + + example_input = (torch.ones(1, 10),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 7 + # Check if QDQ pattern: + assert nodes[3].name == "softmax" + assert ( + _get_target_name(nodes[3].args[0]) + == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + ) + assert ( + _get_target_name(nodes[4]) + == "torch.ops.quantized_decomposed.quantize_per_tensor.default" + ) + assert nodes[4].args[0].name == "softmax" + + # Check output quantization + scale, zp, _, _, dtype = nodes[4].args[1:] + assert scale == 1.0 / 256.0 + assert zp == -128 + assert dtype == torch.int8 + + +def test_quantizer_single_maxpool2d(): + model = models.MaxPool2dModule() + model.eval() + + example_input = (torch.ones(1, 4, 32, 32),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 3 + assert nodes[1].name == "max_pool2d" + assert "quantization_annotation" not in nodes[1].meta + + +def test_quantizer_conv2d_relu(): + model = models.Conv2dReLUModule() + model.eval() + + example_input = (torch.ones(1, 4, 32, 32),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 12 + assert nodes[7].name == "dequantize_per_tensor_default_2" + assert nodes[8].name == "relu" + assert nodes[9].name == "quantize_per_tensor_default_3" + + +def test_quantizer_conv2d_avg_pool2d(): + model = models.AvgPool2dConvModule(count_include_pad=False) + model.eval() + + example_input = (torch.ones(1, 4, 16, 16),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 14 + assert nodes[9].name == "dequantize_per_tensor_default_3" + assert nodes[10].name == "avg_pool2d" + assert nodes[11].name == "quantize_per_tensor_default_4" + + +def test_quantizer_conv2d_permute(): + model = models.Conv2dPermuteModule() + model.eval() + + example_input = (torch.ones(1, 4, 16, 16),) + quantizer = NeutronQuantizer() + graph_module = torch.export.export_for_training( + model, example_input, strict=True + ).module() + + # noinspection PyTypeChecker + m = prepare_pt2e(graph_module, quantizer) + m(*example_input) + m = convert_pt2e(m) + + # Dry run + m(*example_input) + + nodes = list(m.graph.nodes) + assert len(nodes) == 12 + assert nodes[7].name == "dequantize_per_tensor_default_2" + assert nodes[8].name == "permute" + assert nodes[9].name == "quantize_per_tensor_default_3"