Skip to content

Commit d56774b

Browse files
committed
Add support for int4 weight-only QAT
Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
1 parent b69083d commit d56774b

File tree

3 files changed

+468
-41
lines changed

3 files changed

+468
-41
lines changed

test/quantization/test_qat.py

Lines changed: 166 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,39 @@
1818
fake_quantize_per_channel_group,
1919
fake_quantize_per_token,
2020
)
21-
from torchao.quantization.utils import get_group_qparams_symmetric
21+
from torchao.quantization.utils import (
22+
get_group_qparams_symmetric,
23+
get_groupwise_affine_qparams,
24+
groupwise_affine_dequantize_tensor_from_qparams,
25+
groupwise_affine_quantize_tensor,
26+
groupwise_affine_quantize_tensor_from_qparams,
27+
)
2228
from torchao.utils import TORCH_VERSION_AFTER_2_4
2329

2430

2531
# TODO: put this in a common test utils file
32+
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
33+
2634
class Sub(torch.nn.Module):
2735
def __init__(self):
2836
super().__init__()
29-
self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float)
37+
self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float)
3038

3139
def example_inputs(self):
32-
return (torch.randn(1, 32).to(torch.float),)
40+
return (torch.randn(1, 256).to(torch.float),)
3341

3442
def forward(self, x):
3543
return self.linear(x)
3644

3745
class M(torch.nn.Module):
3846
def __init__(self):
3947
super().__init__()
40-
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
48+
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
4149
self.sub = Sub()
42-
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
50+
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
4351

4452
def example_inputs(self):
45-
return (torch.randn(1, 64).to(torch.float),)
53+
return (torch.randn(1, 512).to(torch.float),)
4654

4755
def forward(self, x):
4856
x = self.linear1(x)
@@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self):
111119

112120
def _set_ptq_weight(
113121
self,
114-
ptq_linear: "Int8DynActInt4WeightLinear",
115-
fp32_weight: torch.Tensor,
116-
group_size: int,
122+
ptq_linear: torch.nn.Module,
123+
qat_linear: torch.nn.Module,
117124
):
118125
"""
119126
Set the weight to the quantized version of the given fp32 weights,
120127
for making linear outputs comparable with QAT.
121128
"""
129+
from torchao.quantization.GPTQ import (
130+
Int8DynActInt4WeightLinear,
131+
WeightOnlyInt4Linear,
132+
)
133+
from torchao.quantization.prototype.qat import (
134+
Int8DynActInt4WeightQATLinear,
135+
Int4WeightOnlyQATLinear,
136+
)
122137
n_bit = 4
123138
(qmin, qmax) = self._get_qmin_qmax(n_bit)
124-
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
125-
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
126-
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
127-
)
128-
ptq_linear.weight = q_weight
129-
ptq_linear.scales = s
130-
ptq_linear.zeros = zp
139+
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
140+
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
141+
fp32_weight = qat_linear.weight
142+
group_size = qat_linear.groupsize
143+
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
144+
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
145+
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
146+
)
147+
ptq_linear.weight = q_weight
148+
ptq_linear.scales = s
149+
ptq_linear.zeros = zp
150+
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
151+
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
152+
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
153+
qat_linear.weight, n_bit, qat_linear.groupsize,
154+
)
155+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
156+
q_weight.to("cuda"), qat_linear.inner_k_tiles,
157+
)
158+
ptq_linear.weight = q_weight
159+
ptq_linear.scales_and_zeros = scales_and_zeros
160+
else:
161+
raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear))
131162

132163
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
133164
def test_qat_8da4w_linear(self):
@@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self):
144175
)
145176

146177
# Force the weights to be the same
147-
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
178+
self._set_ptq_weight(ptq_linear, qat_linear)
148179

149180
# Compare linear values
150181
torch.manual_seed(self.SEED)
@@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280311
loss_fn1 = torch.nn.CrossEntropyLoss()
281312
loss_fn2 = torch.nn.CrossEntropyLoss()
282313
example_inputs = nn_model.example_inputs()
283-
target = torch.randn(1, 64).float()
314+
target = torch.randn(1, 512).float()
284315
output1 = nn_model(*example_inputs)
285316
output2 = qat_model(*example_inputs)
286317
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
@@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self):
322353
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
323354
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)
324355

