Skip to content

Commit ed5cdb7

Browse files
HDCharlesfacebook-github-bot
authored andcommitted
[ao][sparsity] make sparsity and PTQ compose (pytorch#74845)
Summary: Pull Request resolved: pytorch#74845 This PR adds support for quantization flow to detect parametrized modules and match them using their original module types. This mainly involved using the new type_before_parametrizations function rather than type to check for module mathcing Test Plan: python test/test_ao_sparsity.py TestComposability Imported from OSS Reviewed By: jerryzh168 Differential Revision: D35240274 fbshipit-source-id: 7294d89c9c2e069e51d8b9bafa45c15f92bed124
1 parent 8b8ed7b commit ed5cdb7

File tree

6 files changed

+136
-12
lines changed

6 files changed

+136
-12
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
# Owner(s): ["module: unknown"]
3+
4+
5+
import logging
6+
7+
import torch
8+
import torch.quantization as tq
9+
from torch import nn
10+
from torch.ao import sparsity
11+
from torch.testing._internal.common_utils import TestCase
12+
13+
logging.basicConfig(
14+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
15+
)
16+
17+
sparse_defaults = {
18+
"sparsity_level": 0.8,
19+
"sparse_block_shape": (1, 4),
20+
"zeros_per_block": 4,
21+
}
22+
23+
24+
class TestComposability(TestCase):
25+
def _get_model_and_sparsifier_and_sparse_config(self):
26+
model = nn.Sequential(
27+
nn.Linear(4, 4), # 0
28+
nn.ReLU(),
29+
nn.Linear(4, 4), # 2
30+
nn.ReLU(),
31+
tq.QuantStub(),
32+
nn.Linear(4, 4), # 5
33+
nn.Identity(),
34+
# nn.ReLU(), not testing fusion yet
35+
tq.DeQuantStub(),
36+
)
37+
model[5].qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
38+
model[4].qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
39+
40+
sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)
41+
42+
sparse_config = [
43+
{
44+
"module": model[5],
45+
"sparsity_level": 0.7,
46+
"sparse_block_shape": (1, 4),
47+
"zeros_per_block": 4,
48+
},
49+
model[0],
50+
]
51+
return model, sparsifier, sparse_config
52+
53+
def _check_parametrizations_and_observers(self, model):
54+
self.assertTrue(hasattr(model[0], "parametrizations"))
55+
self.assertTrue(hasattr(model[5], "parametrizations"))
56+
self.assertTrue(hasattr(model[5], "activation_post_process"))
57+
58+
def _squash_mask_calibrate_and_convert(self, model, sparsifier, input):
59+
sparsifier.step()
60+
sparsifier.squash_mask()
61+
model(input)
62+
tq.convert(model, inplace=True)
63+
64+
def test_q_prep_before_s_prep(self):
65+
(
66+
mod,
67+
sparsifier,
68+
sparse_config,
69+
) = self._get_model_and_sparsifier_and_sparse_config()
70+
71+
tq.prepare(mod, inplace=True)
72+
sparsifier.prepare(mod, config=sparse_config)
73+
self._check_parametrizations_and_observers(mod)
74+
self._squash_mask_calibrate_and_convert(
75+
mod, sparsifier, torch.randn(1, 4, 4, 4)
76+
)
77+
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
78+
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
79+
80+
def test_s_prep_before_q_prep(self):
81+
(
82+
mod,
83+
sparsifier,
84+
sparse_config,
85+
) = self._get_model_and_sparsifier_and_sparse_config()
86+
87+
sparsifier.prepare(mod, config=sparse_config)
88+
torch.quantization.prepare(mod, inplace=True)
89+
self._check_parametrizations_and_observers(mod)
90+
self._squash_mask_calibrate_and_convert(
91+
mod, sparsifier, torch.randn(1, 4, 4, 4)
92+
)
93+
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
94+
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

test/test_ao_sparsity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@
2020
# Scheduler
2121
from ao.sparsity.test_scheduler import TestScheduler # noqa: F401
2222

23+
# Composability
24+
from ao.sparsity.test_composability import TestComposability # noqa: F401
25+
2326
if __name__ == '__main__':
2427
run_tests()

torch/ao/quantization/quantization_mappings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
default_symmetric_fixed_qparams_fake_quant,
2424
)
2525
from torch.ao.quantization.utils import get_combined_dict
26+
from torch.nn.utils.parametrize import type_before_parametrizations
2627

