diff --git a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py index ac36de6..b3011a8 100644 --- a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py @@ -15,6 +15,7 @@ import json import logging import os +import re from pathlib import Path from typing import Optional @@ -43,11 +44,21 @@ def is_aws_neuron_available(): logger = logging.getLogger(__name__) PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" +SAFETENSORS_WEIGHTS_NAME = "model.safetensors" TF2_WEIGHTS_NAME = "tf_model.h5" -FRAMEWORK_MAPPING = {"pytorch": PYTORCH_WEIGHTS_NAME, "tensorflow": TF2_WEIGHTS_NAME} +FRAMEWORK_MAPPING = { + "pytorch": PYTORCH_WEIGHTS_NAME, + "tensorflow": TF2_WEIGHTS_NAME, +} +PYTORCH_WEIGHTS_NAME_PATTERN = r"pytorch_model-\d+-\of-\d+.bin" +SAFETENSORS_WEIGHTS_NAME_PATTERN = r"model-\d+-\of-\d+\.safetensors" +TF2_WEIGHTS_NAME_PATTERN = r"tf_model-\d+-\of-\d+.h5" FILE_LIST_NAMES = [ "config.json", + "model.safetensors.index.json", + "pytorch_model.bin.index.json", + "tf_model.h5.index.json", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.json", @@ -192,11 +203,40 @@ def _load_model_from_hub( os.makedirs(storage_folder, exist_ok=True) # filters files to download - download_file_list = [ - file.rfilename - for file in model_info.siblings - if file.rfilename in FILE_LIST_NAMES + [FRAMEWORK_MAPPING[framework]] - ] + download_file_list = [] + + # prioritize safe tensors weights if they exist + repo_using_safetensors = False + for file in model_info.siblings: + if file.rfilename == SAFETENSORS_WEIGHTS_NAME or re.match(SAFETENSORS_WEIGHTS_NAME_PATTERN, file.rfilename): + repo_using_safetensors = True + download_file_list = [ + file.rfilename + for file in model_info.siblings + if file.rfilename == SAFETENSORS_WEIGHTS_NAME + or file.rfilename in FILE_LIST_NAMES + or re.match(SAFETENSORS_WEIGHTS_NAME_PATTERN, file.rfilename) + ] + break + + # if repo doesn't use safetensors, use framework specific weights + if not repo_using_safetensors: + if (framework) == "pytorch": + download_file_list = [ + file.rfilename + for file in model_info.siblings + if file.rfilename == PYTORCH_WEIGHTS_NAME + or file.rfilename in FILE_LIST_NAMES + or re.match(PYTORCH_WEIGHTS_NAME_PATTERN, file.rfilename) + ] + elif (framework) == "tensorflow": + download_file_list = [ + file.rfilename + for file in model_info.siblings + if file.rfilename == TF2_WEIGHTS_NAME + or file.rfilename in FILE_LIST_NAMES + or re.match(TF2_WEIGHTS_NAME_PATTERN, file.rfilename) + ] # download files to storage_folder and removes cache for file in download_file_list: