Skip to content

[quant][fx] Move the remaining fixed qparam ops to backend_config_dict #75314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,12 @@ def test_op_io_dtype_coverage(self):
# embedding shadowing is not implemented, for now
continue
else:
if (
base_op in FUNS_UNMATCHABLE or
base_op in MODS_UNMATCHABLE or
base_op in METHS_UNMATCHABLE
):
continue
if qhandler_cls(None, {}).is_general_tensor_value_op():
self.assertTrue(
(base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
Expand All @@ -1615,6 +1621,9 @@ def test_op_io_dtype_coverage(self):
f"missing IO type handling for {base_op} using {qhandler_cls}")
else:
self.assertTrue(
(base_op in FUNS_IO_TYPE_FP32_OR_INT8) or
(base_op in MODS_IO_TYPE_FP32_OR_INT8) or
(base_op in METHS_IO_TYPE_FP32_OR_INT8) or
(base_op in FUNS_IO_TYPE_FP32) or
(base_op in MODS_IO_TYPE_FP32) or
(base_op in MODS_IO_TYPE_FP32_OR_INT8),
Expand Down
13 changes: 0 additions & 13 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4033,19 +4033,6 @@ def forward(self, x):
def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2):
self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr)

def test_fixed_qparams_patterns(self):
hard_sigmoid_keys = [torch.nn.functional.hardsigmoid, "hardsigmoid", "hardsigmoid_"]
sigmoid_keys = [torch.nn.Sigmoid, torch.sigmoid, "sigmoid", "sigmoid_"]
tanh_keys = [torch.nn.Tanh, torch.tanh, "tanh", "tanh_"]
for k in hard_sigmoid_keys + sigmoid_keys:
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP[k], default_affine_fixed_qparams_observer)
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[k],
default_affine_fixed_qparams_fake_quant)
for k in tanh_keys:
self.assertEqual(DEFAULT_OUTPUT_OBSERVER_MAP[k], default_symmetric_fixed_qparams_observer)
self._assertFixedQParamsFakeQuantizeEqual(DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[k],
default_symmetric_fixed_qparams_fake_quant)

def test_register_patterns(self):
@register_fusion_pattern("dummy_fusion")
class DummyFusion():
Expand Down
54 changes: 37 additions & 17 deletions torch/ao/quantization/fx/backend_config/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import torch.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
from ...observer import default_affine_fixed_qparams_observer
from ...observer import (
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
)
from ...fake_quantize import FixedQParamsFakeQuantize
from ...fuser_method_mappings import reverse_sequential_wrapper2

Expand Down Expand Up @@ -328,21 +331,38 @@ def _get_binary_op_configs():
return binary_op_configs


_HARDSIGMOID_MODULE_CONFIG = {
"pattern": torch.nn.Hardsigmoid,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
# TODO: The following two keys are temporary, since we don't want to put observer in the configs
# we expect that it's provided by user
# What we want to put here is the requirement on observers, in this case dtype,
# quant_min, quant_max etc., but we need to first move all configs to
# backend_config_dict to do that, we'll remove these keys after we fully migrated
# everything to use backend_config_dict
"_overwrite_output_fake_quantizer": FixedQParamsFakeQuantize.with_args(observer=default_affine_fixed_qparams_observer),
"_overwrite_output_observer": default_affine_fixed_qparams_observer,
"dtype_configs": [
weighted_op_int8_dtype_config,
],
}
def _get_fixed_qparams_op_configs():
fixed_qparams_op_configs = []
for fixed_qparam_op, output_observer in [
(torch.nn.Hardsigmoid, default_affine_fixed_qparams_observer),
(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer),
("hardsigmoid", default_affine_fixed_qparams_observer),
("hardsigmoid_", default_affine_fixed_qparams_observer),
(torch.nn.Sigmoid, default_affine_fixed_qparams_observer),
(torch.sigmoid, default_affine_fixed_qparams_observer),
("sigmoid", default_affine_fixed_qparams_observer),
("sigmoid_", default_affine_fixed_qparams_observer),
(torch.nn.Tanh, default_symmetric_fixed_qparams_observer),
(torch.tanh, default_symmetric_fixed_qparams_observer),
("tanh", default_symmetric_fixed_qparams_observer),
("tanh_", default_symmetric_fixed_qparams_observer),
]:
fixed_qparams_op_configs.append({
"pattern": fixed_qparam_op,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
# TODO: The following two keys are temporary, since we don't want to put observer in the configs
# we expect that it's provided by user
# What we want to put here is the requirement on observers, in this case dtype,
# quant_min, quant_max etc., but we need to first move all configs to
# backend_config_dict to do that, we'll remove these keys after we fully migrated
# everything to use backend_config_dict
"_overwrite_output_fake_quantizer": FixedQParamsFakeQuantize.with_args(observer=output_observer),
"_overwrite_output_observer": output_observer,
"dtype_configs": [
weighted_op_int8_dtype_config,
],
})
return fixed_qparams_op_configs

_CAT_CONFIG = {
"pattern": torch.cat,
Expand Down Expand Up @@ -398,7 +418,7 @@ def get_native_backend_config_dict():
*_get_linear_configs(),
*_get_conv_configs(),
*_get_binary_op_configs(),
_HARDSIGMOID_MODULE_CONFIG,
*_get_fixed_qparams_op_configs(),
_CAT_CONFIG,
*_get_bn_configs(),
],
Expand Down
30 changes: 2 additions & 28 deletions torch/ao/quantization/fx/quantization_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@
from torch.fx.graph import (
Node,
)
from ..observer import (
default_affine_fixed_qparams_observer,
default_symmetric_fixed_qparams_observer,
)

from ..utils import (
activation_dtype,
)

from .pattern_utils import (
register_quant_pattern,
get_default_output_activation_post_process_map,
Pattern,
)
from .utils import (
Expand Down Expand Up @@ -158,26 +149,9 @@ class DefaultNodeQuantizeHandler(QuantizeHandler):
"""
pass

@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer)
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_observer)
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_observer)
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_observer)
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_observer)
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_observer)
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_observer)
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_observer)
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_observer)
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_observer)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_observer)
# TODO: remove this class
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
# some qhandlers override the activations constructor
def get_activation_ctr(self, qconfig, pattern, is_training) -> Optional[Callable]:
act_dtype = activation_dtype(qconfig)
if act_dtype == torch.quint8:
return get_default_output_activation_post_process_map(is_training).get(
pattern, qconfig.activation)
else:
return qconfig.activation
pass

@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
Expand Down