Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit e7c046c

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add required example_args argument to prepare_fx and prepare_qat_fx (#77608)
Summary: X-link: pytorch/pytorch#77608 X-link: pytorch/fx2trt#76 X-link: facebookresearch/d2go#249 X-link: fairinternal/ClassyVision#104 X-link: pytorch/benchmark#916 Pull Request resolved: #791 X-link: facebookresearch/mobile-vision#68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: m = resnet18(...) m = prepare_fx(m, qconfig_dict) After: m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 58c1e0afa7421ce79c164a31e88bb7dc4541f42b
1 parent 35c8f33 commit e7c046c

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

test/models_densenet_test.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,32 @@ def _test_quantize_model(self, model_config):
113113
_find_block_full_path(model.features, block_name)
114114
for block_name in heads.keys()
115115
]
116+
# TODO[quant-example-inputs]: The dimension here is random, if we need to
117+
# use dimension/rank in the future we'd need to get the correct dimensions
118+
standalone_example_inputs = (torch.randn(1, 3, 3, 3),)
116119
# we need to keep the modules used in head standalone since
117120
# it will be accessed with path name directly in execution
118121
prepare_custom_config_dict["standalone_module_name"] = [
119122
(
120123
head,
121124
{"": tq.default_qconfig},
125+
standalone_example_inputs,
122126
{"input_quantized_idxs": [0], "output_quantized_idxs": []},
123127
None,
124128
)
125129
for head in head_path_from_blocks
126130
]
127-
model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig})
131+
# TODO[quant-example-inputs]: The dimension here is random, if we need to
132+
# use dimension/rank in the future we'd need to get the correct dimensions
133+
example_inputs = (torch.randn(1, 3, 3, 3),)
134+
model.initial_block = prepare_fx(
135+
model.initial_block, {"": tq.default_qconfig}, example_inputs
136+
)
137+
128138
model.features = prepare_fx(
129139
model.features,
130140
{"": tq.default_qconfig},
141+
example_inputs,
131142
prepare_custom_config_dict,
132143
)
133144
model.set_heads(heads)
@@ -148,8 +159,8 @@ def test_small_densenet(self):
148159
self._test_model(MODELS["small_densenet"])
149160

150161
@unittest.skipIf(
151-
get_torch_version() < [1, 8],
152-
"FX Graph Modee Quantization is only availablee from 1.8",
162+
get_torch_version() < [1, 13],
163+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
153164
)
154165
def test_quantized_small_densenet(self):
155166
self._test_quantize_model(MODELS["small_densenet"])

test/models_mlp_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,19 @@ def test_build_model(self):
2626
self.assertEqual(output.shape, torch.Size([2, 1]))
2727

2828
@unittest.skipIf(
29-
get_torch_version() < [1, 8],
30-
"FX Graph Modee Quantization is only availablee from 1.8",
29+
get_torch_version() < [1, 13],
30+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
3131
)
3232
def test_quantize_model(self):
33-
if get_torch_version() >= [1, 11]:
34-
import torch.ao.quantization as tq
35-
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
36-
else:
37-
import torch.quantization as tq
38-
from torch.quantization.quantize_fx import convert_fx, prepare_fx
39-
33+
import torch.ao.quantization as tq
34+
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
4035
config = {"name": "mlp", "input_dim": 3, "output_dim": 1, "hidden_dims": [2]}
4136
model = build_model(config)
4237
self.assertTrue(isinstance(model, ClassyModel))
4338

4439
model.eval()
45-
model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig})
40+
example_inputs = (torch.rand(1, 3),)
41+
model.mlp = prepare_fx(model.mlp, {"": tq.default_qconfig}, example_inputs)
4642
model.mlp = convert_fx(model.mlp)
4743

4844
tensor = torch.tensor([[1, 2, 3]], dtype=torch.float)

test/models_regnet_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,25 +163,24 @@ def test_build_model(self, config):
163163
model = build_model(config)
164164
assert isinstance(model, RegNet)
165165

