Skip to content

Commit 230ebf6

Browse files
committed
Enable dispatch to tinygemm int4 and int8 kernels for unified quantized tensor
Summary: This adds some dispatch to the tinygemm kernels for cuda, although need to resolve implementation mismatch problem for tinygemm first Test Plan: TODO Reviewers: Subscribers: Tasks: Tags:
1 parent b34d1ac commit 230ebf6

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

test/quantization/test_quant_api.py

+40
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,46 @@ def get_per_token_block_size(x):
442442
ref = m_copy(*example_inputs)
443443
self.assertTrue(torch.equal(res, ref))
444444

445+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
446+
def test_quantized_tensor_subclass_int4(self):
447+
from torchao.quantization.subclass import AffineQuantizedTensor
448+
from torchao.quantization.quant_primitives import MappingType
449+
import copy
450+
451+
# weight settings
452+
groupsize = 32
453+
mapping_type = MappingType.ASYMMETRIC
454+
block_size = (1, groupsize)
455+
target_dtype = torch.int8
456+
eps = torch.finfo(torch.bfloat16).eps
457+
quant_min = -8
458+
quant_max = 7
459+
preserve_zero = False
460+
461+
# weight only quantization
462+
input_quant_func = None
463+
464+
m = ToyLinearModel().eval().to(torch.bfloat16)
465+
m_copy = copy.deepcopy(m)
466+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
467+
468+
def to_quantized(weight):
469+
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, preserve_zero, quant_min, quant_max, eps, input_quant_func=input_quant_func)
470+
471+
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
472+
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
473+
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
474+
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
475+
476+
# reference
477+
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
478+
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
479+
480+
res = m(*example_inputs)
481+
ref = m_copy(*example_inputs)
482+
483+
self.assertTrue(torch.equal(res, ref))
484+
445485

446486

447487

torchao/quantization/subclass.py

+55-4
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ class AffineQuantizedTensor(torch.Tensor):
626626
shape (torch.Size): the shape for the Tensor
627627
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
628628
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
629-
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object
629+
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
630630
dtype: dtype for external representation of the tensor, e.g. torch.float32
631631
"""
632632

@@ -642,6 +642,7 @@ def __new__(
642642
quant_max: Optional[int] = None,
643643
input_quant_func: Optional[Callable] = None,
644644
dtype=None,
645+
# TODO: remove args and kwargs
645646
*args,
646647
**kwargs
647648
):
@@ -684,7 +685,9 @@ def __repr__(self):
684685
f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
685686
)
686687

687-
def dequantize(self, output_dtype=torch.float32):
688+
def dequantize(self, output_dtype=None):
689+
if output_dtype is None:
690+
output_dtype = self.dtype
688691
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype)
689692

690693
def __tensor_flatten__(self):
@@ -716,13 +719,15 @@ def from_float(
716719
mapping_type,
717720
block_size,
718721
target_dtype,
722+
preserve_zero = True,
719723
quant_min = None,
720724
quant_max = None,
721725
eps = None,
722726
scale_dtype = None,
723727
zero_point_dtype = None,
724728
input_quant_func = None,
725729
):
730+
# TODO: add preserve_zero arg to choose_qparams_affine
726731
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype)
727732
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
728733
return cls(
@@ -810,7 +815,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
810815
if (
811816
func in [aten.mm.default, aten.addmm.default]
812817
and args[0].is_floating_point()
813-
and args[0].is_cuda
814818
):
815819
if func == aten.addmm.default:
816820
assert args[1].shape[-1] == args[2].shape[0], (
@@ -832,7 +836,54 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
832836
args[1],
833837
None if len(args) == 2 else args[2],
834838
)
835-
if weight_qtensor.input_quant_func is not None:
839+
if weight_qtensor.input_quant_func is None:
840+
is_cuda = args[0].is_cuda
841+
# weight only quantization
842+
is_int8 = (
843+
weight_qtensor.int_data.dtype == torch.int8 and
844+
self.quant_min is None or self.quant_min == -128 and
845+
self.quant_max is None or self.quant_max == 127
846+
)
847+
is_int4 = (
848+
weight_qtensor.int_data.dtype == torch.int8 and
849+
self.quant_min is None or self.quant_min == -8 and
850+
self.quant_max is None or self.quant_max == 7
851+
)
852+
853+
if (
854+
is_cuda and
855+
is_int4 and
856+
len(weight_qtensor.shape) == 2 and
857+
weight_qtensor.block_size[0] == 1
858+
):
859+
# groupwise int4 quantization
860+
# TODO: currently doing packing on the fly, we'll need to figure out
861+
# the API to do packing before hand
862+
# TODO: zero_point transform
863+
# TODO: expose the arg
864+
innerKTiles = 8
865+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data, innerKTiles)
866+
groupsize = weight_qtensor.block_size[-1]
867+
# adjust zero_point to be compatible with tinygemm
868+
def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
869+
return (quant_min - zero_point + mid_point) * scale
870+
871+
mid_point = 8
872+
zero_point_float = int_zero_point_to_float(weight_qtensor.zero_point, weight_qtensor.scale, weight_qtensor.quant_min, mid_point)
873+
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, zero_point_float)
874+
return _weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
875+
elif (
876+
is_cuda and
877+
is_int8 and
878+
len(weight_qtensor.shape) == 2 and
879+
len(weight_qtensor.block_size) == 2 and
880+
weight_qtensor.block_size[0] == 1 and
881+
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
882+
):
883+
# per channel int8 quantization
884+
return _weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
885+
else:
886+
# dynamic quantization
836887
input_tensor = weight_qtensor.input_quant_func(input_tensor)
837888
input_tensor = input_tensor.dequantize()
838889
weight_tensor = weight_qtensor.dequantize()

0 commit comments

Comments
 (0)