-
Notifications
You must be signed in to change notification settings - Fork 646
NXP backend: Add NeutronQuantizer #9876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
digantdesai
merged 1 commit into
pytorch:main
from
nxp-upstream:feature/nxf93343/neutron-quantizer
Apr 14, 2025
+1,209
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024-2025 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
from executorch.backends.nxp.quantizer.patterns import ( | ||
AddmmPattern, | ||
AvgPoolPattern, | ||
Conv1dPattern, | ||
Conv2dPattern, | ||
LinearPattern, | ||
MaxPoolPattern, | ||
PadPattern, | ||
PermutePattern, | ||
QuantizationPattern, | ||
ReluInPlacePattern, | ||
ReluPattern, | ||
ReshapePattern, | ||
SoftMaxPattern, | ||
) | ||
from executorch.backends.nxp.quantizer.utils import ( | ||
find_sequential_partitions_aten, | ||
is_annotated, | ||
no_outside_users, | ||
) | ||
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( | ||
OperatorConfig, | ||
QuantizationAnnotation, | ||
QuantizationConfig, | ||
QuantizationSpec, | ||
) | ||
from torch import fx | ||
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver | ||
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer | ||
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer | ||
|
||
|
||
class NeutronAtenQuantizer(Quantizer): | ||
def __init__( | ||
self, pattern: QuantizationPattern, quantization_config: QuantizationConfig | ||
) -> None: | ||
super().__init__() | ||
self.pattern = pattern | ||
self.quantization_config = quantization_config | ||
|
||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
fused_partitions = find_sequential_partitions_aten( | ||
model, | ||
self.pattern.partition_types(), | ||
) | ||
|
||
input_act_qspec = self.quantization_config.input_activation | ||
weight_qspec = self.quantization_config.weight | ||
bias_qspec = self.quantization_config.bias | ||
output_act_qspec = self.quantization_config.output_activation | ||
|
||
for fused_partition in fused_partitions: | ||
if not no_outside_users(fused_partition): | ||
continue | ||
|
||
anchors = self.pattern.get_anchors(model, fused_partition) | ||
if not anchors or anchors.empty: | ||
continue | ||
if is_annotated( | ||
[ | ||
x[0] | ||
for x in anchors.inputs | ||
+ anchors.weights | ||
+ anchors.biases | ||
+ anchors.output | ||
] | ||
): | ||
continue | ||
|
||
for output, *custom_spec in anchors.output: | ||
# pyre-ignore[16]: no attribute | ||
output.meta["quantization_annotation"] = QuantizationAnnotation( | ||
# pyre-ignore[6]: incompatible parameter type | ||
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), | ||
_annotated=True, | ||
) | ||
|
||
def annotate_inputs( | ||
inputs: Union[ | ||
List[Tuple[fx.Node, int]], | ||
List[Tuple[fx.Node, int, DerivedQuantizationSpec],], | ||
], | ||
spec: Optional[QuantizationSpec], | ||
) -> None: | ||
for node, idx, *custom_spec in inputs: | ||
# pyre-ignore[16]: no attribute | ||
annotation = node.meta.get( | ||
"quantization_annotation", | ||
QuantizationAnnotation(_annotated=True), | ||
) | ||
arg = ( | ||
# pyre-ignore[16]: no attribute | ||
node.args[idx] | ||
if isinstance(idx, int) | ||
# pyre-ignore[16]: no attribute | ||
else node.args[idx[0]][idx[1]] | ||
) | ||
annotation.input_qspec_map[arg] = ( | ||
custom_spec[0] if custom_spec else spec | ||
) | ||
# pyre-ignore[16]: no attribute | ||
node.meta["quantization_annotation"] = annotation | ||
|
||
def annotate_weights_or_biases( | ||
weights_or_biases: List[Tuple[fx.Node, int]], | ||
spec: Optional[QuantizationSpec], | ||
) -> None: | ||
for node, idx, *custom_spec in weights_or_biases: | ||
annotation = node.meta.get( | ||
"quantization_annotation", | ||
QuantizationAnnotation(_annotated=True), | ||
) | ||
annotation.input_qspec_map[node.args[idx]] = ( | ||
custom_spec[0] if custom_spec else spec | ||
) | ||
node.meta["quantization_annotation"] = annotation | ||
|
||
# pyre-ignore[6]: incompatible parameter type | ||
annotate_inputs(anchors.inputs, input_act_qspec) | ||
annotate_weights_or_biases(anchors.weights, weight_qspec) | ||
# pyre-ignore[6]: incompatible parameter type | ||
annotate_weights_or_biases(anchors.biases, bias_qspec) | ||
return model | ||
|
||
def validate(self, model: fx.GraphModule) -> None: | ||
pass | ||
|
||
@classmethod | ||
def get_supported_operators(cls) -> List[OperatorConfig]: | ||
return [] | ||
|
||
|
||
# Quantization Specification used by Neutron NPU | ||
act_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-128, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_affine, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), | ||
) | ||
|
||
wgt_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-127, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_symmetric, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=MinMaxObserver, | ||
ch_axis=0, | ||
) | ||
|
||
wgt_fc_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-127, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_symmetric, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=MinMaxObserver, | ||
) | ||
|
||
# Is set by the *PatternQuantizer directly. | ||
bias_qspec = None | ||
|
||
|
||
class NeutronQuantizer(ComposableQuantizer): | ||
def __init__(self): | ||
static_qconfig = QuantizationConfig( | ||
act_qspec, | ||
act_qspec, | ||
wgt_qspec, | ||
None, | ||
) | ||
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) | ||
super().__init__( | ||
[ | ||
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig), | ||
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig), | ||
NeutronAtenQuantizer(Conv2dPattern(), static_qconfig), | ||
NeutronAtenQuantizer(LinearPattern(), static_fc_qconfig), | ||
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig), | ||
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig), | ||
NeutronAtenQuantizer(ReshapePattern(), static_qconfig), | ||
NeutronAtenQuantizer(PermutePattern(), static_qconfig), | ||
NeutronAtenQuantizer(PadPattern(), static_qconfig), | ||
NeutronAtenQuantizer(ReluPattern(), static_qconfig), | ||
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig), | ||
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), | ||
] | ||
) | ||
|
||
def transform_for_annotation( | ||
self, model: torch.fx.GraphModule | ||
) -> torch.fx.GraphModule: | ||
return model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2025 NXP | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional, Tuple, Type, Union | ||
|
||
import torch | ||
|
||
from executorch.backends.nxp.quantizer.utils import get_bias_qparams | ||
from torch import fx | ||
from torch._ops import OpOverload | ||
from torch.ao.quantization.quantizer import ( | ||
DerivedQuantizationSpec, | ||
FixedQParamsQuantizationSpec, | ||
SharedQuantizationSpec, | ||
) | ||
|
||
|
||
@dataclass | ||
class PartitionAnchors: | ||
""" | ||
All fields except output are lists of (node, args_index) pair, where node is from | ||
the given partition and node.args[args_index] is an input to the partition. Assumes | ||
a single output. | ||
Quantizer uses inputs, weights and biases for quantization annotation. The others | ||
field contains tensor inputs that aren't quantized, and the literals fields contains | ||
is used for other types of input values as well as handling default parameters. | ||
""" | ||
|
||
# Inputs can share quantization parameters | ||
inputs: List[ | ||
Union[ | ||
Tuple[fx.Node, Union[int, Tuple[int, int]]], | ||
Tuple[ | ||
fx.Node, | ||
Union[int, Tuple[int, int]], | ||
SharedQuantizationSpec, | ||
], | ||
] | ||
] = field(default_factory=list) | ||
weights: List[Tuple[fx.Node, int]] = field(default_factory=list) | ||
biases: List[ | ||
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] | ||
] = field(default_factory=list) | ||
others: List[Tuple[fx.Node, int]] = field(default_factory=list) | ||
literals: List[Tuple[fx.Node, int]] = field(default_factory=list) | ||
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( | ||
default_factory=list | ||
) | ||
empty: bool = False | ||
|
||
|
||
class QuantizationPattern(ABC): | ||
@abstractmethod | ||
def partition_types(self) -> list[OpOverload]: | ||
""" | ||
List of types to be passed to find_sequential_partitions_aten. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_anchors( | ||
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> Optional[PartitionAnchors]: | ||
pass | ||
|
||
|
||
class SharedSpecPattern(QuantizationPattern): | ||
""" | ||
Quantization pattern for shared quantization. | ||
The quantization is derived from the previous node quantization and the input and output shares the same | ||
quantization parameters (scale and zero-point). | ||
""" | ||
|
||
def partition_types(self) -> List[Type[torch.nn.Module]]: | ||
pass | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors | None: | ||
node = fused_partition[0].nodes[-1] | ||
assert len(fused_partition[0].input_nodes) == 1 | ||
prev_node = fused_partition[0].input_nodes[0] | ||
|
||
# Previous node was not quantized => we are not able to share q-params | ||
if "quantization_annotation" not in prev_node.meta: | ||
return None | ||
|
||
qspec = SharedQuantizationSpec(prev_node) | ||
|
||
return PartitionAnchors( | ||
inputs=[(node, 0)], | ||
weights=[], | ||
biases=[], | ||
output=[ | ||
(node, qspec), | ||
], | ||
) | ||
|
||
|
||
class AddmmPattern(QuantizationPattern): | ||
def partition_types(self) -> List[OpOverload]: | ||
return [torch.ops.aten.addmm.default] | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors: | ||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... | ||
addmm_node = fused_partition[0].nodes[-1] | ||
|
||
bias_qspec = DerivedQuantizationSpec( | ||
derived_from=[ | ||
(addmm_node.args[1], addmm_node), | ||
(addmm_node.args[2], addmm_node), | ||
], | ||
derive_qparams_fn=get_bias_qparams, | ||
dtype=torch.int32, | ||
quant_min=-(2**31), | ||
quant_max=2**31 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
) | ||
|
||
return PartitionAnchors( | ||
inputs=[(addmm_node, 1)], | ||
weights=[(addmm_node, 2)], | ||
biases=[(addmm_node, 0, bias_qspec)], | ||
output=[(addmm_node,)], | ||
) | ||
|
||
|
||
class AvgPoolPattern(SharedSpecPattern): | ||
""" | ||
Quantizer for AvgPool2D operator. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.avg_pool2d.default] | ||
|
||
|
||
class Conv1dPattern(QuantizationPattern): | ||
def partition_types(self) -> List[OpOverload]: | ||
return [torch.ops.aten.conv1d.default] | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors: | ||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... | ||
conv1d_node = fused_partition[0].nodes[-1] | ||
|
||
bias_qspec = DerivedQuantizationSpec( | ||
derived_from=[ | ||
(conv1d_node.args[0], conv1d_node), | ||
(conv1d_node.args[1], conv1d_node), | ||
], | ||
derive_qparams_fn=get_bias_qparams, | ||
dtype=torch.int32, | ||
quant_min=-(2**31), | ||
quant_max=2**31 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
) | ||
|
||
# Keep bias empty if not supplied | ||
bias = [] | ||
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: | ||
bias = [(conv1d_node, 2, bias_qspec)] | ||
|
||
return PartitionAnchors( | ||
inputs=[(conv1d_node, 0)], | ||
weights=[(conv1d_node, 1)], | ||
# pyre-fixme[6]: Incompatible parameter type | ||
biases=bias, | ||
output=[(conv1d_node,)], | ||
) | ||
|
||
|
||
class Conv2dPattern(QuantizationPattern): | ||
def partition_types(self) -> List[OpOverload]: | ||
return [torch.ops.aten.conv2d.default] | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors: | ||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... | ||
conv2d_node = fused_partition[0].nodes[-1] | ||
|
||
bias_qspec = DerivedQuantizationSpec( | ||
derived_from=[ | ||
(conv2d_node.args[0], conv2d_node), | ||
(conv2d_node.args[1], conv2d_node), | ||
], | ||
derive_qparams_fn=get_bias_qparams, | ||
dtype=torch.int32, | ||
quant_min=-(2**31), | ||
quant_max=2**31 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
) | ||
|
||
# Keep bias empty if not supplied | ||
bias = [] | ||
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: | ||
bias = [(conv2d_node, 2, bias_qspec)] | ||
|
||
return PartitionAnchors( | ||
inputs=[(conv2d_node, 0)], | ||
weights=[(conv2d_node, 1)], | ||
# pyre-fixme[6]: Incompatible parameter type | ||
biases=bias, | ||
output=[(conv2d_node,)], | ||
) | ||
|
||
|
||
class LinearPattern(QuantizationPattern): | ||
def partition_types(self) -> List[OpOverload]: | ||
return [torch.ops.aten.linear.default] | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors: | ||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... | ||
linear_node = fused_partition[0].nodes[-1] | ||
|
||
bias_qspec = DerivedQuantizationSpec( | ||
derived_from=[ | ||
(linear_node.args[0], linear_node), | ||
(linear_node.args[1], linear_node), | ||
], | ||
derive_qparams_fn=get_bias_qparams, | ||
dtype=torch.int32, | ||
quant_min=-(2**31), | ||
quant_max=2**31 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
) | ||
|
||
# Keep bias empty if not supplied | ||
bias = [] | ||
if len(linear_node.args) > 2: | ||
bias = [(linear_node, 2, bias_qspec)] | ||
|
||
return PartitionAnchors( | ||
inputs=[(linear_node, 0)], | ||
weights=[(linear_node, 1)], | ||
# pyre-fixme[6]: Incompatible parameter type | ||
biases=bias, | ||
output=[(linear_node,)], | ||
) | ||
|
||
|
||
class MaxPoolPattern(SharedSpecPattern): | ||
""" | ||
Quantizer for MaxPool2D operator. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.max_pool2d.default] | ||
|
||
|
||
class PadPattern(SharedSpecPattern): | ||
""" | ||
Quantizer for Pad operator. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.pad.default] | ||
|
||
|
||
class PermutePattern(SharedSpecPattern): | ||
""" | ||
Quantizer for Permute operator. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.permute.default] | ||
|
||
|
||
class ReluPattern(SharedSpecPattern): | ||
""" | ||
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.relu.default] | ||
|
||
|
||
class ReluInPlacePattern(SharedSpecPattern): | ||
""" | ||
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually | ||
follows computation layer. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.relu_.default] | ||
|
||
|
||
class ReshapePattern(SharedSpecPattern): | ||
""" | ||
Quantizer for Reshape operator. | ||
""" | ||
|
||
def partition_types(self): | ||
return [torch.ops.aten.reshape.default] | ||
|
||
|
||
class SoftMaxPattern(QuantizationPattern): | ||
""" | ||
Quantizer for Softmax operator. | ||
The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. | ||
""" | ||
|
||
def partition_types(self) -> List[OpOverload]: | ||
return [torch.ops.aten.softmax.int] | ||
|
||
def get_anchors( | ||
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] | ||
) -> PartitionAnchors: | ||
node = fused_partition[0].nodes[-1] | ||
assert len(fused_partition[0].input_nodes) == 1 | ||
|
||
qspec = FixedQParamsQuantizationSpec( | ||
dtype=torch.int8, | ||
scale=1.0 / 256.0, | ||
zero_point=-128, | ||
quant_min=-128, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_affine, | ||
) | ||
|
||
return PartitionAnchors( | ||
inputs=[(node, 0)], | ||
weights=[], | ||
biases=[], | ||
output=[ | ||
(node, qspec), | ||
], | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024-2025 NXP | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
import itertools | ||
from collections import OrderedDict | ||
from typing import Any, Dict, List, Tuple, Type | ||
|
||
import torch | ||
from torch import fx | ||
from torch._ops import OpOverload | ||
from torch.ao.quantization import ObserverOrFakeQuantize | ||
from torch.fx.passes.utils.source_matcher_utils import ( | ||
check_subgraphs_connected, | ||
SourcePartition, | ||
) | ||
|
||
|
||
def is_annotated(nodes: List[fx.Node]) -> bool: | ||
annotated = False | ||
for node in nodes: | ||
annotated = annotated or ( | ||
"quantization_annotation" in node.meta | ||
and node.meta["quantization_annotation"]._annotated | ||
) | ||
return annotated | ||
|
||
|
||
def no_outside_users(fused_partition) -> bool: | ||
""" | ||
Checks if each partition other than the last does not have any outside users. | ||
""" | ||
for source_partition in fused_partition[:-1]: | ||
if len(source_partition.output_nodes) != 1: | ||
return False | ||
if len(source_partition.output_nodes[0].users) != 1: | ||
return False | ||
return True | ||
|
||
|
||
def get_bias_qparams( | ||
obs_or_fqs: List[ObserverOrFakeQuantize], | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
act_scale, _ = obs_or_fqs[0].calculate_qparams() | ||
weight_scale, _ = obs_or_fqs[1].calculate_qparams() | ||
bias_scale = act_scale * weight_scale | ||
bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) | ||
return bias_scale, bias_zero_point | ||
|
||
|
||
def get_aten_node_target_partitions( | ||
graph: torch.fx.Graph, | ||
wanted_original_aten_op: List[OpOverload], | ||
): | ||
""" | ||
Args: | ||
graph: The graph we want to partition | ||
wanted_original_aten_op: List of original_aten ops (OpOverload) | ||
Returns: | ||
Dictionary mapping aten ops that were given to a list of SourcePartitions | ||
that correspond to the list of nodes that were decomposed from the given | ||
aten ops. | ||
""" | ||
modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {} | ||
|
||
for node in graph.nodes: | ||
# The metadata source_fn should contain a tuple of a unique name for the | ||
# source, and the source function if the node is decomposed from a | ||
# function, or the type of module if the node is decomposed from a leaf | ||
# module | ||
# TODO(matthiascremon): look into ways to avoid using source_fn_stack | ||
if (source_fn_st := node.meta.get("source_fn_stack")) is None: | ||
continue | ||
|
||
source_fn = source_fn_st[-1] | ||
if node.target not in wanted_original_aten_op: | ||
continue | ||
|
||
diff_modules = modules.setdefault(source_fn[1], {}) | ||
partition = diff_modules.setdefault(node.name, []) | ||
partition.append(node) | ||
|
||
def make_partition( | ||
nodes: List[torch.fx.Node], module_type: Type | ||
) -> SourcePartition: | ||
input_nodes = set() | ||
output_nodes = set() | ||
params = set() | ||
for node in nodes: | ||
for arg in node.args: | ||
if isinstance(arg, torch.fx.Node) and arg not in nodes: | ||
input_nodes.add(arg) | ||
|
||
if node.op == "get_attr": | ||
params.add(node) | ||
|
||
for user in node.users.keys(): | ||
if user not in nodes: | ||
output_nodes.add(node) | ||
|
||
return SourcePartition( | ||
nodes, | ||
module_type, | ||
list(input_nodes), | ||
list(output_nodes), | ||
list(params), # type: ignore[arg-type] | ||
) | ||
|
||
ret: Dict[Type[Any], List[SourcePartition]] = {} | ||
|
||
for k, v in modules.items(): | ||
ret[k] = [make_partition(partition, k) for partition in v.values()] | ||
|
||
return ret | ||
|
||
|
||
def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool: | ||
prev_partition = None | ||
for partition in partitions: | ||
if prev_partition is not None and not check_subgraphs_connected( | ||
prev_partition, partition | ||
): | ||
return False | ||
prev_partition = partition | ||
return True | ||
|
||
|
||
def find_sequential_partitions_aten( | ||
gm: torch.fx.GraphModule, | ||
partition_types: List[Any], | ||
): | ||
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() | ||
for partition_type in partition_types: | ||
partitions = get_aten_node_target_partitions(gm.graph, [partition_type]) | ||
typed_partitions[partition_type] = list( | ||
itertools.chain.from_iterable(partitions.values()) | ||
) | ||
|
||
typed_partitions_list = list(typed_partitions.values()) | ||
fusion_candidates = itertools.product(*typed_partitions_list) | ||
fused_partitions = [] | ||
for candidate in fusion_candidates: | ||
if _partitions_sequential(candidate): | ||
fused_partitions.append(candidate) | ||
return fused_partitions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
# Copyright 2024 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Collection, Union | ||
|
||
import torch | ||
|
||
|
||
class Conv2dModule(torch.nn.Module): | ||
def __init__( | ||
self, | ||
bias: bool = True, | ||
dilation: Union[int, tuple[int, int]] = 1, | ||
in_channels: int = 4, | ||
kernel_size: Union[int, tuple[int, int]] = 3, | ||
out_channels: int = 8, | ||
padding: Union[str, int, Collection[int]] = 0, | ||
stride: Union[int, tuple[int, int]] = 2, | ||
): | ||
super().__init__() | ||
|
||
self.conv = torch.nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
dilation=dilation, | ||
bias=bias, | ||
) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class Conv2dAndMaxPool2DModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.conv = torch.nn.Conv2d( | ||
in_channels=8, out_channels=32, kernel_size=5, bias=True | ||
) | ||
self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.maxpool(x) | ||
|
||
|
||
class Conv2dConstantPadNDModule(torch.nn.Module): | ||
def __init__(self, paddings: Collection[int], constant: float | int | None = None): | ||
super().__init__() | ||
self.pad = ConstantPadNDModule(paddings, constant) | ||
self.conv = Conv2dModule() | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.pad(x) | ||
|
||
|
||
class SoftmaxModule(torch.nn.Module): | ||
def __init__(self, dim: int): | ||
super().__init__() | ||
|
||
self.softmax = torch.nn.Softmax(dim=dim) | ||
|
||
def forward(self, x): | ||
return self.softmax(x) | ||
|
||
|
||
class SoftmaxConvModule(torch.nn.Module): | ||
def __init__(self, dim: int): | ||
super().__init__() | ||
|
||
self.conv = Conv2dModule() | ||
self.softmax = SoftmaxModule(dim=dim) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.softmax(x) | ||
|
||
|
||
class LinearModule(torch.nn.Module): | ||
def __init__(self, bias: bool): | ||
super().__init__() | ||
self.linear = torch.nn.Linear(32, 16, bias=bias) | ||
|
||
def forward(self, x): | ||
return self.linear(x) | ||
|
||
|
||
class LinearSoftmaxModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.linear = torch.nn.Linear(12, 10) | ||
self.softmax = torch.nn.Softmax(1) | ||
|
||
def forward(self, x): | ||
x = self.linear(x) | ||
x = self.softmax(x) | ||
|
||
return x | ||
|
||
|
||
class ConvFCSoftmaxModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) | ||
self.fc = torch.nn.Linear(1024, 10) | ||
self.softmax = torch.nn.Softmax(1) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = torch.reshape(x, (-1, 1024)) | ||
x = self.fc(x) | ||
x = self.softmax(x) | ||
|
||
return x | ||
|
||
|
||
class ConstantPadNDModule(torch.nn.Module): | ||
def __init__(self, paddings: Collection[int], constant: float | int | None = None): | ||
super().__init__() | ||
self.paddings = paddings | ||
self.constant = constant | ||
|
||
def forward(self, x): | ||
if self.constant is None: | ||
return torch.nn.functional.pad(x, tuple(self.paddings), "constant") | ||
else: | ||
return torch.nn.functional.pad( | ||
x, tuple(self.paddings), "constant", self.constant | ||
) | ||
|
||
|
||
class ConstantPadNDConvModule(torch.nn.Module): | ||
def __init__(self, paddings: Collection[int], constant: float | int | None = None): | ||
super().__init__() | ||
self.pad = ConstantPadNDModule(paddings, constant) | ||
self.conv = Conv2dModule() | ||
|
||
def forward(self, x): | ||
x = self.pad(x) | ||
return self.conv(x) | ||
|
||
|
||
class MaxPool2dModule(torch.nn.Module): | ||
def __init__(self, padding=0): | ||
super().__init__() | ||
|
||
self.max_pool2d = torch.nn.MaxPool2d( | ||
kernel_size=3, stride=2, padding=padding, dilation=1 | ||
) | ||
|
||
def forward(self, x): | ||
return self.max_pool2d(x) | ||
|
||
|
||
class MaxPool2dConvModule(torch.nn.Module): | ||
def __init__(self, padding=0): | ||
super().__init__() | ||
|
||
self.conv = Conv2dModule() | ||
self.max_pool2d = torch.nn.MaxPool2d( | ||
kernel_size=3, stride=2, padding=padding, dilation=1 | ||
) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.max_pool2d(x) | ||
|
||
|
||
class AvgPool2dModule(torch.nn.Module): | ||
def __init__(self, count_include_pad, padding=0): | ||
super().__init__() | ||
|
||
self.avg_pool = torch.nn.AvgPool2d( | ||
kernel_size=3, | ||
stride=2, | ||
padding=padding, | ||
count_include_pad=count_include_pad, | ||
) | ||
|
||
def forward(self, x): | ||
return self.avg_pool(x) | ||
|
||
|
||
class AvgPool2dConvModule(torch.nn.Module): | ||
def __init__(self, count_include_pad, padding=0): | ||
super().__init__() | ||
|
||
self.conv = Conv2dModule() | ||
self.avg_pool = torch.nn.AvgPool2d( | ||
kernel_size=3, | ||
stride=1, | ||
padding=padding, | ||
count_include_pad=count_include_pad, | ||
) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.avg_pool(x) | ||
|
||
|
||
class ReLUModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.relu = torch.nn.ReLU() | ||
|
||
def forward(self, x): | ||
return self.relu(x) | ||
|
||
|
||
class Conv2dReLUModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) | ||
self.relu = torch.nn.ReLU() | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return self.relu(x) | ||
|
||
|
||
class Conv2dPermuteModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
return torch.permute(x, [0, 2, 1, 3]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
# Copyright 2024 NXP | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Tests for NeutronQuantizer. | ||
|
||
import executorch.backends.nxp.tests.models as models | ||
import torch | ||
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer | ||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
|
||
|
||
def _get_target_name(node): | ||
return node._pretty_print_target(node.target) | ||
|
||
|
||
def test_quantizer_conv2d(): | ||
model = models.Conv2dModule() | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 4, 32, 32),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 11 | ||
assert nodes[7].name == "conv2d" | ||
# [0]: Input, [1] : weights, [2]: bias | ||
assert ( | ||
_get_target_name(nodes[7].args[0]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[7].args[1]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[7].args[2]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[8]) | ||
== "torch.ops.quantized_decomposed.quantize_per_tensor.default" | ||
) | ||
assert nodes[8].args[0].name == "conv2d" | ||
digantdesai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_quantizer_linear(): | ||
model = models.LinearModule(bias=True) | ||
model.eval() | ||
|
||
example_input = (torch.ones(10, 32),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 11 | ||
assert nodes[7].name == "linear" | ||
# [0]: Input, [1] : weights, [2]: bias | ||
assert ( | ||
_get_target_name(nodes[7].args[0]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[7].args[1]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[7].args[2]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[8]) | ||
== "torch.ops.quantized_decomposed.quantize_per_tensor.default" | ||
) | ||
assert nodes[8].args[0].name == "linear" | ||
|
||
|
||
def test_quantizer_maxpool2d(): | ||
model = models.Conv2dAndMaxPool2DModule() | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 8, 32, 32),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 14 | ||
# Check if QDQ pattern: | ||
assert nodes[10].name == "max_pool2d" | ||
assert ( | ||
_get_target_name(nodes[10].args[0]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[11]) | ||
== "torch.ops.quantized_decomposed.quantize_per_tensor.default" | ||
) | ||
assert nodes[11].args[0].name == "max_pool2d" | ||
|
||
# Check if input and output quantization is same | ||
input_quant = nodes[10].args[0].args[1:] | ||
output_quant = nodes[11].args[1:] | ||
assert input_quant == output_quant | ||
|
||
|
||
def test_quantizer_softmax(): | ||
model = models.SoftmaxModule(dim=0) | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 10),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 7 | ||
# Check if QDQ pattern: | ||
assert nodes[3].name == "softmax" | ||
assert ( | ||
_get_target_name(nodes[3].args[0]) | ||
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default" | ||
) | ||
assert ( | ||
_get_target_name(nodes[4]) | ||
== "torch.ops.quantized_decomposed.quantize_per_tensor.default" | ||
) | ||
assert nodes[4].args[0].name == "softmax" | ||
|
||
# Check output quantization | ||
scale, zp, _, _, dtype = nodes[4].args[1:] | ||
assert scale == 1.0 / 256.0 | ||
assert zp == -128 | ||
assert dtype == torch.int8 | ||
|
||
|
||
def test_quantizer_single_maxpool2d(): | ||
model = models.MaxPool2dModule() | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 4, 32, 32),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 3 | ||
assert nodes[1].name == "max_pool2d" | ||
assert "quantization_annotation" not in nodes[1].meta | ||
|
||
|
||
def test_quantizer_conv2d_relu(): | ||
model = models.Conv2dReLUModule() | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 4, 32, 32),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 12 | ||
assert nodes[7].name == "dequantize_per_tensor_default_2" | ||
assert nodes[8].name == "relu" | ||
assert nodes[9].name == "quantize_per_tensor_default_3" | ||
|
||
|
||
def test_quantizer_conv2d_avg_pool2d(): | ||
model = models.AvgPool2dConvModule(count_include_pad=False) | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 4, 16, 16),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 14 | ||
assert nodes[9].name == "dequantize_per_tensor_default_3" | ||
assert nodes[10].name == "avg_pool2d" | ||
assert nodes[11].name == "quantize_per_tensor_default_4" | ||
|
||
|
||
def test_quantizer_conv2d_permute(): | ||
model = models.Conv2dPermuteModule() | ||
model.eval() | ||
|
||
example_input = (torch.ones(1, 4, 16, 16),) | ||
quantizer = NeutronQuantizer() | ||
graph_module = torch.export.export_for_training( | ||
model, example_input, strict=True | ||
).module() | ||
|
||
# noinspection PyTypeChecker | ||
m = prepare_pt2e(graph_module, quantizer) | ||
m(*example_input) | ||
m = convert_pt2e(m) | ||
|
||
# Dry run | ||
m(*example_input) | ||
|
||
nodes = list(m.graph.nodes) | ||
assert len(nodes) == 12 | ||
assert nodes[7].name == "dequantize_per_tensor_default_2" | ||
assert nodes[8].name == "permute" | ||
assert nodes[9].name == "quantize_per_tensor_default_3" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.