Skip to content

Commit 285d20b

Browse files
committed
Fix quantize embedding error
1 parent d3a4def commit 285d20b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,8 @@ def quantize( # noqa C901
120120
if group_size is None:
121121
raise Exception("For 8da4w quantization, group size must be specified.")
122122

123-
from torchao.quantization import (
124-
quantize_,
125-
int8_dynamic_activation_int4_weight,
126-
)
123+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
124+
127125
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
128126

129127
if verbose:
@@ -664,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module:
664662
def quantized_model(self) -> nn.Module:
665663
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
666664
self.convert_for_runtime()
667-
self.mod.load_state_dict(model_updated_state_dict)
665+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
668666
return self.mod
669667

670668

0 commit comments

Comments
 (0)