Skip to content

Commit e667169

Browse files
committed
renaming get_state_dict function
1 parent 37c1443 commit e667169

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

src/transformers/modeling_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4012,8 +4012,11 @@ def save_pretrained(
40124012
repo_id = self._create_repo(repo_id, **kwargs)
40134013
files_timestamps = self._get_files_timestamps(save_directory)
40144014

4015+
metadata = {}
40154016
if hf_quantizer is not None:
4016-
state_dict = hf_quantizer.get_state_dict(self)
4017+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
4018+
metadata["format"] = "pt"
4019+
40174020
# Only save the model itself if we are using distributed training
40184021
model_to_save = unwrap_model(self)
40194022
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
@@ -4291,7 +4294,7 @@ def save_pretrained(
42914294
if safe_serialization:
42924295
# At some point we will need to deal better with save_function (used for TPU and other distributed
42934296
# joyfulness), but for now this enough.
4294-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
4297+
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
42954298
else:
42964299
save_function(shard, os.path.join(save_directory, shard_file))
42974300

src/transformers/quantizers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ def is_compileable(self) -> bool:
334334
"""Flag indicating whether the quantized model can be compiled"""
335335
return False
336336

337-
def get_state_dict(self, model):
338-
"""Get state dict. Useful when we need to modify a bit the state dict due to quantization"""
339-
return None
337+
def get_state_dict_and_metadata(self, model, safe_serialization=False):
338+
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
339+
return None, {}
340340

341341
@abstractmethod
342342
def _process_model_before_weight_loading(self, model, **kwargs): ...

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def update_param_name(self, param_name: str) -> str:
366366
return param_name.replace("down_proj", "down_proj_blocks")
367367
return param_name
368368

369-
def get_state_dict(self, model):
369+
def get_state_dict_and_metadata(self, model):
370370
from ..integrations import Mxfp4GptOssExperts
371371

372372
state_dict = model.state_dict()
@@ -398,7 +398,8 @@ def get_state_dict(self, model):
398398
).transpose(-1, -2)
399399
)
400400

401-
return state_dict
401+
metadata = {}
402+
return state_dict, metadata
402403

403404
def is_serializable(self, safe_serialization=None):
404405
return True

0 commit comments

Comments
 (0)