Skip to content

Commit 13ad473

Browse files
asl3pytorchmergebot
authored andcommitted
[quant] Implement PTQ for APoT FakeQuant (pytorch#81040)
### Summary: This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight. According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below). ### Test Plan: Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py` ### Accuracy Stats: 8-bit (Uniform int8, APoT b = 8 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5) 4-bit (Uniform int4, APoT b = 4 k = 2) **Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5) **Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5) **Model #3:** APoT activation, APoT weight (FX Graph Mode quantized) Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5) **Full Precision model (FX Graph Mode quantized)** Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5) **Eager mode quantized model** Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5) Pull Request resolved: pytorch#81040 Approved by: https://github.com/jerryzh168
1 parent f445c22 commit 13ad473

File tree

9 files changed

+358
-13
lines changed

9 files changed

+358
-13
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ ignore_missing_imports = True
7373
[mypy-torch.ao.quantization.experimental.fake_quantize_function]
7474
ignore_missing_imports = True
7575

76+
[mypy-torch.ao.quantization.experimental.fake_quantize]
77+
ignore_missing_imports = True
78+
7679
#
7780
# Files with various errors. Mostly real errors, possibly some false
7881
# positives as well.
44.7 MB
Binary file not shown.
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchvision
4+
import torchvision.transforms.transforms as transforms
5+
import os
6+
import torch.quantization
7+
8+
# Setup warnings
9+
import warnings
10+
warnings.filterwarnings(
11+
action='ignore',
12+
category=DeprecationWarning,
13+
module=r'.*'
14+
)
15+
warnings.filterwarnings(
16+
action='default',
17+
module=r'torch.quantization'
18+
)
19+
20+
"""
21+
Define helper functions
22+
"""
23+
24+
# Specify random seed for repeatable results
25+
_ = torch.manual_seed(191009)
26+
27+
class AverageMeter(object):
28+
"""Computes and stores the average and current value"""
29+
def __init__(self, name, fmt=':f'):
30+
self.name = name
31+
self.fmt = fmt
32+
self.reset()
33+
34+
def reset(self):
35+
self.val = 0
36+
self.avg = 0
37+
self.sum = 0
38+
self.count = 0
39+
40+
def update(self, val, n=1):
41+
self.val = val
42+
self.sum += val * n
43+
self.count += n
44+
self.avg = self.sum / self.count
45+
46+
def __str__(self):
47+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
48+
return fmtstr.format(**self.__dict__)
49+
50+
51+
def accuracy(output, target, topk=(1,)):
52+
"""Computes the accuracy over the k top predictions for the specified values of k"""
53+
with torch.no_grad():
54+
maxk = max(topk)
55+
batch_size = target.size(0)
56+
57+
_, pred = output.topk(maxk, 1, True, True)
58+
pred = pred.t()
59+
correct = pred.eq(target.view(1, -1).expand_as(pred))
60+
61+
res = []
62+
for k in topk:
63+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
64+
res.append(correct_k.mul_(100.0 / batch_size))
65+
return res
66+
67+
68+
def evaluate(model, criterion, data_loader):
69+
model.eval()
70+
top1 = AverageMeter('Acc@1', ':6.2f')
71+
top5 = AverageMeter('Acc@5', ':6.2f')
72+
cnt = 0
73+
with torch.no_grad():
74+
for image, target in data_loader:
75+
output = model(image)
76+
loss = criterion(output, target)
77+
cnt += 1
78+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
79+
top1.update(acc1[0], image.size(0))
80+
top5.update(acc5[0], image.size(0))
81+
print('')
82+
83+
return top1, top5
84+
85+
def load_model(model_file):
86+
model = resnet18(pretrained=False)
87+
state_dict = torch.load(model_file)
88+
model.load_state_dict(state_dict)
89+
model.to("cpu")
90+
return model
91+
92+
def print_size_of_model(model):
93+
if isinstance(model, torch.jit.RecursiveScriptModule):
94+
torch.jit.save(model, "temp.p")
95+
else:
96+
torch.jit.save(torch.jit.script(model), "temp.p")
97+
print("Size (MB):", os.path.getsize("temp.p") / 1e6)
98+
os.remove("temp.p")
99+
100+
def prepare_data_loaders(data_path):
101+
102+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
103+
std=[0.229, 0.224, 0.225])
104+
dataset = torchvision.datasets.ImageNet(data_path,
105+
split="train",
106+
transform=transforms.Compose([transforms.RandomResizedCrop(224),
107+
transforms.RandomHorizontalFlip(),
108+
transforms.ToTensor(),
109+
normalize]))
110+
dataset_test = torchvision.datasets.ImageNet(data_path,
111+
split="val",
112+
transform=transforms.Compose([transforms.Resize(256),
113+
transforms.CenterCrop(224),
114+
transforms.ToTensor(),
115+
normalize]))
116+
117+
train_sampler = torch.utils.data.RandomSampler(dataset)
118+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
119+
120+
data_loader = torch.utils.data.DataLoader(
121+
dataset, batch_size=train_batch_size,
122+
sampler=train_sampler)
123+
124+
data_loader_test = torch.utils.data.DataLoader(
125+
dataset_test, batch_size=eval_batch_size,
126+
sampler=test_sampler)
127+
128+
return data_loader, data_loader_test
129+
130+
data_path = '~/my_imagenet/'
131+
saved_model_dir = '/data/home/amandaliu/cluster/pytorch/test/quantization/core/experimental/data/'
132+
float_model_file = 'resnet18_pretrained_float.pth'
133+
134+
train_batch_size = 30
135+
eval_batch_size = 50
136+
137+
data_loader, data_loader_test = prepare_data_loaders(data_path)
138+
criterion = nn.CrossEntropyLoss()
139+
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
140+
float_model.eval()
141+
142+
# deepcopy the model since we need to keep the original model around
143+
import copy
144+
model_to_quantize = copy.deepcopy(float_model)
145+
146+
model_to_quantize.eval()
147+
148+
"""
149+
Prepare models
150+
"""
151+
152+
# Note that this is temporary, we'll expose these functions to torch.quantization after official releasee
153+
from torch.quantization.quantize_fx import prepare_fx, convert_fx
154+
155+
def calibrate(model, data_loader):
156+
model.eval()
157+
with torch.no_grad():
158+
for image, target in data_loader:
159+
model(image)
160+
161+
from torch.ao.quantization.experimental.qconfig import (
162+
uniform_qconfig_8bit,
163+
apot_weights_qconfig_8bit,
164+
apot_qconfig_8bit,
165+
uniform_qconfig_4bit,
166+
apot_weights_qconfig_4bit,
167+
apot_qconfig_4bit
168+
)
169+
170+
"""
171+
Prepare full precision model
172+
"""
173+
full_precision_model = float_model
174+
175+
top1, top5 = evaluate(full_precision_model, criterion, data_loader_test)
176+
print("Model #0 Evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))
177+
178+
"""
179+
Prepare model PTQ for specified qconfig for torch.nn.Linear
180+
"""
181+
def prepare_ptq_linear(qconfig):
182+
qconfig_dict = {"object_type": [(torch.nn.Linear, qconfig)]}
183+
prepared_model = prepare_fx(copy.deepcopy(float_model), qconfig_dict) # fuse modules and insert observers
184+
calibrate(prepared_model, data_loader_test) # run calibration on sample data
185+
return prepared_model
186+
187+
"""
188+
Prepare model with uniform activation, uniform weight
189+
b=8, k=2
190+
"""
191+
192+
prepared_model = prepare_ptq_linear(uniform_qconfig_8bit)
193+
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
194+
195+
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
196+
print("Model #1 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
197+
198+
"""
199+
Prepare model with uniform activation, uniform weight
200+
b=4, k=2
201+
"""
202+
203+
prepared_model = prepare_ptq_linear(uniform_qconfig_4bit)
204+
quantized_model = convert_fx(prepared_model) # convert the calibrated model to a quantized model
205+
206+
top1, top5 = evaluate(quantized_model1, criterion, data_loader_test)
207+
print("Model #1 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
208+
209+
"""
210+
Prepare model with uniform activation, APoT weight
211+
(b=8, k=2)
212+
"""
213+
214+
prepared_model = prepare_ptq_linear(apot_weights_qconfig_8bit)
215+
216+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
217+
print("Model #2 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
218+
219+
"""
220+
Prepare model with uniform activation, APoT weight
221+
(b=4, k=2)
222+
"""
223+
224+
prepared_model = prepare_ptq_linear(apot_weights_qconfig_4bit)
225+
226+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
227+
print("Model #2 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
228+
229+
230+
"""
231+
Prepare model with APoT activation and weight
232+
(b=8, k=2)
233+
"""
234+
235+
prepared_model = prepare_ptq_linear(apot_qconfig_8bit)
236+
237+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
238+
print("Model #3 Evaluation accuracy on test dataset (b=8, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
239+
240+
"""
241+
Prepare model with APoT activation and weight
242+
(b=4, k=2)
243+
"""
244+
245+
prepared_model = prepare_ptq_linear(apot_qconfig_4bit)
246+
247+
top1, top5 = evaluate(prepared_model, criterion, data_loader_test)
248+
print("Model #3 Evaluation accuracy on test dataset (b=4, k=2): %2.2f, %2.2f" % (top1.avg, top5.avg))
249+
250+
"""
251+
Prepare eager mode quantized model
252+
"""
253+
254+
from torchvision.models.quantization.resnet import resnet18
255+
eager_quantized_model = resnet18(pretrained=True, quantize=True).eval()
256+
top1, top5 = evaluate(eager_quantized_model, criterion, data_loader_test)
257+
print("Eager mode quantized model evaluation accuracy on test dataset: %2.2f, %2.2f" % (top1.avg, top5.avg))

test/quantization/core/experimental/test_fake_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_fake_calc_qparams(self):
3535

3636
r""" Tests fake quantize forward() method
3737
by comparing result with expected
38-
float_to_reduced_precision mapping of input tensor.
38+
quant_dequant_APoT mapping of input tensor.
3939
Uses input tensor with random values from 0 -> 1000
4040
and APoT observer with hard-coded values b=4, k=2
4141
"""

torch/ao/quantization/experimental/apot_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def float_to_apot(x, levels, indices, alpha):
3333
reduced precision floating point value
3434
based on quantization levels
3535
"""
36-
def float_to_reduced_precision(x, levels, indices):
36+
def quant_dequant_util(x, levels, indices):
3737
levels_lst = list(levels)
3838
indices_lst = list(indices)
3939

torch/ao/quantization/experimental/fake_quantize.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,23 @@ class APoTFakeQuantize(FakeQuantizeBase):
1010
quantization_levels: Tensor
1111
level_indices: Tensor
1212

13-
def __init__(self, **observer_kwargs):
13+
def __init__(self, observer=APoTObserver, **observer_kwargs):
1414
super().__init__()
15-
self.activation_post_process = APoTObserver(**observer_kwargs)
15+
self.activation_post_process = observer(**observer_kwargs)
16+
self.dtype = self.activation_post_process.dtype
1617

17-
def calculate_qparams(self, signed: bool): # type: ignore[override]
18+
def calculate_qparams(self, signed=False): # type: ignore[override]
1819
return self.activation_post_process.calculate_qparams(signed=signed)
1920

20-
def forward(self, X: torch.Tensor, signed: bool): # type: ignore[override]
21+
def forward(self, X: torch.Tensor): # type: ignore[override]
2122
if self.observer_enabled[0] == 1:
2223
self.activation_post_process.forward(X)
23-
self.alpha, self.gamma, self.quantization_levels, self.level_indices = \
24-
self.activation_post_process.calculate_qparams(signed)
24+
result = self.activation_post_process.calculate_qparams(signed=False)
25+
self.alpha = result[0]
26+
self.gamma = result[1]
27+
self.quantization_levels = result[2]
28+
self.level_indices = result[3]
29+
2530
if self.fake_quant_enabled[0] == 1:
2631
assert (self.alpha is not None
2732
and self.gamma is not None

torch/ao/quantization/experimental/observer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
self,
2424
b,
2525
k,
26-
dtype=torch.int32) -> None:
26+
dtype=torch.quint8) -> None:
2727
super().__init__(dtype)
2828
self.b = b
2929
self.k = k
@@ -47,7 +47,7 @@ def calculate_qparams(self, signed):
4747
quantization_levels: non-uniform quantization levels (fp representation)
4848
level_indices: int representation of quantization_levels indices
4949
"""
50-
def _calculate_qparams(self, signed, min_val=None, max_val=None):
50+
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
5151
if min_val is not None:
5252
self.min_val = min_val
5353
if max_val is not None:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from torch.ao.quantization.qconfig import QConfig
3+
from torch.ao.quantization import MinMaxObserver
4+
from torch.ao.quantization.fake_quantize import FakeQuantize
5+
from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
6+
7+
"""
8+
Default symmetric fake_quant for activations.
9+
"""
10+
default_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
11+
qscheme=torch.per_tensor_symmetric,
12+
dtype=torch.quint8)
13+
14+
"""
15+
Default symmetric fake_quant for weights.
16+
"""
17+
default_weight_symmetric_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver,
18+
qscheme=torch.per_tensor_symmetric,
19+
dtype=torch.qint8)
20+
21+
# uniform activation and weight, b=8 k=2
22+
uniform_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant,
23+
weight=default_weight_symmetric_fake_quant.with_args)
24+
25+
# uniform activation, APoT weight, b=8 k=2
26+
apot_weight_qconfig_8bit = QConfig(activation=default_symmetric_fake_quant.with_args,
27+
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
28+
29+
# APoT activation and uniform weight, b=8 k=2
30+
apot_qconfig_8bit = QConfig(activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
31+
weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8))
32+
33+
# uniform activation and weight, b=4 k=2
34+
uniform_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
35+
quant_max=15),
36+
weight=default_weight_symmetric_fake_quant.with_args(quant_min=0,
37+
quant_max=15))
38+
39+
# uniform activation, APoT weight, b=4 k=2
40+
apot_weight_qconfig_4bit = QConfig(activation=default_symmetric_fake_quant.with_args(quant_min=0,
41+
quant_max=15),
42+
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))
43+
44+
# APoT activation and uniform weight, b=4 k=2
45+
apot_qconfig_4bit = QConfig(activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
46+
weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8))

0 commit comments

Comments
 (0)