Skip to content

Commit d54cecb

Browse files
committed
Enable dispatch to tinygemm int4 and int8 kernels for unified quantized tensor
Summary: This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation mismatch problem for tinygemm first Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent b91b6be commit d54cecb

File tree

3 files changed

+174
-8
lines changed

3 files changed

+174
-8
lines changed

test/quantization/test_quant_api.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
8888
return model
8989

9090
class ToyLinearModel(torch.nn.Module):
91-
def __init__(self):
91+
def __init__(self, m=64, n=32, k=64):
9292
super().__init__()
93-
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
94-
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
93+
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
94+
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
9595

9696
def example_inputs(self):
97-
return (torch.randn(1, 64).to(torch.float),)
97+
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
9898

9999
def forward(self, x):
100100
x = self.linear1(x)
@@ -442,7 +442,81 @@ def get_per_token_block_size(x):
442442
ref = m_copy(*example_inputs)
443443
self.assertTrue(torch.equal(res, ref))
444444

445+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
446+
def test_quantized_tensor_subclass_int4(self):
447+
from torchao.quantization.subclass import TinygemmAffineQuantizedTensor
448+
from torchao.quantization.quant_primitives import MappingType
449+
import copy
450+
451+
# weight settings
452+
groupsize = 32
453+
mapping_type = MappingType.ASYMMETRIC
454+
block_size = (1, groupsize)
455+
eps = 1e-6
456+
preserve_zero = False
457+
458+
# weight only quantization
459+
input_quant_func = None
460+
461+
# use 1024 so that we don't need padding
462+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
463+
m_copy = copy.deepcopy(m)
464+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
465+
466+
def to_quantized(weight):
467+
return TinygemmAffineQuantizedTensor.from_float(weight, mapping_type, block_size, eps, input_quant_func=input_quant_func)
468+
469+
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
470+
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
471+
assert isinstance(m.linear1.weight, TinygemmAffineQuantizedTensor)
472+
assert isinstance(m.linear2.weight, TinygemmAffineQuantizedTensor)
473+
474+
# reference
475+
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
476+
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
477+
478+
res = m(*example_inputs)
479+
ref = m_copy(*example_inputs)
480+
481+
torch.testing.assert_close(res, ref, rtol=0.00001, atol=0.02)
482+
483+
484+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
485+
def test_quantized_tensor_subclass_int8(self):
486+
from torchao.quantization.subclass import AffineQuantizedTensor
487+
from torchao.quantization.quant_primitives import MappingType
488+
import copy
489+
490+
# weight settings
491+
mapping_type = MappingType.SYMMETRIC
492+
target_dtype = torch.int8
493+
eps = torch.finfo(torch.float32).eps
494+
zero_point_dtype = torch.int64
495+
496+
# weight only quantization
497+
input_quant_func = None
498+
499+
m = ToyLinearModel().eval().to(torch.bfloat16)
500+
m_copy = copy.deepcopy(m)
501+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
502+
503+
def to_quantized(weight):
504+
block_size = (1, weight.shape[1])
505+
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
506+
507+
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
508+
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
509+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
510+
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
511+
512+
# reference
513+
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
514+
change_linear_weights_to_int8_woqtensors(m_copy)
515+
516+
res = m(*example_inputs)
517+
ref = m_copy(*example_inputs)
445518

519+
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
446520

447521

448522
if __name__ == "__main__":

test/quantization/test_quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
353353
preserve_zero=False,
354354
)
355355

356-
def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
356+
def int_zero_point_to_float(zero_point, scale, quant_min, mid_point):
357357
return (quant_min - zero_point + mid_point) * scale
358358

359359
mid_point = 2 ** (n_bit - 1)

torchao/quantization/subclass.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
dynamically_quantize_per_channel,
1515
groupwise_affine_quantize_tensor,
1616
quant_int8_dynamic_per_token_linear,
17+
pack_tinygemm_scales_and_zeros,
1718
unpack_tinygemm_scales_and_zeros,
19+
groupwise_affine_quantize_tensor_from_qparams,
1820
choose_qparams_affine,
1921
quantize_affine,
2022
dequantize_affine,
@@ -619,7 +621,7 @@ class AffineQuantizedTensor(torch.Tensor):
619621
shape (torch.Size): the shape for the Tensor
620622
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
621623
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
622-
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object
624+
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
623625
dtype: dtype for external representation of the tensor, e.g. torch.float32
624626
"""
625627

@@ -635,6 +637,7 @@ def __new__(
635637
quant_max: Optional[int] = None,
636638
input_quant_func: Optional[Callable] = None,
637639
dtype=None,
640+
# TODO: remove args and kwargs
638641
*args,
639642
**kwargs
640643
):
@@ -677,7 +680,9 @@ def __repr__(self):
677680
f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
678681
)
679682

680-
def dequantize(self, output_dtype=torch.float32):
683+
def dequantize(self, output_dtype=None):
684+
if output_dtype is None:
685+
output_dtype = self.dtype
681686
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype)
682687

683688
def __tensor_flatten__(self):
@@ -740,7 +745,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
740745
args[1],
741746
args[2] if len(args) > 2 else None,
742747
)
743-
if weight_qtensor.input_quant_func is not None:
748+
if weight_qtensor.input_quant_func is None:
749+
is_cuda = args[0].is_cuda
750+
is_cpu = args[0].device == torch.device("cpu")
751+
# weight only quantization
752+
is_int8 = (
753+
weight_qtensor.int_data.dtype == torch.int8 and
754+
weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and
755+
weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127
756+
)
757+
is_uint4 = (
758+
weight_qtensor.int_data.dtype == torch.int32 and
759+
weight_qtensor.quant_min == 0 and
760+
weight_qtensor.quant_max == 15
761+
)
762+
763+
# TODO: enable cpu and mps path as well
764+
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
765+
# TODO: move this to TinygemmAffineQuantizedTensor
766+
if (
767+
is_cuda and
768+
is_uint4 and
769+
weight_qtensor.dtype == torch.bfloat16 and
770+
len(weight_qtensor.shape) == 2 and
771+
weight_qtensor.block_size[0] == 1
772+
):
773+
# groupwise int4 quantization
774+
# TODO: currently doing packing on the fly, we'll need to figure out
775+
# the API to do packing before hand
776+
# TODO: expose the arg
777+
innerKTiles = 8
778+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
779+
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
780+
groupsize = weight_qtensor.block_size[-1]
781+
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
782+
elif (
783+
is_cpu and
784+
is_int8 and
785+
len(weight_qtensor.shape) == 2 and
786+
len(weight_qtensor.block_size) == 2 and
787+
weight_qtensor.block_size[0] == 1 and
788+
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
789+
):
790+
# TODO: enable mps path as well
791+
# per channel int8 weight only quantizated mm
792+
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
793+
else:
794+
# dynamic quantization
795+
# TODO: enable int8 dynamic quant dispatch
744796
input_tensor = weight_qtensor.input_quant_func(input_tensor)
745797
input_tensor = input_tensor.dequantize()
746798
weight_tensor = weight_qtensor.dequantize()
@@ -865,3 +917,43 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
865917
kwargs,
866918
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
867919
)
920+
921+
922+
# TODO: add padding support
923+
class TinygemmAffineQuantizedTensor(AffineQuantizedTensor):
924+
@classmethod
925+
def from_float(
926+
cls,
927+
input_float,
928+
mapping_type,
929+
block_size,
930+
eps = None,
931+
scale_dtype = None,
932+
zero_point_dtype = None,
933+
input_quant_func = None,
934+
):
935+
# TODO: replace this with uint4 dtype
936+
target_dtype = torch.int32
937+
quant_min = 0
938+
quant_max = 15
939+
preserve_zero = False
940+
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero)
941+
def int_zero_point_to_float(zero_point, scale, quant_min, mid_point):
942+
return (quant_min - zero_point + mid_point) * scale
943+
944+
mid_point = (quant_min + quant_max + 1) / 2
945+
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)
946+
n_bit = 4
947+
groupsize = block_size[1]
948+
int_data = groupwise_affine_quantize_tensor_from_qparams(input_float, scale, zero_point_float, n_bit, groupsize)
949+
return cls(
950+
int_data,
951+
scale,
952+
zero_point_float,
953+
block_size,
954+
input_float.shape,
955+
quant_min,
956+
quant_max,
957+
input_quant_func=input_quant_func,
958+
dtype=input_float.dtype
959+
)

0 commit comments

Comments
 (0)