@@ -198,7 +198,7 @@ def forward(self, x):
198
198
199
199
def _maybe_get_quantized_linear_native (nbit , has_weight_zeros ):
200
200
try :
201
- if nbit in [1 , 2 , 3 , 4 , 5 , 6 , 7 ]:
201
+ if nbit in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ]:
202
202
wzp_suffix = "" if has_weight_zeros else "0zp"
203
203
return _Int8DynActIntxWeightQuantizedLinearNative (
204
204
pack_weight_op = getattr (
@@ -230,7 +230,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
230
230
has_weight_zeros = kwargs ["has_weight_zeros" ]
231
231
232
232
assert not isinstance (module , nn .Linear )
233
- assert nbit >= 1 and nbit <= 7
233
+ assert nbit >= 1 and nbit <= 8
234
234
235
235
for name , child in module .named_children ():
236
236
if not isinstance (child , nn .Linear ):
@@ -366,9 +366,9 @@ def quantize_and_pack_weights(self, weights, group_size):
366
366
weight_qvals , weight_scales , weight_zeros = _quantize (
367
367
weights , self .group_size , self .nbit , has_weight_zeros = True
368
368
)
369
- self .weight_qvals = weight_qvals .to (torch .int8 )
369
+ self .weight_qvals = weight_qvals .to (torch .int32 )
370
370
self .weight_scales = weight_scales
371
- self .weight_zeros = weight_zeros .to (torch .int8 )
371
+ self .weight_zeros = weight_zeros .to (torch .int32 )
372
372
373
373
def forward (self , x ):
374
374
shape = x .shape
@@ -394,7 +394,7 @@ def _replace_embedding_with_quantized_embedding(module: nn.Module, kwargs={}):
394
394
nbit = kwargs ["nbit" ]
395
395
396
396
assert not isinstance (module , nn .Embedding )
397
- assert nbit >= 1 and nbit <= 7
397
+ assert nbit >= 1 and nbit <= 8
398
398
399
399
for name , child in module .named_children ():
400
400
if not isinstance (child , nn .Embedding ):
0 commit comments