Skip to content

Emulate weight and activation only quant with fake quant, numerics test #26625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions test/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,37 @@ def forward(self, x):
out = out.view(-1, 3 * 2 * 2)
out = self.fc(out)
return out

# Model to ensure consistency of fake quant with true quant
# Average pooling and mean operations are not modelled
# accurately with fake-quant so this model does not
# contain those operations
class ModelMultipleOpsNoAvgPool(torch.nn.Module):
def __init__(self):
super(ModelMultipleOpsNoAvgPool, self).__init__()
norm_layer = nn.BatchNorm2d
inplanes = 3
self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
self.bn1 = norm_layer(inplanes)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.skip_add = nn.quantized.FloatFunctional()
self.cat = nn.quantized.FloatFunctional()
self.maxpool = nn.MaxPool2d((4, 4))
self.fc = nn.Linear(12, 6)

def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
skip = self.conv2(x)
out = self.skip_add.add(out, skip)
out = self.relu2(out)
out = self.maxpool(out)
out = self.conv2(out)
out = torch.nn.functional.max_pool2d(out, 2, 2)
out = self.cat.cat([out, out])
out = out.view(-1, 3 * 2 * 2)
out = self.fc(out)
return out
58 changes: 57 additions & 1 deletion test/test_quantized_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.jit
import unittest
from common_utils import run_tests
from common_quantization import QuantizationTestCase, ModelMultipleOps
from common_quantization import QuantizationTestCase, ModelMultipleOps, ModelMultipleOpsNoAvgPool

@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
"Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
Expand Down Expand Up @@ -50,5 +50,61 @@ def test_float_quant_compare_per_channel(self):
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')

def test_fake_quant_true_quant_compare(self):
torch.manual_seed(67)
myModel = ModelMultipleOpsNoAvgPool().to(torch.float32)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
myModel.eval()
out_ref = myModel(eval_data)
fqModel = torch.quantization.QuantWrapper(myModel)
fqModel.train()
fqModel.qconfig = torch.quantization.default_qat_qconfig
torch.quantization.fuse_modules(fqModel.module, [['conv1', 'bn1', 'relu1']])
torch.quantization.prepare_qat(fqModel)
fqModel.eval()
fqModel.apply(torch.quantization.disable_fake_quant)
fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats)
fqModel(calib_data)
fqModel.apply(torch.quantization.enable_fake_quant)
fqModel.apply(torch.quantization.disable_observer)
out_fq = fqModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
torch.quantization.convert(fqModel)
out_q = fqModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')

# Test to compare weight only quantized model numerics and
# activation only quantized model numerics with float
def test_weight_only_activation_only_fakequant(self):
torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = set([torch.quantization.default_weight_only_quant_qconfig,
torch.quantization.default_activation_only_quant_qconfig])
SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset):
myModel = ModelMultipleOpsNoAvgPool().to(torch.float32)
myModel.eval()
out_ref = myModel(eval_data)
fqModel = torch.quantization.QuantWrapper(myModel)
fqModel.train()
fqModel.qconfig = qconfig
torch.quantization.fuse_modules(fqModel.module, [['conv1', 'bn1', 'relu1']])
torch.quantization.prepare_qat(fqModel)
fqModel.eval()
fqModel.apply(torch.quantization.disable_fake_quant)
fqModel.apply(torch.nn._intrinsic.qat.freeze_bn_stats)
fqModel(calib_data)
fqModel.apply(torch.quantization.enable_fake_quant)
fqModel.apply(torch.quantization.disable_observer)
out_fq = fqModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')

if __name__ == "__main__":
run_tests()
5 changes: 5 additions & 0 deletions torch/quantization/QConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,8 @@ def __new__(cls, weight):

default_qat_qconfig = QConfig(activation=default_fake_quant,
weight=default_weight_fake_quant)

default_weight_only_quant_qconfig = QConfig(activation=torch.nn.Identity,
weight=default_weight_fake_quant)
default_activation_only_quant_qconfig = QConfig(activation=default_fake_quant,
weight=torch.nn.Identity)