From 36fad622f54c64f2d1b1f9d8b1084dc5865e4fc2 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 22 Apr 2024 10:52:45 -0700 Subject: [PATCH] Add convert path for 8da4w QAT Summary: This commit implements the convert path for 8da4w QAT, which swaps the QAT linear with the quantized linear, and quantizing the weights the same way as the PTQ flow. The result is a model that is identical to the one output by the PTQ flow. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar --- test/quantization/test_qat.py | 12 ++++++++ torchao/quantization/prototype/qat.py | 41 ++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 64518c0599..031e5ef14d 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -188,6 +188,18 @@ def test_qat_8da4w_quantizer(self): ptq_out = ptq_model(*x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 621f4bf80f..7ba64f3aca 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -16,7 +16,10 @@ if TORCH_VERSION_AFTER_2_3: - from torchao.quantization.GPTQ import _replace_linear_8da4w + from torchao.quantization.GPTQ import ( + _replace_linear_8da4w, + Int8DynActInt4WeightLinear, + ) class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ @@ -60,10 +63,38 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear - pass - + _convert_qat_linear_8da4w(model) + return model + def _convert_qat_linear_8da4w(module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = child._get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + _convert_qat_linear_8da4w(child) + class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -96,6 +127,7 @@ def __init__( ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" assert not bias, "require bias=False" self.groupsize = groupsize + self.precision = precision self.scales_precision = scales_precision def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.nn.functional.linear(x_fq, w_fq) + # TODO: move this to common util def _get_qmin_qmax(self, n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1