2728
# Default map for swapping float module to reference quantized modules
2829
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
@@ -306,7 +307,7 @@ def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]
306307
input: torch.nn.Sigmoid
307308
output: default_affine_fixed_qparam_fake_quant
308309
"""
309-
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type(module), None)
310+
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module), None)
310311

311312
def _has_special_act_post_process(module: torch.nn.Module) -> bool:
312313
return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS

torch/ao/quantization/quantize.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_has_special_act_post_process,
1818
_get_special_act_post_process,
1919
)
20-
from .utils import get_qparam_dict
20+
from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
2121
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
2222
from torch.ao.quantization.qconfig import (
2323
add_module_to_qconfig_obs_ctr,
@@ -26,6 +26,7 @@
2626
float_qparams_weight_only_qconfig,
2727
float_qparams_weight_only_qconfig_4bit,
2828
activation_is_memoryless)
29+
from torch.nn.utils.parametrize import type_before_parametrizations
2930

3031
def is_activation_post_process(module):
3132
return (isinstance(module, torch.ao.quantization.ObserverBase) or
@@ -170,9 +171,9 @@ def insert_activation_post_process(m, special_act_post_process=None):
170171

171172
for name, child in module.named_children():
172173
# TODO remove Dropout special after codebase stable
173-
if type(child) in [nn.Dropout]:
174+
if type_before_parametrizations(child) in [nn.Dropout]:
174175
continue
175-
elif type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
176+
elif type_before_parametrizations(child) in [nnq.FloatFunctional, nnq.QFunctional]:
176177
if needs_observation(child):
177178
child.activation_post_process = get_activation_post_process(child.qconfig, device)
178179
elif isinstance(child, _FusedModule):
@@ -182,23 +183,23 @@ def insert_activation_post_process(m, special_act_post_process=None):
182183
elif _has_special_act_post_process(child):
183184
special_act_post_process = _get_special_act_post_process(child)
184185
insert_activation_post_process(child, special_act_post_process)
185-
elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
186+
elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
186187
if needs_observation(child):
187188
insert_activation_post_process(child)
188-
elif needs_observation(child) and type(child) in custom_module_class_mapping:
189-
observed_child = custom_module_class_mapping[type(child)].from_float(child)
189+
elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
190+
observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
190191
setattr(module, name, observed_child)
191192
# TODO: These are the modules that cannot be observed
192193
# Once there are more, we should move them to a separate list
193-
if custom_module_class_mapping[type(child)] not in no_observer_set():
194+
if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
194195
insert_activation_post_process(observed_child)
195196
else:
196197
add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
197198

198199
# Insert observers only for leaf nodes, note that this observer is for
199200
# the output of the module, for input QuantStub will observe them
200-
if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
201-
and type(module) in qconfig_propagation_list:
201+
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
202+
and type_before_parametrizations(module) in qconfig_propagation_list:
202203
insert_activation_post_process(module)
203204

204205
def get_unique_devices_(module):
@@ -220,7 +221,7 @@ def add_quant_dequant(module):
220221
wraps the input module, the latter case only happens when the input
221222
module is a leaf module and we want to quantize it.
222223
"""
223-
if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
224+
if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
224225
return QuantWrapper(module)
225226

226227
for name, child in module.named_children():

torch/ao/quantization/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
88
from typing import Tuple, Any, Union, Callable
9+
from torch.nn.utils.parametrize import is_parametrized
910

1011
# Type for fusion patterns, it can be more complicated than the following actually,
1112
# see pattern.md for docs
@@ -356,3 +357,16 @@ def _parent_name(target):
356357
return '', r[0]
357358
else:
358359
return r[0], r[1]
360+
361+
def has_no_children_ignoring_parametrizations(module):
362+
"""
363+
Checks if module._modules is empty or
364+
if module is a parametrization, checks that module._modules only has
365+
the 'parametrizations' module
366+
"""
367+
if len(module._modules) == 0:
368+
return True
369+
elif is_parametrized(module):
370+
return len(module._modules) == 1 and 'parametrizations' in module._modules
371+
else:
372+
return False

torch/nn/utils/parametrize.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
573573
else:
574574
return tensor_name in parametrizations
575575

576-
577576
def remove_parametrizations(
578577
module: Module, tensor_name: str, leave_parametrized: bool = True
579578
) -> Module:
@@ -644,3 +643,15 @@ def remove_parametrizations(
644643
orig_cls = module.__class__.__bases__[0]
645644
module.__class__ = orig_cls
646645
return module
646+
647+
def type_before_parametrizations(module: Module) -> type:
648+
r"""Returns the module type before parametrizations were applied and if not,
649+
then it returns the module type.
650+
651+
Args:
652+
module (nn.Module): module to get type of
653+
"""
654+
if is_parametrized(module):
655+
return module.__class__.__bases__[0]
656+
else:
657+
return type(module)

0 commit comments

Comments
 (0)