Skip to content

Commit 296c5d4

Browse files
authored
Add docstring for PT2E and HQQ (#1937)
Signed-off-by: yiliu30 <[email protected]>
1 parent 437c8e7 commit 296c5d4

File tree

17 files changed

+454
-114
lines changed

17 files changed

+454
-114
lines changed

.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
/neural-compressor/neural_compressor/strategy
1616
/neural-compressor/neural_compressor/training.py
1717
/neural-compressor/neural_compressor/utils
18+
/neural_compressor/torch/algorithms/pt2e_quant
19+
/neural_compressor/torch/export
20+
/neural_compressor/common
21+
/neural_compressor/torch/algorithms/weight_only/hqq

neural_compressor/torch/algorithms/pt2e_quant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""The PT2E-related modules."""
1415

1516

1617
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer

neural_compressor/torch/algorithms/pt2e_quant/core.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# Some code snippets are taken from the X86InductorQuantizer tutorial.
1616
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html
17-
17+
"""The quantizer using PT2E path."""
1818

1919
from typing import Any
2020

@@ -30,13 +30,24 @@
3030

3131

3232
class W8A8PT2EQuantizer(Quantizer):
33+
"""The W8A8 quantizer using PT2E."""
34+
3335
is_dynamic = False
3436

3537
def __init__(self, quant_config=None):
38+
"""Initialize the quantizer."""
3639
super().__init__(quant_config)
3740

3841
@staticmethod
3942
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
43+
"""Updates the quantizer based on the given quantization configuration.
44+
45+
Args:
46+
quant_config (dict): The quantization configuration. Defaults to None.
47+
48+
Returns:
49+
X86InductorQuantizer: The updated quantizer object.
50+
"""
4051
if not quant_config:
4152
quantizer = X86InductorQuantizer()
4253
quantizer.set_global(
@@ -47,9 +58,18 @@ def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuan
4758
return quantizer
4859

4960
def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
50-
"""Prepare the model for calibration.
61+
"""Prepares the model for calibration.
5162
5263
Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.
64+
65+
Args:
66+
model (GraphModule): The model to be prepared for calibration.
67+
example_inputs (tuple, optional): Example inputs to be used for calibration. Defaults to None.
68+
inplace (bool, optional): Whether to modify the model in-place or return a new prepared model.
69+
Defaults to True.
70+
71+
Returns:
72+
GraphModule: The prepared model.
5373
"""
5474
quant_config = self.quant_config
5575
assert model._exported, "The model should be exported before preparing it for calibration."
@@ -58,7 +78,14 @@ def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args,
5878
return prepared_model
5979

6080
def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
61-
"""Convert the calibrated model into qdq mode."""
81+
"""Convert the calibrated model into qdq mode.
82+
83+
Args:
84+
model (GraphModule): The prepared model.
85+
86+
Returns:
87+
GraphModule: The converted quantized model.
88+
"""
6289
fold_quantize = kwargs.get("fold_quantize", False)
6390
converted_model = convert_pt2e(model, fold_quantize=fold_quantize)
6491
logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.")
@@ -67,6 +94,12 @@ def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
6794
return converted_model
6895

6996
def half_precision_transformation(self, model, config):
97+
"""Applies half-precision transformation to the given model in-place.
98+
99+
Args:
100+
model: The model to apply the transformation to.
101+
config: The configuration for the transformation.
102+
"""
70103
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
71104
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
72105
hp_rewriter.transformation(model, half_precision_node_set)

neural_compressor/torch/algorithms/pt2e_quant/half_precision_rewriter.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Rewrite the FP32 operators to FP16 or BF16 operators."""
1415

1516
from dataclasses import dataclass
1617
from functools import partial
@@ -34,6 +35,14 @@
3435

3536
@dataclass
3637
class PatternPair:
38+
"""Represents a pair of patterns used for search and replacement in a graph.
39+
40+
Attributes:
41+
fn (TorchFuncType): The function type associated with the pattern pair.
42+
search_pattern (torch.fx.GraphModule): The search pattern to be matched in the graph.
43+
replace_pattern (torch.fx.GraphModule): The replacement pattern to be used when a match is found.
44+
"""
45+
3746
fn: TorchFuncType
3847
search_pattern: torch.fx.GraphModule
3948
replace_pattern: torch.fx.GraphModule
@@ -101,6 +110,15 @@ def _register_pattern_pair(dtype: torch.dtype) -> None:
101110

102111

103112
def get_filter_fn(node_list, fn):
113+
"""Filter function to check if a node with the target operator is in the given `node_list`.
114+
115+
Args:
116+
node_list (list): List of nodes to check against.
117+
fn (str): Target operator.
118+
119+
Returns:
120+
bool: True if the node with the target operator is in the `node_list`, False otherwise.
121+
"""
104122
target_op = FN_ATEN_OPS_MAPPING[fn]
105123

106124
def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
@@ -119,6 +137,16 @@ def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
119137

120138

121139
def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list):
140+
"""Applies a single pattern pair to a given GraphModule.
141+
142+
Args:
143+
gm (torch.fx.GraphModule): The GraphModule to apply the pattern pair to.
144+
pattern_pair (PatternPair): The pattern pair containing the search and replace patterns.
145+
node_list: The list of nodes to filter for pattern matching.
146+
147+
Returns:
148+
List[Match]: A list of Match objects representing the matches found after applying the pattern pair.
149+
"""
122150
filter_fn = get_filter_fn(node_list, pattern_pair.fn)
123151
match_and_replacements = subgraph_rewriter.replace_pattern_with_filters(
124152
gm=gm,
@@ -133,6 +161,14 @@ def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPai
133161

134162

135163
def get_unquantized_node_set(gm: torch.fx.GraphModule):
164+
"""Retrieves the set of unquantized nodes from a given GraphModule.
165+
166+
Args:
167+
gm (torch.fx.GraphModule): The GraphModule to retrieve unquantized nodes from.
168+
169+
Returns:
170+
set: A set containing the unquantized nodes.
171+
"""
136172
unquantized_node_set = set()
137173
for node in gm.graph.nodes:
138174
if meta := getattr(node, "meta"):
@@ -180,7 +216,17 @@ def _parse_node_candidate_set_from_user_config(config, gm):
180216

181217

182218
def get_half_precision_node_set(gm, config):
183-
"""Intersection between `unquantized_node_set` and `node_set_from_user_config`"""
219+
"""Retrieves a set of nodes from the given graph model (gm) that are candidates for conversion to half precision.
220+
221+
The result is the intersection between `unquantized_node_set` and `node_set_from_user_config`.
222+
223+
Args:
224+
gm (GraphModel): The graph model to search for nodes.
225+
config (dict): User configuration for node candidate set.
226+
227+
Returns:
228+
set: A set of nodes that are candidates for conversion to half precision.
229+
"""
184230
# TODO: implement it, current return all unquantized_node_set
185231

186232
node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm)

neural_compressor/torch/algorithms/pt2e_quant/save_load.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Save and load the quantized model."""
15+
1416

1517
import json
1618
import os
@@ -22,6 +24,13 @@
2224

2325

2426
def save(model, example_inputs, output_dir="./saved_results"):
27+
"""Save the quantized model and its configuration.
28+
29+
Args:
30+
model (torch.nn.Module): The quantized model to be saved.
31+
example_inputs (torch.Tensor or tuple of torch.Tensor): Example inputs used for tracing the model.
32+
output_dir (str, optional): The directory where the saved results will be stored. Defaults to "./saved_results".
33+
"""
2534
os.makedirs(output_dir, exist_ok=True)
2635
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
2736
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
@@ -37,6 +46,14 @@ def save(model, example_inputs, output_dir="./saved_results"):
3746

3847

3948
def load(output_dir="./saved_results"):
49+
"""Load a quantized model from the specified output directory.
50+
51+
Args:
52+
output_dir (str): The directory where the quantized model is saved. Defaults to "./saved_results".
53+
54+
Returns:
55+
torch.nn.Module: The loaded quantized model.
56+
"""
4057
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
4158
loaded_quantized_ep = torch.export.load(qmodel_file_path)
4259
return loaded_quantized_ep.module()

neural_compressor/torch/algorithms/pt2e_quant/utility.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""Utility functions for PT2E quantization."""
1415

1516
from typing import Dict
1617

@@ -24,6 +25,18 @@
2425

2526

2627
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
28+
"""Create a quantization specification based on the given configuration.
29+
30+
Args:
31+
dtype (str): The desired data type for quantization. Valid options are "int8" and "uint8".
32+
sym (bool): Whether to use symmetric quantization or not.
33+
granularity (str): The granularity of quantization. Valid options are "per_channel" and "per_tensor".
34+
algo (str): The algorithm to use for quantization. Valid options are "placeholder", "minmax", and "kl".
35+
is_dynamic (bool, optional): Whether to use dynamic quantization or not. Defaults to False.
36+
37+
Returns:
38+
QuantizationSpec: The created quantization specification.
39+
"""
2740
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
2841
select_dtype = dtype_mapping[dtype]
2942
min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)}
@@ -76,6 +89,15 @@ def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> Quant
7689

7790

7891
def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer:
92+
"""Creates an instance of X86InductorQuantizer based on the given configuration.
93+
94+
Args:
95+
config: The configuration object containing the quantization settings.
96+
is_dynamic: A boolean indicating whether dynamic quantization is enabled.
97+
98+
Returns:
99+
An instance of X86InductorQuantizer initialized with the provided configuration.
100+
"""
79101
quantizer = xiq.X86InductorQuantizer()
80102
# set global
81103
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)

neural_compressor/torch/algorithms/weight_only/hqq/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""HQQ-related modules."""
1415

1516
from .quantizer import HQQuantizer
1617
from .config import HQQModuleConfig, QTensorConfig

0 commit comments

Comments
 (0)