Skip to content

Add sparse marlin AQT layout #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import copy
import pytest

from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes import MarlinSparseLayoutType
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.sparsity.marlin import (
pack_to_marlin_24,
unpack_from_marlin_24,
inject_24
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
ZeroPointDomain,
MappingType,
)


class SparseMarlin24(TestCase):

def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
nn.ReLU(),
nn.Linear(4096, 21504),
nn.Linear(21504, 4096),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
group_size = 128
shape = (11008, 4096)
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity mask
w_24, _ = inject_24(w, *w.shape)

# Quantize weights
scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain)
scales = scales.reshape(-1, w_q_24.shape[1])

# Test pack/unpack equivalence
q_w_comp, packed_scales, meta = pack_to_marlin_24(
w_q_24, scales, num_bits, group_size
)
unpacked_q_w, unpacked_scales = unpack_from_marlin_24(
q_w_comp, packed_scales, meta, shape, group_size, num_bits
)

assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights"
assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales"


if __name__ == "__main__":
run_tests()
30 changes: 27 additions & 3 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
int8_dynamic_activation_int8_semi_sparse_weight,
semi_sparse_weight,
)
from torchao.dtypes import MarlinSparseLayoutType
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
int8_dynamic_activation_int8_weight,
quantize_,
int4_weight_only,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase
Expand Down Expand Up @@ -73,5 +72,30 @@ def test_quant_semi_sparse(self):

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse_marlin(self):
input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
nn.Linear(256, 1024),
nn.Linear(1024, 256),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy.bfloat16(), int4_weight_only())
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

if __name__ == "__main__":
unittest.main()
24 changes: 15 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
)


MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
MNK_FACTORS = [
Expand All @@ -318,8 +319,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
Expand Down Expand Up @@ -374,15 +375,15 @@ def reshape_w(w):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
@pytest.mark.parametrize("batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
a_input = torch.randn((batch_size, size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")

# Inject 2:4 sparsity
Expand All @@ -391,19 +392,24 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
# Symmetric quantize
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Obtains reference output
output_ref = torch.matmul(a_input, w_24_ref)
output_ref = torch.matmul(input_2d, w_24_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (scale.shape[1],))

# Packs to marlin 2:4
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
torch.cuda.synchronize()
output = output.reshape(a_input.shape[:-1] + (marlin_24_scale.shape[1],))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04
Expand Down
1 change: 1 addition & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ kv cache quantization:
20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072
20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization
20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask
20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
4 changes: 3 additions & 1 deletion torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt

# sparse marlin (NOTE: float16)
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
# auto-round w/ quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
# auto-round w/o quant_lm_head
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0



export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization
Expand Down
3 changes: 3 additions & 0 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def main(
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if "autoround" in quantization:
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
from transformers import AutoTokenizer
Expand Down
2 changes: 2 additions & 0 deletions torchao/_models/sam/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
# int8 dynamic quant attn + int4 wo + sparse marlin lin 1 + 2:4 sparse lin2
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half float16 --device cuda --compress int4_weight_only_sparse
34 changes: 24 additions & 10 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def run(
for block in predictor.model.image_encoder.blocks:
block.attn.use_rel_pos = use_rel_pos

# Helper filter functions
def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
def mlp_lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

if compress == "int8_dynamic_quant":
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -296,15 +306,6 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder)
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
elif compress == "int8_dynamic_quant_sparse":
def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name
def mlp_lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name

# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
Expand All @@ -320,7 +321,20 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress == "int4_weight_only_sparse":
# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
from torchao.dtypes import MarlinSparseLayoutType
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
sparsify_(predictor.model.image_encoder
semi_sparse_weight(),
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
1 change: 1 addition & 0 deletions torchao/_models/sam/results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,17068,21,23.96093702681232,41.73459489004953,0.5485481164943489,max-autotune,torch.float16,int4_weight_only_sparse,False,True,True,32,154,4928,None,None
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1123,4 +1123,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::marlin_24_gemm", &marlin_24_gemm);
}

} // namespace torchao
} // namespace torchao
2 changes: 1 addition & 1 deletion torchao/csrc/sparse_marlin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TORCH_LIBRARY_FRAGMENT(torchao, m) {
m.impl_abstract_pystub("torchao.ops");
m.def("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor");
}
}
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TensorCoreTiledLayoutType,
Float8LayoutType,
Float8AQTLayout,
MarlinSparseLayoutType,
)

__all__ = [
Expand All @@ -33,4 +34,5 @@
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Float8AQTLayout",
"MarlinSparseLayoutType",
]
Loading
Loading