166+
@unittest.skipIf(
167+
get_torch_version() < [1, 13],
168+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
169+
)
166170
@parameterized.expand(REGNET_TEST_CONFIGS + REGNET_TEST_PRESETS)
167171
def test_quantize_model(self, config):
168172
"""
169173
Test that the model builds using a config using either model_params or
170174
model_name and calls fx graph mode quantization apis
171175
"""
172-
if get_torch_version() < [1, 8]:
173-
self.skipTest("FX Graph Modee Quantization is only availablee from 1.8")
174-
if get_torch_version() >= [1, 11]:
175-
import torch.ao.quantization as tq
176-
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
177-
else:
178-
import torch.quantization as tq
179-
from torch.quantization.quantize_fx import convert_fx, prepare_fx
176+
import torch.ao.quantization as tq
177+
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
180178

181179
model = build_model(config)
182180
assert isinstance(model, RegNet)
183181
model.eval()
184-
model.stem = prepare_fx(model.stem, {"": tq.default_qconfig})
182+
example_inputs = (torch.rand(1, 3, 3, 3),)
183+
model.stem = prepare_fx(model.stem, {"": tq.default_qconfig}, example_inputs)
185184
model.stem = convert_fx(model.stem)
186185

187186

test/models_resnext_test.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,29 @@ def _post_training_quantize(model, input):
107107
]
108108
# we need to keep the modules used in head standalone since
109109
# it will be accessed with path name directly in execution
110+
# TODO[quant-example-inputs]: Fix the shape if it is needed in quantization
111+
standalone_example_inputs = (torch.rand(1, 3, 3, 3),)
110112
prepare_custom_config_dict["standalone_module_name"] = [
111113
(
112114
head,
113115
{"": tq.default_qconfig},
116+
standalone_example_inputs,
114117
{"input_quantized_idxs": [0], "output_quantized_idxs": []},
115118
None,
116119
)
117120
for head in head_path_from_blocks
118121
]
119-
model.initial_block = prepare_fx(model.initial_block, {"": tq.default_qconfig})
122+
# TODO[quant-example-inputs]: Fix the shape if it is needed in quantization
123+
example_inputs = (torch.rand(1, 3, 3, 3),)
124+
model.initial_block = prepare_fx(
125+
model.initial_block, {"": tq.default_qconfig}, example_inputs
126+
)
127+
120128
model.blocks = prepare_fx(
121-
model.blocks, {"": tq.default_qconfig}, prepare_custom_config_dict
129+
model.blocks,
130+
{"": tq.default_qconfig},
131+
example_inputs,
132+
prepare_custom_config_dict,
122133
)
123134
model.set_heads(heads)
124135

@@ -222,8 +233,8 @@ def test_small_resnext(self):
222233
self._test_model(MODELS["small_resnext"])
223234

224235
@unittest.skipIf(
225-
get_torch_version() < [1, 8],
226-
"FX Graph Modee Quantization is only availablee from 1.8",
236+
get_torch_version() < [1, 13],
237+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
227238
)
228239
def test_quantized_small_resnext(self):
229240
self._test_quantize_model(MODELS["small_resnext"])
@@ -232,8 +243,8 @@ def test_small_resnet(self):
232243
self._test_model(MODELS["small_resnet"])
233244

234245
@unittest.skipIf(
235-
get_torch_version() < [1, 8],
236-
"FX Graph Modee Quantization is only availablee from 1.8",
246+
get_torch_version() < [1, 13],
247+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
237248
)
238249
def test_quantized_small_resnet(self):
239250
self._test_quantize_model(MODELS["small_resnet"])
@@ -242,8 +253,8 @@ def test_small_resnet_se(self):
242253
self._test_model(MODELS["small_resnet_se"])
243254

244255
@unittest.skipIf(
245-
get_torch_version() < [1, 8],
246-
"FX Graph Modee Quantization is only availablee from 1.8",
256+
get_torch_version() < [1, 13],
257+
"This test is using a new api of FX Graph Mode Quantization which is only available after 1.13"
247258
)
248259
def test_quantized_small_resnet_se(self):
249260
self._test_quantize_model(MODELS["small_resnet_se"])

0 commit comments

Comments
 (0)