Skip to content

Improve GemLite Integration #2096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 25, 2025
Merged

Improve GemLite Integration #2096

merged 24 commits into from
Apr 25, 2025

Conversation

mobicham
Copy link
Collaborator

@mobicham mobicham commented Apr 22, 2025

Tasks

  • Faster get_plain() via Triton unpacking instead of matmul with the identity matrix -> slicing should be faster.
  • Fix serialization issues.
  • General clean-up.
  • Remove cuda device check in .to() to allow loading models in transformers.
  • Force cuda device in from_plain / get_plain for vllm weight loader compatibllity.
  • Add bfloat16 support.

Note: Slicing still performs unpacking -> packing. If we can restrict the step to be self._layout.packing_bitwidth // self._layout.bit_width we can avoid this and slice directly the packed data.

Test

import torch, gemlite
from torchao.quantization import GemliteUIntXWeightOnlyConfig, quantize_
device = 'cuda:0'
dtype = torch.float16

layer = torch.nn.Linear(256, 512, bias=False, dtype=dtype, device=device)
weight = layer.weight.data.clone()
orig_shape = weight.shape

group_size = 64
quantize_(layer, GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=group_size))

#Test dot prod
#################################################
torch.manual_seed(0)
x = torch.randn((1, layer.in_features), device=device, dtype=dtype) / 10.
y_ref = x @ weight.T
y_gem  = layer(x)
assert (y_ref - y_gem).abs().mean() < 5e-3, "Dot product mismatch"

#Test get_plain() 
#################################################
W_q, s, z       = layer.weight.tensor_impl.get_plain()
W_r_getplain = ((W_q.view([-1, group_size]) - z.view(-1, 1)) * s.view(-1, 1)).view(orig_shape)
W_r_dotprod = layer(torch.eye(layer.in_features, device=device, dtype=dtype)).T #SLOW because it autotunes
assert (W_r_getplain - W_r_dotprod).abs().mean() < 1e-4, "get_plain() incorrect results"
assert (y_gem - (x @ W_r_getplain.T)).abs().mean() < 1e-4, "get_plain() incorrect results"

#Test slicing 
#################################################
def _dequant(tensor_impl, in_features, orig_shape):
    int_data   = tensor_impl.packed_weight
    scale      = tensor_impl.scale
    zero_point = tensor_impl.zero_point

    W_q = gemlite.bitpack.unpack_over_rows(int_data, W_nbits=4, num_output_rows=in_features, dtype=torch.uint8).T.contiguous()
    s   = scale.t().contiguous()
    z   = zero_point.t().contiguous()
    return ((W_q.view([-1, group_size]) - z.view(-1, 1)) * s.view(-1, 1)).view(orig_shape)

torch.manual_seed(0)
x     = torch.randn((1, layer.in_features), device=device, dtype=dtype) / 10.
y_ref = layer(x)
W_r   = _dequant(layer.weight.tensor_impl, layer.in_features, orig_shape).T
y_2   = x @ W_r
assert (y_ref - y_2).abs().mean() < 1e-4, "Incorrect dequant results"


layer_sliced = layer.weight.narrow(0, 0, 256)
W_slice1   = _dequant(layer_sliced.tensor_impl, layer.in_features, [orig_shape[0]//2, orig_shape[1]]).T
assert (W_r[:, :256] - W_slice1).abs().mean() == 0, "slice1 along axis=0 is incorrect"

layer_sliced = layer.weight.narrow(0, 256, 256)
W_slice2   = _dequant(layer_sliced.tensor_impl, layer.in_features, [orig_shape[0]//2, orig_shape[1]]).T
assert (W_r[:, 256:] - W_slice2).abs().mean() == 0 , "slice2 along axis=0 is incorrect"

layer_sliced = layer.weight.narrow(1, 0, 128)
W_slice1   = _dequant(layer_sliced.tensor_impl, layer.in_features//2, [orig_shape[0], orig_shape[1]//2]).T
assert (W_r[:128, :] - W_slice1).abs().mean() == 0, "slice1 along axis=1 is incorrect"

layer_sliced = layer.weight.narrow(1, 128, 128)
W_slice2   = _dequant(layer_sliced.tensor_impl, layer.in_features//2, [orig_shape[0], orig_shape[1]//2]).T
assert (W_r[128:, :] - W_slice2).abs().mean() == 0 , "slice2 along axis=1 is incorrect"

Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2096

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0e65350 with merge base 11472c9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2025
@mobicham mobicham changed the title Improve GemLite Integration Improve GemLite Integration topic: improvemen Apr 22, 2025
@mobicham mobicham changed the title Improve GemLite Integration topic: improvemen Improve GemLite Integration Apr 22, 2025
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Apr 23, 2025
@@ -91,12 +91,30 @@ def wrapper(*args, **kwargs):


def skip_if_no_cuda():
import unittest
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might have to use unittest to make CI happy I think

@jerryzh168 jerryzh168 added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) and removed topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) labels Apr 25, 2025
@jerryzh168 jerryzh168 merged commit 3da411a into pytorch:main Apr 25, 2025
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants