Skip to content

Commit 2a8dc5d

Browse files
committed
Enable dispatch to tinygemm int4 and int8 kernels for unified quantized tensor
Summary: This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation mismatch problem for tinygemm first Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4 python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8 Reviewers: Subscribers: Tasks: Tags:
1 parent e7bbbd2 commit 2a8dc5d

File tree

5 files changed

+243
-43
lines changed

5 files changed

+243
-43
lines changed

test/quantization/test_quant_api.py

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import unittest
1010
import torch
1111
import os
12-
from torch._export import capture_pre_autograd_graph
1312
from torch.ao.quantization.quantize_pt2e import (
1413
prepare_pt2e,
1514
convert_pt2e,
@@ -36,7 +35,7 @@
3635

3736

3837
def dynamic_quant(model, example_inputs):
39-
m = capture_pre_autograd_graph(model, example_inputs)
38+
m = torch.export.export(model, example_inputs).module()
4039
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
4140
m = prepare_pt2e(m, quantizer)
4241
m = convert_pt2e(m)
@@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
5049
"""
5150
_replace_with_custom_fn_if_matches_filter(
5251
model,
53-
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))),
52+
lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)),
5453
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
5554
)
5655
return model
5756

5857

5958
def capture_and_prepare(model, example_inputs):
60-
m = capture_pre_autograd_graph(model, example_inputs)
59+
m = torch.export.export(model, example_inputs)
6160
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True))
6261
m = prepare_pt2e(m, quantizer)
6362
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
@@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8887
return model
8988

9089
class ToyLinearModel(torch.nn.Module):
91-
def __init__(self):
90+
def __init__(self, m=64, n=32, k=64):
9291
super().__init__()
93-
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
94-
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
92+
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
93+
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
9594

9695
def example_inputs(self):
97-
return (torch.randn(1, 64).to(torch.float),)
96+
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
9897

9998
def forward(self, x):
10099
x = self.linear1(x)
@@ -104,8 +103,9 @@ def forward(self, x):
104103
class TestQuantFlow(unittest.TestCase):
105104
def test_dynamic_quant_gpu_singleline(self):
106105
m = ToyLinearModel().eval()
106+
example_inputs = m.example_inputs()
107107
m = _apply_dynamic_quant(m)
108-
quantized = m(*m.example_inputs())
108+
quantized = m(*example_inputs)
109109
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
110110
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
111111
# m = torch.compile(m, mode="max-autotune")
@@ -442,7 +442,94 @@ def get_per_token_block_size(x):
442442
ref = m_copy(*example_inputs)
443443
self.assertTrue(torch.equal(res, ref))
444444

445+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
446+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
447+
def test_quantized_tensor_subclass_int4(self):
448+
from torchao.quantization.subclass import AffineQuantizedTensor
449+
from torchao.quantization.quant_primitives import MappingType
450+
from torchao.quantization.quant_primitives import ZeroPointDomain
451+
import copy
452+
453+
# weight settings
454+
groupsize = 32
455+
mapping_type = MappingType.ASYMMETRIC
456+
block_size = (1, groupsize)
457+
target_dtype = torch.int32
458+
quant_min = 0
459+
quant_max = 15
460+
eps = 1e-6
461+
preserve_zero = False
462+
zero_point_dtype = torch.bfloat16
463+
464+
# weight only quantization
465+
input_quant_func = None
466+
467+
# use 1024 so that we don't need padding
468+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
469+
m_copy = copy.deepcopy(m)
470+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
471+
472+
def to_quantized(weight):
473+
return AffineQuantizedTensor.from_float(
474+
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
475+
zero_point_dtype=zero_point_dtype,
476+
preserve_zero=preserve_zero,
477+
zero_point_domain=ZeroPointDomain.FLOAT,
478+
input_quant_func=input_quant_func,
479+
)
480+
481+
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
482+
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
483+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
484+
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
485+
486+
# reference
487+
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
488+
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
489+
490+
res = m(*example_inputs)
491+
ref = m_copy(*example_inputs)
492+
493+
self.assertTrue(torch.equal(res, ref))
494+
495+
496+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
497+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
498+
def test_quantized_tensor_subclass_int8(self):
499+
from torchao.quantization.subclass import AffineQuantizedTensor
500+
from torchao.quantization.quant_primitives import MappingType
501+
import copy
502+
503+
# weight settings
504+
mapping_type = MappingType.SYMMETRIC
505+
target_dtype = torch.int8
506+
eps = torch.finfo(torch.float32).eps
507+
zero_point_dtype = torch.int64
508+
509+
# weight only quantization
510+
input_quant_func = None
511+
512+
m = ToyLinearModel().eval().to(torch.bfloat16)
513+
m_copy = copy.deepcopy(m)
514+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
515+
516+
def to_quantized(weight):
517+
block_size = (1, weight.shape[1])
518+
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
519+
520+
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
521+
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
522+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
523+
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
524+
525+
# reference
526+
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
527+
change_linear_weights_to_int8_woqtensors(m_copy)
528+
529+
res = m(*example_inputs)
530+
ref = m_copy(*example_inputs)
445531

532+
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
446533

447534

448535
if __name__ == "__main__":

test/quantization/test_quant_primitives.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ def test_not_preserve_zero_not_supported(self):
327327

328328

329329
def test_tinygemm_get_groupwise_affine_qparams(self):
330+
from torchao.quantization.quant_primitives import ZeroPointDomain
331+
330332
input = torch.randn(10, 256)
331333
n_bit = 4
332334
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
@@ -351,16 +353,11 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
351353
scale_dtype=scale_dtype,
352354
zero_point_dtype=zero_point_dtype,
353355
preserve_zero=False,
356+
zero_point_domain=ZeroPointDomain.FLOAT,
354357
)
355358

356-
def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
357-
return (quant_min - zero_point + mid_point) * scale
358-
359-
mid_point = 2 ** (n_bit - 1)
360-
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)
361-
362359
self.assertTrue(torch.equal(scale, scale_ref))
363-
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)
360+
self.assertTrue(torch.equal(zero_point, zero_point_ref))
364361

365362

366363
if __name__ == "__main__":

torchao/quantization/autoquant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
quantize_activation_per_token_absmax,
1010
safe_int_mm,
1111
)
12+
from .utils import TORCH_VERSION_AFTER_2_4
1213
import torch.nn.functional as F
1314
try:
1415
from torch._inductor.utils import do_bench

torchao/quantization/quant_primitives.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
7272
torch.uint7: (0, 2**7-1),
7373
})
7474

75+
class MappingType(Enum):
76+
SYMMETRIC = 0
77+
ASYMMETRIC = 1
78+
79+
class ZeroPointDomain(Enum):
80+
INT = 0
81+
FLOAT = 1
82+
7583
# TODO: decide on if we want to allow custom quant_min/quant_max here
7684
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
7785
"""Get quant_min and quant_max args based on dtype and also
@@ -141,7 +149,8 @@ def quantize_affine(
141149
zero_point: Optional[torch.Tensor],
142150
output_dtype: torch.dtype,
143151
quant_min: Optional[int] = None,
144-
quant_max: Optional[int] = None
152+
quant_max: Optional[int] = None,
153+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
145154
):
146155
"""
147156
Args:
@@ -153,6 +162,12 @@ def quantize_affine(
153162
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
154163
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
155164
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
165+
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
166+
if zero_point is in integer domain, zero point is added to the quantized integer value during
167+
quantization
168+
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
169+
value during quantization
170+
default is ZeroPointDomain.INT
156171
157172
Note:
158173
How can block_size represent different granularities?
@@ -184,9 +199,19 @@ def quantize_affine(
184199
if zero_point is not None:
185200
zero_point = zero_point.view(shape_after_reduction)
186201

187-
quant = torch.clamp(
188-
torch.round(input / scale) + zero_point, quant_min, quant_max
189-
).to(output_dtype)
202+
if zero_point_domain == ZeroPointDomain.INT:
203+
quant = torch.clamp(
204+
torch.round(input / scale) + zero_point, quant_min, quant_max
205+
).to(output_dtype)
206+
else:
207+
assert zero_point_domain == ZeroPointDomain.FLOAT
208+
mid_point = (quant_max + quant_min + 1) / 2
209+
min_val = zero_point - scale * mid_point
210+
quant = (
211+
torch.clamp(
212+
torch.round((input - min_val) / scale),
213+
quant_min, quant_max)
214+
).to(output_dtype)
190215
quant = quant.view(original_shape)
191216

192217
return quant
@@ -199,6 +224,7 @@ def dequantize_affine(
199224
input_dtype: torch.dtype,
200225
quant_min: Optional[int] = None,
201226
quant_max: Optional[int] = None,
227+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
202228
*,
203229
output_dtype: torch.dtype = torch.float32,
204230
):
@@ -213,6 +239,12 @@ def dequantize_affine(
213239
quant_min (Optional[int]): minimum quantized value for input Tensor
214240
quant_max (Optional[int]): maximum quantized value for input Tensor
215241
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
242+
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
243+
if zero_point is in integer domain, zero point is added to the quantized integer value during
244+
quantization
245+
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
246+
value during quantization
247+
default is ZeroPointDomain.INT
216248
217249
Output:
218250
dequantized Tensor, with requested dtype or fp32
@@ -233,18 +265,22 @@ def dequantize_affine(
233265
if zero_point is not None:
234266
zero_point = zero_point.view(shape_after_reduction)
235267

236-
dequant = input.to(torch.int32)
237-
if zero_point is not None:
238-
dequant -= zero_point.to(torch.int32)
239-
dequant = dequant.to(output_dtype)
240-
dequant *= scale
241-
dequant = dequant.view(original_shape)
242-
return dequant.to(output_dtype)
268+
if zero_point_domain == ZeroPointDomain.INT:
269+
dequant = input.to(torch.int32)
270+
if zero_point is not None:
271+
dequant -= zero_point.to(torch.int32)
272+
dequant = dequant.to(output_dtype)
273+
dequant *= scale
274+
else:
275+
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
276+
mid_point = (quant_max + quant_min + 1) / 2
277+
dequant = input - mid_point
278+
dequant = dequant.to(output_dtype)
279+
dequant *= scale
280+
if zero_point is not None:
281+
dequant += zero_point
243282

244-
245-
class MappingType(Enum):
246-
SYMMETRIC = 0
247-
ASYMMETRIC = 1
283+
return dequant.view(original_shape).to(output_dtype)
248284

249285
def choose_qparams_affine(
250286
input: torch.Tensor,
@@ -256,7 +292,8 @@ def choose_qparams_affine(
256292
eps: Optional[float] = None,
257293
scale_dtype: Optional[torch.dtype] = None,
258294
zero_point_dtype: Optional[torch.dtype] = None,
259-
preserve_zero = True,
295+
preserve_zero: bool = True,
296+
zero_point_domain = ZeroPointDomain.INT,
260297
) -> Tuple[torch.Tensor, torch.Tensor]:
261298
"""
262299
Args:
@@ -280,6 +317,13 @@ def choose_qparams_affine(
280317
281318
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
282319
320+
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
321+
if zero_point is in integer domain, zero point is added to the quantized integer value during
322+
quantization
323+
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
324+
value during quantization
325+
default is ZeroPointDomain.INT
326+
283327
Output:
284328
Tuple of scales and zero_points Tensor with requested dtype
285329
"""
@@ -310,15 +354,18 @@ def choose_qparams_affine(
310354
scale = max_val_pos / (float(quant_max - quant_min) / 2)
311355
if not preserve_zero:
312356
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
313-
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
357+
if zero_point_domain != ZeroPointDomain.INT:
358+
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
359+
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
314360
else:
315361
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
316362
if preserve_zero:
317363
zero_point = quant_min - torch.round(min_val_neg / scale)
318364
zero_point = torch.clamp(zero_point, quant_min, quant_max)
319365
else:
320-
zero_point = quant_min - min_val_neg / scale
321-
366+
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
367+
mid_point = (quant_max + quant_min + 1) / 2
368+
zero_point = min_val_neg + scale * mid_point
322369

323370
if eps is None:
324371
eps = torch.finfo(input.dtype).eps

0 commit comments

Comments
 (0)