34
34
from .modules import WeightOnlyLinear
35
35
36
36
DEBUG = False
37
+ accelerator = auto_detect_accelerator ()
37
38
38
39
39
40
# ================ device related ===================
@@ -542,8 +543,10 @@ def forward(layer, *args, **kwargs):
542
543
if self .run_fn :
543
544
if self .run_args :
544
545
self .run_fn (self .model , * self .run_args )
546
+ accelerator .mark_step ()
545
547
else :
546
548
self .run_fn (self .model )
549
+ accelerator .mark_step ()
547
550
else :
548
551
for batch in tqdm (self .dataloader ):
549
552
if not self .use_layer_wise :
@@ -663,6 +666,7 @@ def tmp(_, inp, out):
663
666
for j in range (batch_num ):
664
667
cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
665
668
cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
669
+ accelerator .mark_step ()
666
670
out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )
667
671
out = self .track_hidden_states (out )
668
672
self .cache_key_arguments ["batch_num" ] = batch_num
@@ -682,6 +686,9 @@ def tmp(_, inp, out):
682
686
W = load_value (self .model , full_layer_name + ".weight" , model_path )
683
687
else :
684
688
W = sub_layers [layer_name ].weight .data .clone ()
689
+ accelerator .mark_step ()
690
+ if "hpu" in self .device :
691
+ W = W .to ("cpu" )
685
692
scale , zp , Q = gptq_for_this_block [layer_name ].fasterquant (
686
693
W ,
687
694
blocksize = weight_config_this_layer ["block_size" ],
@@ -854,6 +861,8 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
854
861
self .quantizer .find_params (W , weight = True )
855
862
856
863
H = self .H
864
+ if "hpu" in self .device :
865
+ H = H .to ("cpu" )
857
866
del self .H
858
867
dead = torch .diag (H ) == 0
859
868
H [dead , dead ] = 1
@@ -958,6 +967,10 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
958
967
zero .append (self .quantizer .zero )
959
968
scale = torch .cat (scale , dim = 1 )
960
969
zero = torch .cat (zero , dim = 1 )
970
+ if "hpu" in self .device :
971
+ scale = scale .to (self .device )
972
+ zero = zero .to (self .device )
973
+ Q = Q .to (self .device )
961
974
return scale , zero , Q
962
975
963
976
def free (self ):
@@ -973,25 +986,25 @@ def free(self):
973
986
class Quantizer (nn .Module ):
974
987
def __init__ (self , shape = 1 ):
975
988
super (Quantizer , self ).__init__ ()
976
- self .register_buffer ( " maxq" , torch . tensor ( 0 ))
989
+ self .maxq = 0
977
990
self .register_buffer ("scale" , torch .zeros (shape ))
978
991
self .register_buffer ("zero" , torch .zeros (shape ))
979
992
980
993
def configure (self , weight_config_this_layer , norm = 2.4 , grid = 100 , maxshrink = 0.8 , trits = False ):
981
994
for k , v in weight_config_this_layer .items ():
982
995
setattr (self , k , v )
983
- self .maxq = torch .tensor (2 ** self .bits - 1 )
996
+ # self.maxq = torch.tensor(2**self.bits - 1)
997
+ self .maxq = 2 ** self .bits - 1
984
998
self .scheme = "sym" if self .sym else "asym"
985
999
self .double_quant_scheme = "sym" if self .double_quant_sym else "asym"
986
1000
self .norm = norm
987
1001
self .grid = grid
988
1002
self .maxshrink = maxshrink
989
1003
if trits :
990
- self .maxq = torch . tensor ( - 1 )
1004
+ self .maxq = - 1
991
1005
992
1006
def find_params (self , x , weight = False ):
993
1007
dev = x .device
994
- self .maxq = self .maxq .to (dev )
995
1008
# NF4 FP4
996
1009
if self .dtype != "int" :
997
1010
from .utility import quant_tensor
0 commit comments