Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from .save_load import save, load
42 changes: 42 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import torch

from neural_compressor.common.utils import load_config_mapping, save_config_mapping
from neural_compressor.torch.utils import QCONFIG_NAME, WEIGHT_NAME, logger


def save(model, example_inputs, output_dir="./saved_results"):
os.makedirs(output_dir, exist_ok=True)
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
quantized_ep = torch.export.export(model, example_inputs)
torch.export.save(quantized_ep, qmodel_file_path)
for key, op_config in model.qconfig.items():
model.qconfig[key] = op_config.to_dict()
with open(qconfig_file_path, "w") as f:
json.dump(model.qconfig, f, indent=4)

logger.info("Save quantized model to {}.".format(qmodel_file_path))
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))


def load(output_dir="./saved_results"):
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
loaded_quantized_ep = torch.export.load(qmodel_file_path)
return loaded_quantized_ep.module()
6 changes: 6 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def static_quant_entry(
def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand All @@ -221,6 +222,8 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


Expand All @@ -230,6 +233,7 @@ def pt2e_dynamic_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode
def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode, *args, **kwargs) -> torch.nn.Module:
logger.info("Quantize model with the PT2E static quant algorithm.")
from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
from neural_compressor.torch.algorithms.pt2e_quant.save_load import save

run_fn = kwargs.get("run_fn", None)
example_inputs = kwargs.get("example_inputs", None)
Expand All @@ -240,6 +244,8 @@ def pt2e_static_quant_entry(model: torch.nn.Module, configs_mapping, mode: Mode,
model = w8a8_quantizer.execute(
model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace
)
model.qconfig = configs_mapping
model.save = MethodType(save, model)
return model


Expand Down
5 changes: 5 additions & 0 deletions neural_compressor/torch/quantization/load_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
from neural_compressor.torch.algorithms import static_quant

return static_quant.load(model_name_or_path)
elif "static_quant" in per_op_qconfig.keys() or "pt2e_dynamic_quant" in per_op_qconfig.keys(): # PT2E
from neural_compressor.torch.algorithms import pt2e_quant

return pt2e_quant.load(model_name_or_path)
else:
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
# select load function
Expand All @@ -99,6 +103,7 @@ def load(model_name_or_path, original_model=None, format="default", device="cpu"
from neural_compressor.torch.algorithms import habana_fp8

return habana_fp8.load(model_name_or_path, original_model)

elif format == LoadFormat.HUGGINGFACE.value:
# now only support load huggingface WOQ causal language model
from neural_compressor.torch.algorithms import weight_only
Expand Down
18 changes: 15 additions & 3 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import unittest
from unittest.mock import patch
import shutil

import pytest
import torch
Expand Down Expand Up @@ -33,6 +31,8 @@ def _is_ipex_imported():


class TestPT2EQuantization:
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)

@staticmethod
def get_toy_model():
Expand Down Expand Up @@ -114,6 +114,18 @@ def calib_fn(model):
config.freezing = True
q_model_out = q_model(*example_inputs)
assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!"

# test save and load
q_model.save(
example_inputs=example_inputs,
output_dir="./saved_results",
)
from neural_compressor.torch.quantization import load

loaded_quantized_model = load("./saved_results")
loaded_q_model_out = loaded_quantized_model(*example_inputs)
assert torch.equal(loaded_q_model_out, q_model_out)

opt_model = torch.compile(q_model)
out = opt_model(*example_inputs)
logger.warning("out shape is %s", out.shape)
Expand Down