Skip to content

Commit 7d93bf7

Browse files
committed
Add convert path for 8da4w QAT
Summary: This commit implements the convert path for 8da4w QAT, which swaps the QAT linear with the quantized linear, and quantizing the weights the same way as the PTQ flow. The result is a model that is identical to the one output by the PTQ flow. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
1 parent ec6affe commit 7d93bf7

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

test/quantization/test_qat.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,18 @@ def test_qat_8da4w_quantizer(self):
188188
ptq_out = ptq_model(*x2)
189189
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
190190

191+
# Convert QAT model and compare model values
192+
converted_model = qat_quantizer.convert(qat_model)
193+
converted_out = converted_model(*x)
194+
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)
195+
196+
# Compare converted state dict
197+
ptq_state_dict = ptq_model.state_dict()
198+
converted_state_dict = converted_model.state_dict()
199+
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
200+
for k in ptq_state_dict.keys():
201+
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
202+
191203

192204
if __name__ == "__main__":
193205
unittest.main()

torchao/quantization/prototype/qat.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717

1818
if TORCH_VERSION_AFTER_2_3:
19-
from torchao.quantization.GPTQ import _replace_linear_8da4w
19+
from torchao.quantization.GPTQ import (
20+
_replace_linear_8da4w,
21+
Int8DynActInt4WeightLinear,
22+
)
2023

2124
class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
2225
"""
@@ -60,10 +63,38 @@ def convert(
6063
*args: Any,
6164
**kwargs: Any
6265
) -> torch.nn.Module:
63-
# TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear
64-
pass
65-
66+
_convert_qat_linear_8da4w(model)
67+
return model
6668

69+
def _convert_qat_linear_8da4w(module: torch.nn.Module):
70+
"""
71+
Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
72+
"""
73+
for name, child in module.named_children():
74+
if isinstance(child, Int8DynActInt4WeightQATLinear):
75+
quantized_linear = Int8DynActInt4WeightLinear(
76+
child.in_features,
77+
child.out_features,
78+
bias=False,
79+
groupsize=child.groupsize,
80+
precision=child.precision,
81+
scales_precision=child.scales_precision,
82+
)
83+
setattr(module, name, quantized_linear)
84+
85+
# Load weights and qparams into quantized linear
86+
n_bit = 4
87+
(qmin, qmax) = child._get_qmin_qmax(n_bit)
88+
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
89+
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
90+
child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize,
91+
)
92+
quantized_linear.weight = q_weight
93+
quantized_linear.scales = s
94+
quantized_linear.zeros = zp
95+
else:
96+
_convert_qat_linear_8da4w(child)
97+
6798
class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
6899
"""
69100
This module implements a linear layer with int8 dynamic per token fake
@@ -96,6 +127,7 @@ def __init__(
96127
), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
97128
assert not bias, "require bias=False"
98129
self.groupsize = groupsize
130+
self.precision = precision
99131
self.scales_precision = scales_precision
100132

101133
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
123155
)
124156
return torch.nn.functional.linear(x_fq, w_fq)
125157

158+
# TODO: move this to common util
126159
def _get_qmin_qmax(self, n_bit: int):
127160
qmin = -(2 ** (n_bit - 1))
128161
qmax = 2 ** (n_bit - 1) - 1

0 commit comments

Comments
 (0)