356+
def _assert_close_4w(self, val, ref):
357+
# Note: for int4 weight-only quantization, we do not expect exact match
358+
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
359+
# Here we use the same error bar as PyTorch core to determine closeness:
360+
# https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
361+
mean_err = ((val - ref) / ref).mean().abs()
362+
print(mean_err)
363+
self.assertTrue(mean_err < 0.05)
364+
365+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
366+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
367+
def test_qat_4w_primitives(self):
368+
n_bit = 4
369+
group_size = 32
370+
inner_k_tiles = 8
371+
scales_precision = torch.bfloat16
372+
device = torch.device("cuda")
373+
dtype = torch.bfloat16
374+
torch.manual_seed(self.SEED)
375+
x = torch.randn(100, 256, dtype=dtype, device=device)
376+
weight = torch.randn(512, 256, dtype=dtype, device=device)
377+
378+
# PTQ
379+
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
380+
weight, n_bit, group_size, scales_precision,
381+
)
382+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
383+
q_weight.to(device), inner_k_tiles,
384+
)
385+
ptq_out = torch.ops.aten._weight_int4pack_mm(
386+
x, q_weight, group_size, scales_and_zeros
387+
)
388+
389+
# QAT
390+
scales, zero_points = get_groupwise_affine_qparams(
391+
weight, n_bit, group_size, scales_precision,
392+
)
393+
w_q = groupwise_affine_quantize_tensor_from_qparams(
394+
weight, scales, zero_points, n_bit, group_size, cast_dtypes=False,
395+
)
396+
w_dq = groupwise_affine_dequantize_tensor_from_qparams(
397+
w_q, scales, zero_points, n_bit, group_size, cast_dtypes=False,
398+
)
399+
qat_out = torch.nn.functional.linear(x, w_dq)
400+
401+
self._assert_close_4w(qat_out, ptq_out)
402+
403+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
404+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
405+
def test_qat_4w_linear(self):
406+
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
407+
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
408+
409+
group_size = 128
410+
device = torch.device("cuda")
411+
dtype = torch.bfloat16
412+
torch.manual_seed(self.SEED)
413+
qat_linear = Int4WeightOnlyQATLinear(
414+
256, 688, bias=False, groupsize=group_size, device=device,
415+
)
416+
ptq_linear = WeightOnlyInt4Linear(
417+
256, 688, bias=False, groupsize=group_size, device=device,
418+
)
419+
420+
# Force the weights to be the same
421+
self._set_ptq_weight(ptq_linear, qat_linear)
422+
423+
# Compare linear values
424+
torch.manual_seed(self.SEED)
425+
x = torch.randn(100, 256, dtype=dtype, device=device)
426+
x2 = copy.deepcopy(x)
427+
qat_out = qat_linear(x)
428+
ptq_out = ptq_linear(x2)
429+
self._assert_close_4w(qat_out, ptq_out)
430+
431+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
432+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
433+
def test_qat_4w_quantizer(self):
434+
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
435+
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
436+
437+
group_size = 32
438+
inner_k_tiles = 8
439+
device = torch.device("cuda")
440+
dtype = torch.bfloat16
441+
torch.manual_seed(self.SEED)
442+
m = M().to(device).to(dtype)
443+
m2 = copy.deepcopy(m)
444+
qat_quantizer = Int4WeightOnlyQATQuantizer(
445+
groupsize=group_size, inner_k_tiles=inner_k_tiles,
446+
)
447+
ptq_quantizer = Int4WeightOnlyQuantizer(
448+
groupsize=group_size, inner_k_tiles=inner_k_tiles,
449+
)
450+
qat_model = qat_quantizer.prepare(m)
451+
ptq_model = ptq_quantizer.quantize(m2)
452+
453+
# Compare model values
454+
torch.manual_seed(self.SEED)
455+
x = [i.to(device).to(dtype) for i in m.example_inputs()]
456+
x2 = copy.deepcopy(x)
457+
qat_out = qat_model(*x)
458+
ptq_out = ptq_model(*x2)
459+
self._assert_close_4w(qat_out, ptq_out)
460+
461+
# Convert QAT model and compare model values
462+
converted_model = qat_quantizer.convert(qat_model)
463+
converted_out = converted_model(*x)
464+
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)
465+
466+
# Compare converted state dict
467+
ptq_state_dict = ptq_model.state_dict()
468+
converted_state_dict = converted_model.state_dict()
469+
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
470+
for k in ptq_state_dict.keys():
471+
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
472+
325473

326474
if __name__ == "__main__":
327475
unittest.main()

0 commit comments

Comments
 (0)