diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 64518c0599..031e5ef14d 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -188,6 +188,18 @@ def test_qat_8da4w_quantizer(self): ptq_out = ptq_model(*x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 621f4bf80f..7ba64f3aca 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -16,7 +16,10 @@ if TORCH_VERSION_AFTER_2_3: - from torchao.quantization.GPTQ import _replace_linear_8da4w + from torchao.quantization.GPTQ import ( + _replace_linear_8da4w, + Int8DynActInt4WeightLinear, + ) class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ @@ -60,10 +63,38 @@ def convert( *args: Any, **kwargs: Any ) -> torch.nn.Module: - # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear - pass - + _convert_qat_linear_8da4w(model) + return model + def _convert_qat_linear_8da4w(module: torch.nn.Module): + """ + Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int8DynActInt4WeightQATLinear): + quantized_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + groupsize=child.groupsize, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (qmin, qmax) = child._get_qmin_qmax(n_bit) + (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize, + ) + quantized_linear.weight = q_weight + quantized_linear.scales = s + quantized_linear.zeros = zp + else: + _convert_qat_linear_8da4w(child) + class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -96,6 +127,7 @@ def __init__( ), f"require in_features:{in_features} % groupsize:{groupsize} == 0" assert not bias, "require bias=False" self.groupsize = groupsize + self.precision = precision self.scales_precision = scales_precision def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return torch.nn.functional.linear(x_fq, w_fq) + # TODO: move this to common util def _get_qmin_qmax(self, n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1