diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 050819a22..0a37fa299 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -11,6 +11,7 @@ ) from huggingface_hub import snapshot_download from transformers.integrations.deepspeed import HfDeepSpeedConfig +from transformers.modeling_utils import no_init_weights from dschat.utils.model.reward_model import RewardModel from dschat.utils.utils import load_state_dict_into_model, print_rank_0 @@ -99,7 +100,8 @@ def create_hf_model(model_class, dschf = None if rlhf_training: # the weight loading is handled by create critic model - model = model_class.from_config(model_config) + with no_init_weights(): + model = model_class.from_config(model_config) else: model = model_class.from_pretrained( model_name_or_path,