diff --git a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py index 36517b9..e87f0ba 100644 --- a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py @@ -15,6 +15,7 @@ import json import logging import os +import re from pathlib import Path from typing import Optional @@ -41,8 +42,15 @@ def is_aws_neuron_available(): logger = logging.getLogger(__name__) PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" +PYTORCH_WEIGHTS_SHARDED_NAME = "pytorch_model-\d{5}-of-\d{5}.bin" TF2_WEIGHTS_NAME = "tf_model.h5" -FRAMEWORK_MAPPING = {"pytorch": PYTORCH_WEIGHTS_NAME, "tensorflow": TF2_WEIGHTS_NAME} +TF2_WEIGHTS_SHARDED_NAME = "tf_model-\d{5}-of-\d{5}.h5" +SHARDED_INDEX_EXTENSION = ".index.json" +FRAMEWORK_MAPPING = { + "pytorch": [PYTORCH_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME + SHARDED_INDEX_EXTENSION], + "tensorflow": [TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME + SHARDED_INDEX_EXTENSION], +} +FRAMEWORK_MAPPING_SHARDED = {"pytorch": PYTORCH_WEIGHTS_SHARDED_NAME, "tensorflow": TF2_WEIGHTS_SHARDED_NAME} FILE_LIST_NAMES = [ "config.json", @@ -188,9 +196,7 @@ def _load_model_from_hub( # filters files to download download_file_list = [ - file.rfilename - for file in model_info.siblings - if file.rfilename in FILE_LIST_NAMES + [FRAMEWORK_MAPPING[framework]] + file.rfilename for file in model_info.siblings if _should_download_file(file.rfilename, framework) ] # download files to storage_folder and removes cache @@ -205,6 +211,18 @@ def _load_model_from_hub( return storage_folder +def _should_download_file(filename: str, framework: str) -> bool: + """ + Check if a file is in the global allowlist of files to download, is allowed based on the framework, or is a sharded file. + """ + if filename in FILE_LIST_NAMES + FRAMEWORK_MAPPING[framework]: + return True + elif re.match(FRAMEWORK_MAPPING_SHARDED[framework], filename): + return True + else: + return False + + def infer_task_from_model_architecture(model_config_path: str, architecture_index=0) -> str: """ Infer task from `config.json` of trained model. It is not guaranteed to the detect, e.g. some models implement multiple architectures or diff --git a/tests/unit/test_transformers_utils.py b/tests/unit/test_transformers_utils.py index 028a55b..ac2d101 100644 --- a/tests/unit/test_transformers_utils.py +++ b/tests/unit/test_transformers_utils.py @@ -26,6 +26,7 @@ _get_framework, _is_gpu_available, _load_model_from_hub, + _should_download_file, get_pipeline, infer_task_from_hub, infer_task_from_model_architecture, @@ -158,3 +159,16 @@ def test_wrapped_pipeline(): res = conv_pipe(data) assert "conversation" in res assert "generated_text" in res + + +def test_allow_sharded_files(): + assert _should_download_file("pytorch_model-00001-of-00002.bin", "pytorch") is True + assert _should_download_file("pytorch_model-00002-of-00002.bin", "pytorch") is True + assert _should_download_file("pytorch_model-00002-of-00002.bin", "tensorflow") is False + assert _should_download_file("pytorch_model-abc-of-def.bin", "pytorch") is False + assert _should_download_file("tf_model-00001-of-00002.h5", "tensorflow") is True + assert _should_download_file("tf_model-00002-of-00002.h5", "tensorflow") is True + assert _should_download_file("tf_model.h5.index.json", "tensorflow") is True + assert _should_download_file("pytorch_model.bin.index.json", "pytorch") is True + assert _should_download_file("tf_model.h5.index.json", "pytorch") is False + assert _should_download_file("pytorch_model.bin.index.json", "tensorflow") is False