Skip to content

Commit def5ce3

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add required example_args argument to prepare_fx and prepare_qat_fx
Summary: X-link: pytorch/pytorch#77608 X-link: pytorch/fx2trt#76 X-link: facebookresearch/d2go#249 X-link: fairinternal/ClassyVision#104 X-link: pytorch/benchmark#916 X-link: facebookresearch/ClassyVision#791 Pull Request resolved: facebookresearch#68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 06a3020780ab9745abad3f069f35a66a8bec58be
1 parent f313fb8 commit def5ce3

File tree

6 files changed

+20
-7
lines changed

6 files changed

+20
-7
lines changed

mobile_cv/arch/tests/test_fbnet_v2_quantize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def test_qat(self):
8686
model.train()
8787

8888
qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig("fbgemm")}
89-
model_prepared = quantize_fx.prepare_qat_fx(model, qconfig_dict)
89+
example_inputs = (torch.rand(2, 3, 8, 8),)
90+
model_prepared = quantize_fx.prepare_qat_fx(
91+
model, qconfig_dict, example_inputs=example_inputs
92+
)
93+
9094
print(f"Prepared model {model_prepared}")
9195

9296
# calibration

mobile_cv/arch/tests/test_fbnet_v2_res_block.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def test_res_block_quantize_partial(self):
9494
data = torch.zeros(1, 8, 4, 4)
9595

9696
qconfig_dict = qu.get_qconfig_dict(model, qconfig)
97-
model = prepare_fx(model, qconfig_dict)
97+
example_inputs = (data,)
98+
model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
9899
model = convert_fx(model)
99100
print(model)
100101

mobile_cv/arch/tests/test_utils_quantize_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ def forward(self, x):
361361
model = MM().eval()
362362
qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
363363
qconfig_dict = qu.get_qconfig_dict(model, qconfig)
364-
model = prepare_fx(model, qconfig_dict)
364+
example_inputs = (torch.rand(1, 1, 3, 3),)
365+
model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
365366
model = convert_fx(model)
366367
print(model)
367368

mobile_cv/arch/utils/quantize_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,15 @@ def set_quant_config(self, quant_cfg):
186186
self.qconfig = quant_cfg
187187
return self
188188

189-
def prepare(self, qconfig_dict=None):
189+
def prepare(self, example_inputs, qconfig_dict=None):
190190
if qconfig_dict is None:
191191
qconfig_dict = get_qconfig_dict(self.model, self.qconfig)
192192
if qconfig_dict is None:
193193
qconfig_dict = {"": self.qconfig}
194194
self._prepared_model = torch.ao.quantization.quantize_fx.prepare_fx(
195-
self.model, qconfig_dict
195+
self.model,
196+
qconfig_dict=qconfig_dict,
197+
example_inputs=example_inputs,
196198
)
197199
return self
198200

mobile_cv/model_zoo/tools/create_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@ def convert_int8_jit(args, model, data, folder_name="int8_jit"):
140140
)
141141
else:
142142
quant = qu.PostQuantizationFX(model)
143+
example_inputs = tuple(data)
143144
quant_model = (
144145
quant.set_quant_backend("default")
145-
.prepare()
146+
.prepare(example_inputs=example_inputs)
146147
.calibrate_model([data], 1)
147148
.convert_model()
148149
)

mobile_cv/model_zoo/tools/model_exporter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,19 +309,23 @@ def export_to_torchscript_int8(
309309
):
310310
cur_loader = itertools.chain([inputs], data_iter)
311311

312+
example_inputs = tuple(inputs)
312313
if hasattr(task, "get_quantized_model"):
314+
print("calling get quantized model")
313315
ptq_model = task.get_quantized_model(model, cur_loader)
314316
model_attrs = _get_model_attributes(ptq_model)
317+
print("after calling get quantized model")
315318
elif args.use_graph_mode_quant:
316319
print(f"Post quantization using {args.post_quant_backend} backend fx mode...")
317320
model_attrs = _get_model_attributes(model)
318321
quant = quantize_utils.PostQuantizationFX(model)
319322
ptq_model = (
320323
quant.set_quant_backend(args.post_quant_backend)
321-
.prepare()
324+
.prepare(example_inputs=example_inputs)
322325
.calibrate_model(cur_loader, 1)
323326
.convert_model()
324327
)
328+
print("after calling callback")
325329
else:
326330
print(f"Post quantization using {args.post_quant_backend} backend...")
327331
qa_model = task.get_quantizable_model(model)

0 commit comments

Comments
 (0)