-
Notifications
You must be signed in to change notification settings - Fork 312
Add sparse marlin AQT layout #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
fdad96f
feat: starting layout implementation
Diogo-V b810707
compile kind of working
jcaip b5eddd8
fix: batching and layout outputs correct results
Diogo-V cf5e286
fix: torch.compile
Diogo-V ef17ed6
chore: cleanup
Diogo-V ea70c74
chore: review
Diogo-V 153fd0b
chore: review v2
Diogo-V 333a88f
update benchmarks + README
jcaip File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import torch | ||
import copy | ||
import pytest | ||
|
||
from torch import nn | ||
from torch.testing._internal.common_utils import TestCase, run_tests | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
from torchao.dtypes import MarlinSparseLayoutType | ||
from torchao.sparsity.sparse_api import apply_fake_sparsity | ||
from torchao.quantization.quant_api import int4_weight_only, quantize_ | ||
from torchao.sparsity.marlin import ( | ||
pack_to_marlin_24, | ||
unpack_from_marlin_24, | ||
inject_24 | ||
) | ||
from torchao.quantization.quant_primitives import ( | ||
choose_qparams_affine, | ||
quantize_affine, | ||
ZeroPointDomain, | ||
MappingType, | ||
) | ||
|
||
|
||
class SparseMarlin24(TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
torch.manual_seed(0) | ||
|
||
self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda") | ||
self.model = ( | ||
nn.Sequential( | ||
nn.Linear(4096, 21504), | ||
nn.Linear(21504, 4096), | ||
nn.ReLU(), | ||
nn.Linear(4096, 21504), | ||
nn.Linear(21504, 4096), | ||
) | ||
.half() | ||
.cuda() | ||
) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_quant_sparse_marlin_layout_eager(self): | ||
apply_fake_sparsity(self.model) | ||
model_copy = copy.deepcopy(self.model) | ||
|
||
# Quantized | ||
quantize_(model_copy.bfloat16(), int4_weight_only()) | ||
dense_result = model_copy(self.input.bfloat16()).half() | ||
|
||
# Sparse + quantized | ||
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) | ||
sparse_result = self.model(self.input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" | ||
|
||
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_quant_sparse_marlin_layout_compile(self): | ||
apply_fake_sparsity(self.model) | ||
model_copy = copy.deepcopy(self.model) | ||
|
||
# Quantized | ||
quantize_(model_copy.bfloat16(), int4_weight_only()) | ||
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True) | ||
dense_result = model_copy(self.input.bfloat16()).half() | ||
|
||
# Sparse + quantized | ||
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType())) | ||
self.model.forward = torch.compile(self.model.forward, fullgraph=True) | ||
sparse_result = self.model(self.input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_pack_unpack_equivalence(self): | ||
num_bits = 4 | ||
group_size = 128 | ||
shape = (11008, 4096) | ||
block_size = (1, group_size) | ||
target_dtype = torch.int32 | ||
quant_min = 0 | ||
quant_max = 15 | ||
eps = 1e-6 | ||
zero_point_dtype = torch.bfloat16 | ||
mapping_type = MappingType.SYMMETRIC | ||
preserve_zero = True | ||
zero_point_domain = ZeroPointDomain.INT | ||
scale_dtype = None | ||
|
||
w = torch.rand(shape, dtype=torch.float16, device="cuda") | ||
|
||
# Inject 2:4 sparsity mask | ||
w_24, _ = inject_24(w, *w.shape) | ||
|
||
# Quantize weights | ||
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) | ||
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain) | ||
scales = scales.reshape(-1, w_q_24.shape[1]) | ||
|
||
# Test pack/unpack equivalence | ||
q_w_comp, packed_scales, meta = pack_to_marlin_24( | ||
w_q_24, scales, num_bits, group_size | ||
) | ||
unpacked_q_w, unpacked_scales = unpack_from_marlin_24( | ||
q_w_comp, packed_scales, meta, shape, group_size, num_bits | ||
) | ||
|
||
assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights" | ||
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales" | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.