Skip to content

Commit 9e0b8bb

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][fx][bc-breaking] Add required example_inputs argument to prepare_fx and prepare_qat_fx (pytorch#77608)
Summary: Pull Request resolved: pytorch#77608 X-link: pytorch/fx2trt#76 X-link: facebookresearch/d2go#249 X-link: fairinternal/ClassyVision#104 X-link: pytorch/benchmark#916 X-link: facebookresearch/ClassyVision#791 X-link: facebookresearch/mobile-vision#68 FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors. Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base. As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch#76496 (comment) (comment) simpler, but it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now. If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to pass the arguments by keyword BC-breaking Note: Before: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict) # or m = prepare_qat_fx(m, qconfig_dict) ``` After: ```python m = resnet18(...) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) # or m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),)) ``` Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestQuantizeFxModels Imported from OSS **Static Docs Preview: classyvision** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D35984526/V44/classyvision/)| |**Modified Pages**| Reviewed By: vkuzo, andrewor14 Differential Revision: D35984526 fbshipit-source-id: 716e5992ebe99cfb90be669357f56b214d692aef
1 parent c0abd83 commit 9e0b8bb

File tree

9 files changed

+536
-340
lines changed

9 files changed

+536
-340
lines changed

test/quantization/bc/test_backward_compatibility.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,10 @@ def _do_quant_transforms(
171171
m: torch.nn.Module,
172172
input_tensor: torch.Tensor,
173173
) -> torch.nn.Module:
174+
example_inputs = (input_tensor,)
174175
# do the quantizaton transforms and save result
175176
qconfig = torch.quantization.get_default_qconfig('fbgemm')
176-
mp = quantize_fx.prepare_fx(m, {'': qconfig})
177+
mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs)
177178
mp(input_tensor)
178179
mq = quantize_fx.convert_fx(mp)
179180
return mq

