diff --git a/README.md b/README.md index 1dd2a72340..e31dc63a8f 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. ```python -from torchao.sparsity import sparsify -from torch.sparse import to_sparse_semi_structured +from torchao.sparsity import sparsify, semi_sparse_weight() -m = sparsify(m, to_sparse_semi_structured) +m = sparsify_(m, semi_sparse_weight()) ``` Sparsity can also be composed with int8 dynamic quantization for further speedups: ```python -from torchao.sparsity import sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight -m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight()) +m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight()) ``` -We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. +We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 5c1262f9cc..c52ce33151 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse - diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index e83ec25300..b9733bd98b 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -9,6 +9,10 @@ import time import resource +from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only +from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight +from torchao.utils import unwrap_tensor_subclass + torch._dynamo.config.cache_size_limit = 50000 def unbind_jagged(device, data, sizes, offsets): @@ -279,30 +283,17 @@ def run( block.attn.use_rel_pos = use_rel_pos if compress == "int8_dynamic_quant": - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight - from torchao.utils import unwrap_tensor_subclass quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) + sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only) elif compress == "sparse": - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured) + sparsify_(predictor.model.image_encoder, semi_sparse_weight()) elif compress == "int8_dynamic_quant_sparse": - from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - SparseSemiStructuredTensor._FORCE_CUTLASS = False - from torchao.sparsity import sparsify, apply_fake_sparsity - from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight - from torchao.utils import unwrap_tensor_subclass - def attn_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'attn' in name def mlp_lin1_only(mod, name): @@ -316,20 +307,17 @@ def mlp_only(mod, name): apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - quantize_( - predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), - attn_only - ) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_weight(), + attn_only) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_semi_sparse_weight(), + mlp_lin1_only) + sparsify_(predictor.model.image_encoder, + semi_sparse_weight(), + mlp_lin2_only) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - int8_dynamic_activation_int8_2x4_sparse_weight(), - mlp_lin1_only, prune=False) - - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - to_sparse_semi_structured, - mlp_lin2_only, prune=False) else: assert compress is None, f"Unsupported compress mode {compress}" @@ -413,6 +401,6 @@ def mlp_only(mod, name): vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile, use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path])) f.write(vals+"\n") - + if __name__ == '__main__': fire.Fire(run) diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 01aad5c022..0be02c7f37 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -1,6 +1,6 @@ device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path -cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None -cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None +cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 3e566732bb..b846afa454 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -1,18 +1,24 @@ +import copy import logging import unittest import torch from torch import nn -from torch.sparse import to_sparse_semi_structured -from torchao.sparsity import apply_fake_sparsity, sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import ( + apply_fake_sparsity, + sparsify_, + int8_dynamic_activation_int8_semi_sparse_weight, + semi_sparse_weight, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, _is_linear, + int8_dynamic_activation_int8_weight, + quantize_, ) -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass from torch.testing._internal.common_utils import TestCase @@ -38,12 +44,11 @@ def test_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - model = sparsify(model, to_sparse_semi_structured) + sparsify_(model, semi_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - class TestQuantSemiSparse(TestCase): @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") @@ -58,15 +63,15 @@ def test_quant_semi_sparse(self): .half() .cuda() ) - apply_fake_sparsity(model) - dense_result = model(input) + model_copy = copy.deepcopy(model) + quantize_(model_copy, int8_dynamic_activation_int8_weight()) + dense_result = model_copy(input) - sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight()) + quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight()) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) - + assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 39372fe27f..e4b47b8229 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -7,6 +7,7 @@ to_affine_quantized_static, LayoutType, PlainLayoutType, + SemiSparseLayoutType, TensorCoreTiledLayoutType, ) @@ -19,5 +20,6 @@ "to_affine_quantized_static", "LayoutType", "PlainLayoutType", + "SemiSparseLayoutType", "TensorCoreTiledLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index da5cc7d28b..807a588aed 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -31,6 +31,17 @@ class PlainLayoutType(LayoutType): pass +@dataclass(frozen=True) +class SemiSparseLayoutType(LayoutType): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp + + @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): inner_k_tiles: int = 8 @@ -472,6 +483,47 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) +@register_layout_cls(SemiSparseLayoutType) +class SemiSparseAQTLayout(PlainAQTLayout): + """ + Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor + """ + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + def get_plain(self): + # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # the identity matrix to get the original dense matrix. This is slow though. + cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) + int_data_expanded = torch._cslt_sparse_mm(self.int_data, + torch.eye(cols, + dtype=self.int_data.dtype, + device=self.int_data.device).t()) + return int_data_expanded, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, SemiSparseLayoutType) + int_data_compressed = torch._cslt_compress(int_data) + return cls(int_data_compressed, scale, zero_point, layout_type) + + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -668,6 +720,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): if bias is not None: y += bias return y + # handle int8 dynamic_quant + semi_structured_sparse + elif( + is_cuda and + input_is_int8 and + input_tensor.dtype == weight_qtensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_qtensor.layout_type, SemiSparseLayoutType) + ): + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8 = weight_qtensor.layout_tensor.int_data + w_scales = weight_qtensor.layout_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y else: input_tensor = input_tensor.dequantize() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1cf1bf034..6bf37f0080 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,6 +32,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3b02930c3c..161a84c4e4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -14,13 +14,14 @@ come along with it and because that is how we access the intended quantized and mixed GEMM kernels """ - +from functools import partial import torch import torchao import torch.nn as nn import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional +from torchao.dtypes import PlainLayoutType from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, @@ -57,6 +58,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] @@ -410,7 +412,8 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) -def int8_dynamic_activation_int8_weight(): + +def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers @@ -432,16 +435,31 @@ def get_weight_block_size(x): zero_point_dtype = torch.int64 # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + input_mapping_type = MappingType.SYMMETRIC input_target_dtype = torch.int8 input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) weight = to_linear_act_quantized(weight, input_quant_func) return weight return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) + + +def int8_dynamic_activation_int8_semi_sparse_weight(): + """ + Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight + quantization + 2:4 sparsity to linear layers. + """ + from torchao.dtypes import SemiSparseLayoutType + return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 9b288c07f9..c3b10f949a 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,18 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify +from .sparse_api import ( + apply_fake_sparsity, + sparsify_, + semi_sparse_weight, + int8_dynamic_activation_int8_semi_sparse_weight +) __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", - "sparsify" + "sparsify_" + "semi_sparse_weight", + "int8_dynamic_activation_int8_semi_sparse_weight" ] diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py deleted file mode 100644 index 2f2a198278..0000000000 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ /dev/null @@ -1,314 +0,0 @@ -import torch -import torch.nn as nn -from typing import Tuple, Optional - -from torchao.quantization.utils import ( - dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax, - dequantize_per_channel, -) - -from torchao.quantization.subclass import ( - Int8DynamicallyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) - -from torch.sparse import to_sparse_semi_structured - -# Quant + Sparse helper functinos -def sparse_quant_int8_dynamic_linear( - x : torch.Tensor, - w_vals_int8_packed : torch.Tensor, - w_meta_int32 : Optional[torch.Tensor], - w_scales : torch.Tensor, - bias : Optional[torch.Tensor], - out_dtype : torch.dtype, - fuse_mul=False, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - # w_meta_int32 is either None or meta tensor - if w_meta_int32 is None: - if fuse_mul: - mm_out = sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype, - ) - - if bias is not None: - mm_out += bias - return mm_out - -def sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - -def sparse_quant_int8_cslt_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - - -def sparse_quant_int8_cutlass_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_meta_int32, - w_scales, - out_dtype, -): - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" - assert w_meta_int32.dtype == torch.int32, f"{w_meta_int32.dtype} not yet supported" - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_int32 = torch._sparse_semi_structured_linear( - tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32 - ) - y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - ) - y = y.to(out_dtype) - return y - -class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( - Int8DynamicallyQuantizedLinearWeight -): - def dequantize(self, dtype=None): - # overload dequantize op for __repr__ - zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, torch.eye(self.shape[1], - dtype=self.int_data.dtype, - device=self.int_data.device)) - dq_t = dequantize_per_channel( - int_data_expanded, self.q_scales, zero_points, self.dtype if dtype is None else dtype - ).to(self.dtype) - - return dq_t if not self.transposed else dq_t.t() - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype, - fuse_mul=True - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - int_data = torch._cslt_compress(int_data) - - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - - -class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): - - @staticmethod - def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - self.q_scales = q_scales - self.mask_meta = mask_meta - super().__init__(int_data, transposed) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - dq_t = dequantize_per_channel( - self.int_data, self.q_scales, 0, self.dtype if dtype is None else dtype - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.mask_meta.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.mask_meta), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype, - ) - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.mask_meta, - self.q_scales, - self.transposed, - shape, - dtype=self.dtype, - ) - - def __tensor_flatten__(self): - return ["int_data", "mask_meta", "q_scales"], [ - self.transposed, - self.dtype, - self.shape, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None - ): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - mask_meta = tensor_data_dict["mask_meta"] - transposed, dtype, shape = tensor_attributes - return cls( - int_data, - mask_meta, - q_scales, - transposed, - shape if outer_size is None else outer_size, - dtype=dtype, - strides=outer_stride, - ) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, - w_qtensor.int_data, - w_qtensor.mask_meta, - w_qtensor.q_scales, - bias, - act_mat.dtype, - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - sparse_tensor = to_sparse_semi_structured(int_data) - - return cls( - sparse_tensor.packed, - sparse_tensor.meta, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - -def int8_dynamic_activation_int8_2x4_sparse_weight(): - return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 8f8ca24a39..a12d954422 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -7,6 +7,7 @@ _is_linear, _replace_with_custom_fn_if_matches_filter, _get_linear_subclass_inserter, + int8_dynamic_activation_int8_semi_sparse_weight, ) # Sparsity helper functions @@ -29,16 +30,21 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.step() sparsifier.squash_mask() +def semi_sparse_weight(): + """ + Convert the weight of linear moduels to semi-structured (2:4) sparsity + """ + return _get_linear_subclass_inserter(to_sparse_semi_structured) -def sparsify(model: torch.nn.Module, +def sparsify_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support two options for sparsity: - - semi-structured (2:4) sparsity with `to_sparse_semi_structured` - - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API + - semi-structured (2:4) sparsity with `semi_sparse_weight` + - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_semi_sparse_weight`, which is also available via the quantize API Args: model (torch.nn.Module): input model @@ -49,7 +55,7 @@ def sparsify(model: torch.nn.Module, Example:: import torch import torch.nn as nn - from torchao.sparsity import sparsify + from torchao.sparsity import sparsify_ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) @@ -57,17 +63,15 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) # for 2:4 sparsity - from torch.sparse import to_sparse_semi_structured - m = sparsify(m, to_sparse_semi_structured, filter_fn) + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity - from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight - m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) + from torchao.sparsity.prototype import int8_dynamic_activation_int8_semi_sparse_weight + m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight(), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, - _get_linear_subclass_inserter(apply_tensor_subclass), + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, ) - - return model