-
Notifications
You must be signed in to change notification settings - Fork 310
Implement sparsity as a AQT Layout #498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e8f1fb1
4e95ebd
f88873c
44dadfc
fd26210
ae8b206
b13c6f4
514b74c
f8fb6aa
11e2534
2086394
5f97b88
0971d92
e7608cf
1d3b2cd
0d5907c
c1797cb
17f0ea1
dc0ab16
c544449
83ac9fc
a2af83d
ce567c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path | ||
cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None | ||
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,14 @@ | |
come along with it and because that is how we access the intended quantized | ||
and mixed GEMM kernels | ||
""" | ||
|
||
from functools import partial | ||
import torch | ||
import torchao | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from typing import Any, Callable, Union, Dict, Optional | ||
|
||
from torchao.dtypes import PlainLayoutType | ||
from torchao.utils import ( | ||
TORCH_VERSION_AFTER_2_4, | ||
unwrap_tensor_subclass, | ||
|
@@ -57,6 +58,7 @@ | |
"quantize_", | ||
"int8_dynamic_activation_int4_weight", | ||
"int8_dynamic_activation_int8_weight", | ||
"int8_dynamic_activation_int8_semi_sparse_weight", | ||
"int4_weight_only", | ||
"int8_weight_only", | ||
] | ||
|
@@ -410,7 +412,8 @@ def apply_int8wo_quant(weight): | |
|
||
return _get_linear_subclass_inserter(apply_int8wo_quant) | ||
|
||
def int8_dynamic_activation_int8_weight(): | ||
|
||
def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): | ||
""" | ||
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight | ||
quantization to linear layers | ||
|
@@ -432,16 +435,31 @@ def get_weight_block_size(x): | |
zero_point_dtype = torch.int64 | ||
|
||
# input settings | ||
def get_per_token_block_size(x): | ||
block_size = list(x.shape) | ||
for i in range(len(block_size)-1): | ||
block_size[i] = 1 | ||
return block_size | ||
|
||
input_mapping_type = MappingType.SYMMETRIC | ||
input_target_dtype = torch.int8 | ||
input_eps = 1e-5 | ||
input_quant_min = -127 | ||
input_quant_max = 127 | ||
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) | ||
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) | ||
|
||
block_size = get_weight_block_size(weight) | ||
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) | ||
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) | ||
weight = to_linear_act_quantized(weight, input_quant_func) | ||
return weight | ||
|
||
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) | ||
|
||
|
||
def int8_dynamic_activation_int8_semi_sparse_weight(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this have similar config as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I meant that we could remove this, and just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, yeah that sounds good to me too. |
||
""" | ||
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight | ||
quantization + 2:4 sparsity to linear layers. | ||
""" | ||
from torchao.dtypes import SemiSparseLayoutType | ||
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does
semi_sparse_weight
have to talk about dtype as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will work for bf16, fp16, and fp32, so i don't think specifying the dtype makes sense. Maybe
dense_activation_semi_sparse_weight
to keep it consistent?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, then it's fine. we have
int4_weight_only()
as well so I feel it's fine that we don't mention activation(we could remove
only
as well)