test/quantization/dbr/test_quantize_dbr.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _test_auto_tracing(
8282

8383
# compare it against FX
8484
if do_fx_comparison:
85-
m_copy_p = prepare_fx(m_copy, {'': qconfig})
85+
m_copy_p = prepare_fx(m_copy, {'': qconfig}, example_inputs=example_args)
8686
out_m_copy_p = m_copy_p(*example_args)
8787
# print(m_copy_p)
8888
m_copy_q = convert_fx(m_copy_p)
@@ -1236,11 +1236,11 @@ def test_qconfig_dict_unsupported_does_not_crash_when_empty(self):
12361236
"""
12371237
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
12381238
qconfig_dict = {'': torch.quantization.default_qconfig}
1239+
example_inputs = (torch.randn(1, 1, 1, 1),)
12391240
# this modifies qconfig_dict inplace to include more keys
1240-
mp = prepare_fx(m, qconfig_dict)
1241-
example_args = (torch.randn(1, 1, 1, 1),)
1241+
mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
12421242
# need this line to not crash
1243-
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
1243+
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)
12441244

12451245
def _test_serialization(self, model, input_shape):
12461246
example_inputs = (torch.randn(*input_shape),)
@@ -1324,15 +1324,15 @@ def test_jit_tracing_removes_aliases(self):
13241324
),
13251325
)
13261326
qconfig_dict = {'': torch.quantization.default_qconfig}
1327-
example_args = (torch.randn(1, 1, 1, 1),)
1328-
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
1327+
example_inputs = (torch.randn(1, 1, 1, 1),)
1328+
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)
13291329
mq = _quantize_dbr.convert(mp)
1330-
mqs = torch.jit.trace(mq, example_args)
1330+
mqs = torch.jit.trace(mq, example_inputs)
13311331
FileCheck().check_count("aten::alias", 5, exactly=True).run(
13321332
mqs.inlined_graph)
1333-
res1 = mqs(*example_args)
1333+
res1 = mqs(*example_inputs)
13341334
mqs = remove_redundant_aliases(mqs)
1335-
res2 = mqs(*example_args)
1335+
res2 = mqs(*example_inputs)
13361336
self.assertTrue(torch.allclose(res1, res2))
13371337
# TODO(future PR): figure out why aliasing still appears in the inlined
13381338
# graph, and if that is fixed then just check the inlined graph.
@@ -1609,11 +1609,11 @@ def test_mobilenet_v2_removes_aliases(self):
16091609
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False)\
16101610
.eval().float()
16111611
qconfig_dict = {'': torch.quantization.default_qconfig}
1612-
example_args = (torch.randn(1, 3, 224, 224),)
1613-
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
1612+
example_inputs = (torch.randn(1, 3, 224, 224),)
1613+
mp = _quantize_dbr.prepare(m, qconfig_dict, example_inputs)
16141614
mq = _quantize_dbr.convert(mp)
1615-
mqs = torch.jit.trace(mq, example_args)
1616-
res1 = mqs(*example_args)
1615+
mqs = torch.jit.trace(mq, example_inputs)
1616+
res1 = mqs(*example_inputs)
16171617
mqs = remove_redundant_aliases(mqs)
1618-
res2 = mqs(*example_args)
1618+
res2 = mqs(*example_inputs)
16191619
self.assertTrue(torch.allclose(res1, res2))

test/quantization/fx/test_equalize_fx.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,14 @@ def test_input_weight_equalization_prepare(self):
274274

275275
for (M, node_occurrence) in tests:
276276
m = M().eval()
277-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
277+
# TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test
278+
# for now we do not need shape so this can be fixed later
279+
example_inputs = (torch.randn(1, 1, 1, 1),)
280+
prepared = prepare_fx(
281+
m,
282+
specific_qconfig_dict,
283+
example_inputs=example_inputs,
284+
equalization_qconfig_dict=default_equalization_qconfig_dict)
278285
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
279286

280287
def test_input_weight_equalization_branching(self):
@@ -305,7 +312,10 @@ def forward(self, x):
305312
}
306313

307314
m = TestBranchingWithoutEqualizationModel().eval()
308-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
315+
example_inputs = (torch.randn(1, 5),)
316+
prepared = prepare_fx(
317+
m, specific_qconfig_dict, example_inputs=example_inputs,
318+
equalization_qconfig_dict=default_equalization_qconfig_dict)
309319
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_eq_branching_node_occurrence)
310320

311321
# Tests that we will add an equalization observer because there is only
@@ -326,7 +336,10 @@ def forward(self, x):
326336
}
327337

328338
m = TestBranchingWithEqualizationModel().eval()
329-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
339+
example_inputs = (torch.randn(1, 5),)
340+
prepared = prepare_fx(
341+
m, specific_qconfig_dict, example_inputs=example_inputs,
342+
equalization_qconfig_dict=default_equalization_qconfig_dict)
330343
self.checkGraphModuleNodes(prepared, expected_node_occurrence=eq_branching_node_occurrence)
331344

332345
@skipIfNoFBGEMM
@@ -353,17 +366,22 @@ def test_input_weight_equalization_convert(self):
353366
elif ndim == 4:
354367
x = torch.rand((16, 3, 224, 224))
355368

369+
example_inputs = (x,)
356370
prepared = prepare_fx(
357371
copy.deepcopy(m),
358372
specific_qconfig_dict,
373+
example_inputs=example_inputs,
359374
equalization_qconfig_dict=default_equalization_qconfig_dict
360375
)
361376
output = prepared(x)
362377

363378
convert_ref = _convert_equalization_ref(prepared)
364379
convert_ref_output = convert_ref(x)
365380

366-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
381+
prepared = prepare_fx(
382+
m, specific_qconfig_dict,
383+
example_inputs=example_inputs,
384+
equalization_qconfig_dict=default_equalization_qconfig_dict)
367385
prepared(x)
368386
convert_fx(prepared) # Check if compile
369387
self.assertEqual(output, convert_ref_output)
@@ -411,8 +429,12 @@ def test_input_weight_equalization_equalization_scales(self):
411429
m = M().eval()
412430
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
413431

414-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
415-
prepared(x)
432+
example_inputs = (x,)
433+
prepared = prepare_fx(
434+
m, specific_qconfig_dict,
435+
example_inputs=example_inputs,
436+
equalization_qconfig_dict=default_equalization_qconfig_dict)
437+
prepared(*example_inputs)
416438
convert_ref = _convert_equalization_ref(prepared)
417439
convert_ref(x)
418440

@@ -460,7 +482,11 @@ def test_input_weight_equalization_weights_bias(self):
460482
exp_eq_scales = self.get_expected_eq_scales(m, x.detach().numpy())
461483
exp_weights, exp_bias = self.get_expected_weights_bias(m, x.detach().numpy(), exp_eq_scales)
462484

463-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
485+
example_inputs = (x,)
486+
prepared = prepare_fx(
487+
m, specific_qconfig_dict,
488+
example_inputs=example_inputs,
489+
equalization_qconfig_dict=default_equalization_qconfig_dict)
464490
prepared(x)
465491
convert_ref = _convert_equalization_ref(prepared)
466492
convert_ref(x)
@@ -516,7 +542,11 @@ def test_input_weight_equalization_activation_values(self):
516542
exp_inp_act_vals = self.get_expected_inp_act_vals(m, x, exp_eq_scales, exp_weights, exp_bias)
517543
exp_weight_act_vals = self.get_expected_weight_act_vals(exp_weights)
518544

519-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
545+
example_inputs = (x,)
546+
prepared = prepare_fx(
547+
m, specific_qconfig_dict,
548+
example_inputs=example_inputs,
549+
equalization_qconfig_dict=default_equalization_qconfig_dict)
520550
prepared(x)
521551
convert_ref = _convert_equalization_ref(prepared)
522552
convert_ref(x)
@@ -751,7 +781,13 @@ def test_input_weight_equalization_graphs(self):
751781

752782
for (M, node_list) in tests:
753783
m = M().eval()
754-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
784+
# TODO[quant-example-inputs]: if shape is important we need to define a example_inputs for each test
785+
# for now we do not need shape so this can be fixed later
786+
example_inputs = (torch.randn(1, 1, 1, 1),)
787+
prepared = prepare_fx(
788+
m, specific_qconfig_dict,
789+
example_inputs=example_inputs,
790+
equalization_qconfig_dict=default_equalization_qconfig_dict)
755791
equalized_quantized_model = convert_fx(prepared)
756792

757793
# Check the order of nodes in the graph
@@ -771,7 +807,12 @@ def test_input_weight_equalization_results(self):
771807
m = M().eval()
772808

773809
# No equalization
774-
prepared = prepare_fx(copy.deepcopy(m), specific_qconfig_dict, equalization_qconfig_dict={})
810+
example_inputs = (x,)
811+
prepared = prepare_fx(
812+
copy.deepcopy(m),
813+
specific_qconfig_dict,
814+
example_inputs=example_inputs,
815+
equalization_qconfig_dict={})
775816
prepared(x)
776817
quantized = convert_fx(prepared) # Check if compile
777818
quantized_output = quantized(x)
@@ -780,6 +821,7 @@ def test_input_weight_equalization_results(self):
780821
prepared = prepare_fx(
781822
copy.deepcopy(m),
782823
specific_qconfig_dict,
824+
example_inputs=example_inputs,
783825
equalization_qconfig_dict=default_equalization_qconfig_dict
784826
)
785827
prepared(x)
@@ -817,7 +859,12 @@ def forward(self, x):
817859
[0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])
818860

819861
# Quantize the float model
820-
prepared_model = prepare_fx(copy.deepcopy(float_model), specific_qconfig_dict)
862+
example_inputs = (x,)
863+
prepared_model = prepare_fx(
864+
copy.deepcopy(float_model),
865+
specific_qconfig_dict,
866+
example_inputs=example_inputs
867+
)
821868
prepared_model(x)
822869
quantized_model = convert_fx(copy.deepcopy(prepared_model))
823870

@@ -832,6 +879,7 @@ def forward(self, x):
832879
prepared_model = prepare_fx(
833880
copy.deepcopy(float_model),
834881
specific_qconfig_dict,
882+
example_inputs=example_inputs,
835883
equalization_qconfig_dict=selective_equalization_qconfig_dict,
836884
)
837885
prepared_model(x)

0 commit comments

Comments
 (0)