Skip to content

Implement sparsity as a AQT Layout #498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table

Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
```python
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured
from torchao.sparsity import sparsify, semi_sparse_weight()

m = sparsify(m, to_sparse_semi_structured)
m = sparsify_(m, semi_sparse_weight())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does semi_sparse_weight have to talk about dtype as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will work for bf16, fp16, and fp32, so i don't think specifying the dtype makes sense. Maybe dense_activation_semi_sparse_weight to keep it consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then it's fine. we have int4_weight_only() as well so I feel it's fine that we don't mention activation

(we could remove only as well)

```
Sparsity can also be composed with int8 dynamic quantization for further speedups:

```python
from torchao.sparsity import sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight

m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight())
```
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.

The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
Expand Down
1 change: 0 additions & 1 deletion scripts/sam/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse

44 changes: 16 additions & 28 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.cache_size_limit = 50000

def unbind_jagged(device, data, sizes, offsets):
Expand Down Expand Up @@ -279,30 +283,17 @@ def run(
block.attn.use_rel_pos = use_rel_pos

if compress == "int8_dynamic_quant":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only)
elif compress == "sparse":
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
elif compress == "int8_dynamic_quant_sparse":
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
Expand All @@ -316,20 +307,17 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

quantize_(
predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only
)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_semi_sparse_weight(),
mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
mlp_lin2_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
int8_dynamic_activation_int8_2x4_sparse_weight(),
mlp_lin1_only, prune=False)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
to_sparse_semi_structured,
mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down Expand Up @@ -413,6 +401,6 @@ def mlp_only(mod, name):
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
f.write(vals+"\n")

if __name__ == '__main__':
fire.Fire(run)
10 changes: 5 additions & 5 deletions scripts/sam/results.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
27 changes: 16 additions & 11 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import copy
import logging
import unittest

import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured

from torchao.sparsity import apply_fake_sparsity, sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import (
apply_fake_sparsity,
sparsify_,
int8_dynamic_activation_int8_semi_sparse_weight,
semi_sparse_weight,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
int8_dynamic_activation_int8_weight,
quantize_,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass
from torch.testing._internal.common_utils import TestCase


Expand All @@ -38,12 +44,11 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

model = sparsify(model, to_sparse_semi_structured)
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
Expand All @@ -58,15 +63,15 @@ def test_quant_semi_sparse(self):
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)
model_copy = copy.deepcopy(model)
quantize_(model_copy, int8_dynamic_activation_int8_weight())
dense_result = model_copy(input)

sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
to_affine_quantized_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
)

Expand All @@ -19,5 +20,6 @@
"to_affine_quantized_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
]
77 changes: 77 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
class PlainLayoutType(LayoutType):
pass

@dataclass(frozen=True)
class SemiSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
# prune to 2:4 if not already
temp = input.detach()
pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2]
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp


@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
inner_k_tiles: int = 8
Expand Down Expand Up @@ -472,6 +483,47 @@ def from_plain(
assert isinstance(layout_type, PlainLayoutType)
return cls(int_data, scale, zero_point, layout_type)

@register_layout_cls(SemiSparseLayoutType)
class SemiSparseAQTLayout(PlainAQTLayout):
"""
Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor
"""
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

raise NotImplementedError(
f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported"
)

def get_plain(self):
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# the identity matrix to get the original dense matrix. This is slow though.
cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0])
int_data_expanded = torch._cslt_sparse_mm(self.int_data,
torch.eye(cols,
dtype=self.int_data.dtype,
device=self.int_data.device).t())
return int_data_expanded, self.scale, self.zero_point

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType,
):
assert isinstance(layout_type, SemiSparseLayoutType)
int_data_compressed = torch._cslt_compress(int_data)
return cls(int_data_compressed, scale, zero_point, layout_type)


@register_layout_cls(TensorCoreTiledLayoutType)
class TensorCoreTiledAQTLayout(AQTLayout):
"""
Expand Down Expand Up @@ -668,6 +720,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
if bias is not None:
y += bias
return y
# handle int8 dynamic_quant + semi_structured_sparse
elif(
is_cuda and
input_is_int8 and
input_tensor.dtype == weight_qtensor.dtype and
isinstance(input_tensor.layout_type, PlainLayoutType) and
isinstance(weight_qtensor.layout_type, SemiSparseLayoutType)
):
x_vals_int8 = input_tensor.layout_tensor.int_data
x_scales = input_tensor.layout_tensor.scale
w_vals_int8 = weight_qtensor.layout_tensor.int_data
w_scales = weight_qtensor.layout_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16
).t()
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
else:
input_tensor = input_tensor.dequantize()

Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
]
26 changes: 22 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
come along with it and because that is how we access the intended quantized
and mixed GEMM kernels
"""

from functools import partial
import torch
import torchao
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes import PlainLayoutType
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -57,6 +58,7 @@
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
]
Expand Down Expand Up @@ -410,7 +412,8 @@ def apply_int8wo_quant(weight):

return _get_linear_subclass_inserter(apply_int8wo_quant)

def int8_dynamic_activation_int8_weight():

def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
quantization to linear layers
Expand All @@ -432,16 +435,31 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
weight = to_linear_act_quantized(weight, input_quant_func)
return weight

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant)


def int8_dynamic_activation_int8_semi_sparse_weight():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have similar config as int8_dynamic_activation_int8_weight? if so we can add a layout_type arg to that function directly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I meant that we could remove this, and just use int8_dynamic_activation_int8_weight for sparsity as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, yeah that sounds good to me too.

"""
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
11 changes: 9 additions & 2 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_fake_sparsity, sparsify
from .sparse_api import (
apply_fake_sparsity,
sparsify_,
semi_sparse_weight,
int8_dynamic_activation_int8_semi_sparse_weight
)

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
"apply_fake_sparsity",
"sparsify"
"sparsify_"
"semi_sparse_weight",
"int8_dynamic_activation_int8_semi_sparse_weight"
]
Loading
Loading