Skip to content

Enable gemlite copy_ and fix slice #2071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
GemliteUIntXWeightOnlyConfig,
float8_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
@@ -35,7 +36,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm, skip_if_no_gemlite
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
check_cpu_version,
@@ -342,7 +343,7 @@ def test_alias(self, device, dtype):
@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
def test_slice(self, device, dtype):
def test_slice_int4wo(self, device, dtype):
# in_feature not divisible by 1024
# out_feature not divisible by 8
# to test slice + padding for int4 weight only quantization
@@ -353,6 +354,21 @@ def test_slice(self, device, dtype):
_ = dummy.weight.narrow(1, 0, 128)


@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.float16])
@skip_if_no_cuda()
@skip_if_no_gemlite()
def test_slice_gemlite(self, device, dtype):
# in_feature not divisible by 1024
# out_feature not divisible by 8
# to test slice + padding for int4 weight only quantization
dummy = nn.Linear(256, 512, dtype=dtype, device=device)
quantize_(dummy, GemliteUIntXWeightOnlyConfig())
# make sure these run without error
_ = dummy.weight.narrow(0, 0, 64)
_ = dummy.weight.narrow(1, 0, 128)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

76 changes: 48 additions & 28 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,25 @@

aten = torch.ops.aten

def _same_metadata(
self: "GemliteAQTTensorImpl", src: "GemliteAQTTensorImpl",
) -> bool:
kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs)
for k, v in self.gemlite_kwargs.items():
if k != "scale_activations":
kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k])

return (
isinstance(self, GemliteAQTTensorImpl)
and isinstance(src, GemliteAQTTensorImpl)
and self.shape == src.shape
and self.packed_weight.shape == src.packed_weight.shape
and self.scale.shape == src.scale.shape
and self.zero_point.shape == src.zero_point.shape
and kwargs_match
and type(self._layout) == type(src._layout)
)


def get_gemlite_quant_kwargs(bit_width, group_size):
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
@@ -172,6 +191,7 @@ def from_plain(
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
print(f"from plain: {int_data.shape=} {scale.shape=}")
from gemlite.core import DType, GemLiteLinearTriton, set_autotune

assert isinstance(
@@ -296,9 +316,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

if func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
assert step == 1, "Only step == 1 is supported in slicing right now"
assert step == 1, "Only step == 1 is supported in slicing right now"
if dim in [0, 1]:
int_data, scale, zero_point = self.get_plain()
# scale and zero_point are transposed compared to int_data
data_len = int_data.shape[dim]
param_dim = 1 - dim
scale_len = scale.shape[param_dim]
@@ -307,46 +328,41 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
scale = aten.slice.Tensor(
scale, param_dim, start_scale, end_scale, step
)
scale = aten.slice.Tensor(scale, param_dim, start_scale, end_scale, step)
if zero_point is not None and zero_point.numel() > 0:
zero_point = aten.slice.Tensor(
zero_point, param_dim, start_scale, end_scale, step
)
else:
zero_point = None

sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
elif dim == 1:
assert step == 1, "Only step == 1 is supported in slicing right now"
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
# scale and zero_point are transposed compared to int_data
param_dim = 1 - dim
scale_len = scale.shape[param_dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
scale = aten.slice.Tensor(
scale, param_dim, start_scale, end_scale, step
# this is to handle padding
int_data, scale, zero_point = self._layout.post_process(
int_data, scale, zero_point, self.block_size
)
if zero_point is not None and zero_point.numel() > 0:
zero_point = aten.slice.Tensor(
zero_point, param_dim, start_scale, end_scale, step
)
else:
zero_point = None
# TODO: maybe get_plain should not output a transposed scale and zp?
# since scale and zero_point are transposed from `get_plain`,
# we need to transpose them back before feeding to from_plain
scale = scale.t().contiguous()
zero_point = zero_point.t().contiguous()
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
else:
raise NotImplementedError(
f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)

raise NotImplementedError(
f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)
@@ -356,6 +372,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
def get_layout(self) -> Layout:
return self._layout

@property
def block_size(self):
return (1, self._layout.group_size)


# logic taken from gemlite's core.py
def _matmul_type_fn(batch_size: int, bit_width: int) -> str:
22 changes: 20 additions & 2 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
@@ -91,12 +91,30 @@ def wrapper(*args, **kwargs):


def skip_if_no_cuda():
import unittest
import pytest

def decorator(test_func):
@functools.wraps(test_func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
raise unittest.SkipTest("No cuda available")
raise pytest.skip("No cuda available")
return test_func(*args, **kwargs)

return wrapper

return decorator


def skip_if_no_gemlite():
import pytest

def decorator(test_func):
@functools.wraps(test_func)
def wrapper(*args, **kwargs):
try:
import gemlite
except:
raise pytest.skip("No cuda available")
return test_func(*args, **kwargs)

return wrapper