Skip to content

Commit c3dd49b

Browse files
committed
update NS for FX tutorial for PyTorch v1.13
Summary: Makes a couple of updates to ensure this tutorial still runs on 1.13: 1. changes the `qconfig_dict` argument of `prepare_fx` to `qconfig_mapping` 2. adds `example_inputs` to `prepare_fx` Test plan: Run the tutorial, it runs without errors on master
1 parent 7d8cb43 commit c3dd49b

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

prototype_source/fx_numeric_suite_tutorial.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939
import copy
4040
import torch
4141
import torchvision
42-
import torch.quantization
42+
import torch.ao.quantization
4343
import torch.ao.ns._numeric_suite_fx as ns
44-
import torch.quantization.quantize_fx as quantize_fx
44+
import torch.ao.quantization.quantize_fx as quantize_fx
45+
from torch.ao.quantization.qconfig_mapping import get_default_qconfig_mapping
4546

4647
import matplotlib.pyplot as plt
4748
from tabulate import tabulate
@@ -68,25 +69,23 @@ def plot(xdata, ydata, xlabel, ylabel, title):
6869
mobilenetv2_float = torchvision.models.quantization.mobilenet_v2(
6970
pretrained=True, quantize=False).eval()
7071

72+
# adjust the default qconfig to make the results more interesting to explore
73+
# 1. turn off quantization for the first couple of layers
74+
# 2. use MinMaxObserver for `features.17`, this should lead to worse
75+
# weight SQNR
76+
qconfig_mapping = get_default_qconfig_mapping('fbgemm')\
77+
.set_module_name('features.0', None)\
78+
.set_module_name('features.1', None)\
79+
.set_module_name('features.0', torch.ao.quantization.default_qconfig)
80+
7181
# create quantized model
72-
qconfig_dict = {
73-
'': torch.quantization.get_default_qconfig('fbgemm'),
74-
# adjust the qconfig to make the results more interesting to explore
75-
'module_name': [
76-
# turn off quantization for the first couple of layers
77-
('features.0', None),
78-
('features.1', None),
79-
# use MinMaxObserver for `features.17`, this should lead to worse
80-
# weight SQNR
81-
('features.17', torch.quantization.default_qconfig),
82-
]
83-
}
82+
8483
# Note: quantization APIs are inplace, so we save a copy of the float model for
8584
# later comparison to the quantized model. This is done throughout the
8685
# tutorial.
87-
mobilenetv2_prepared = quantize_fx.prepare_fx(
88-
copy.deepcopy(mobilenetv2_float), qconfig_dict)
8986
datum = torch.randn(1, 3, 224, 224)
87+
mobilenetv2_prepared = quantize_fx.prepare_fx(
88+
copy.deepcopy(mobilenetv2_float), qconfig_mapping, (datum,))
9089
mobilenetv2_prepared(datum)
9190
# Note: there is a long standing issue that we cannot copy.deepcopy a
9291
# quantized model. Since quantization APIs are inplace and we need to use

0 commit comments

Comments
 (0)