diff --git a/.github/workflows/integ-test.yml b/.github/workflows/integ-test.yml index 3906b98..eb2ec64 100644 --- a/.github/workflows/integ-test.yml +++ b/.github/workflows/integ-test.yml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install Python dependencies run: pip install -e .[test,dev] - name: Run Integration Tests diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 1a59089..0b3f24f 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install Python dependencies run: pip install -e .[quality] - name: Run Quality check diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 847ec70..3b67e79 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 - name: Install Python dependencies run: pip install -e .[test,dev] - name: Run Unit Tests diff --git a/README.md b/README.md index a1a4d3c..d6a2757 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ [![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black) -SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests. +SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers and Diffusers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests. For Training, see [Run training on Amazon SageMaker](https://huggingface.co/docs/sagemaker/train). @@ -109,6 +109,14 @@ The `HF_API_TOKEN` environment variable defines the your Hugging Face authorizat HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX" ``` +#### `HF_TRUST_REMOTE_CODE` + +The `HF_TRUST_REMOTE_CODE` environment variable defines wether or not to allow for custom models defined on the Hub in their own modeling files. Allowed values are `"True"` and `"False"` + +```bash +HF_TRUST_REMOTE_CODE="True" +``` + #### `HF_OPTIMUM_BATCH_SIZE` The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted. @@ -172,3 +180,25 @@ Install all test and development packages with ```bash pip3 install -e ".[test,dev]" ``` +## Run Model Locally + +1. manually change `MMS_CONFIG_FILE` +``` +wget -O sagemaker-mms.properties https://raw.githubusercontent.com/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties +``` + +2. Run Container, e.g. `text-to-image` +``` +HF_MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0" HF_TASK="text-to-image" python src/sagemaker_huggingface_inference_toolkit/serving.py +``` +3. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally + +3. Send request +``` +curl --request POST \ + --url http://localhost:8080/invocations \ + --header 'Accept: image/png' \ + --header 'Content-Type: application/json' \ + --data '"{\"inputs\": \"Camera\"}" \ + --output image.png +``` \ No newline at end of file diff --git a/setup.py b/setup.py index 8391d7c..68cfde1 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ # Hugging Face specific dependencies extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"] +extras["diffusers"] = ["diffusers>=0.23.0"] # framework specific dependencies extras["torch"] = ["torch>=1.8.0", "torchaudio"] @@ -87,8 +88,7 @@ "flake8>=3.8.3", ] -extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] - +extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] + extras["diffusers"] setup( name="sagemaker-huggingface-inference-toolkit", version=VERSION, diff --git a/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py b/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py index 00eac62..a5e061e 100644 --- a/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py +++ b/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py @@ -20,7 +20,6 @@ import numpy as np from sagemaker_inference import errors from sagemaker_inference.decoder import _npy_to_numpy -from sagemaker_inference.encoder import _array_to_npy from mms.service import PredictionException from PIL import Image @@ -100,7 +99,7 @@ def default(self, obj): return super(_JSONEncoder, self).default(obj) -def encode_json(content): +def encode_json(content, accept_type=None): """ encodes json with custom `JSONEncoder` """ @@ -114,7 +113,25 @@ def encode_json(content): ) -def encode_csv(content): # type: (str) -> np.array +def _array_to_npy(array_like, accept_type=None): + """Convert an array-like object to the NPY format. + + To understand better what an array-like object is see: + https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays + + Args: + array_like (np.array or Iterable or int or float): array-like object + to be converted to NPY. + + Returns: + (obj): NPY array. + """ + buffer = BytesIO() + np.save(buffer, array_like) + return buffer.getvalue() + + +def encode_csv(content, accept_type=None): """Convert the result of a transformers pipeline to CSV. Args: content (dict | list): result of transformers pipeline. @@ -133,10 +150,32 @@ def encode_csv(content): # type: (str) -> np.array return stream.getvalue() +def encode_image(image, accept_type=content_types.PNG): + """Convert a PIL.Image object to a byte stream. + Args: + image (PIL.Image): image to be converted. + accept_type (str): content type of the image. + Returns: + (bytes): byte stream of the image. + """ + accept_type = "PNG" if content_types.X_IMAGE == accept_type else accept_type.split("/")[-1].upper() + + with BytesIO() as out: + image.save(out, format=accept_type) + return out.getvalue() + + _encoder_map = { content_types.NPY: _array_to_npy, content_types.CSV: encode_csv, content_types.JSON: encode_json, + content_types.JPEG: encode_image, + content_types.PNG: encode_image, + content_types.TIFF: encode_image, + content_types.BMP: encode_image, + content_types.GIF: encode_image, + content_types.WEBP: encode_image, + content_types.X_IMAGE: encode_image, } _decoder_map = { content_types.NPY: _npy_to_numpy, @@ -172,12 +211,12 @@ def decode(content, content_type=content_types.JSON): raise pred_err -def encode(content, content_type=content_types.JSON): +def encode(content, accept_type=content_types.JSON): """ Encode an 🤗 Transformers object in a specific content_type. """ try: - encoder = _encoder_map[content_type] - return encoder(content) + encoder = _encoder_map[accept_type] + return encoder(content, accept_type) except KeyError: - raise errors.UnsupportedFormatError(content_type) + raise errors.UnsupportedFormatError(accept_type) diff --git a/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py b/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py new file mode 100644 index 0000000..068a41a --- /dev/null +++ b/src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py @@ -0,0 +1,75 @@ +# Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.util +import logging + +from transformers.utils.import_utils import is_torch_bf16_gpu_available + + +logger = logging.getLogger(__name__) + +_diffusers = importlib.util.find_spec("diffusers") is not None + + +def is_diffusers_available(): + return _diffusers + + +if is_diffusers_available(): + import torch + + from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline + + +class SMAutoPipelineForText2Image: + def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU + dtype = torch.float32 + if device == "cuda": + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 + device_map = "auto" if device == "cuda" else None + + self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map) + # try to use DPMSolverMultistepScheduler + if isinstance(self.pipeline, StableDiffusionPipeline): + try: + self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) + except Exception: + pass + self.pipeline.to(device) + + def __call__( + self, + prompt, + **kwargs, + ): + # TODO: add support for more images (Reason is correct output) + if "num_images_per_prompt" in kwargs: + kwargs.pop("num_images_per_prompt") + logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.") + + # Call pipeline with parameters + out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs) + return out.images[0] + + +DIFFUSERS_TASKS = { + "text-to-image": SMAutoPipelineForText2Image, +} + + +def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs): + """Get a pipeline for Diffusers models.""" + device = "cuda" if device == 0 else "cpu" + pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device) + return pipeline diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index af4d396..a30bc6e 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -105,9 +105,13 @@ def load(self, model_dir): elif "config.json" in os.listdir(model_dir): task = infer_task_from_model_architecture(f"{model_dir}/config.json") hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device) + elif "model_index.json" in os.listdir(model_dir): + task = "text-to-image" + hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device) else: raise ValueError( - f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} as env 'HF_TASK'.", 403 + f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.", + 403, ) return hf_pipeline diff --git a/src/sagemaker_huggingface_inference_toolkit/optimum_utils.py b/src/sagemaker_huggingface_inference_toolkit/optimum_utils.py index 0968267..22a5202 100644 --- a/src/sagemaker_huggingface_inference_toolkit/optimum_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/optimum_utils.py @@ -73,6 +73,7 @@ def get_input_shapes(model_dir): return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)} +# TODO: not used yet, need to sync on how to determine if we are running on inf2 instance def get_optimum_neuron_pipeline(task, model_dir): """Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised.""" from optimum.neuron.pipelines import NEURONX_SUPPORTED_TASKS, pipeline diff --git a/src/sagemaker_huggingface_inference_toolkit/serving.py b/src/sagemaker_huggingface_inference_toolkit/serving.py index 74f521b..a0e70b4 100644 --- a/src/sagemaker_huggingface_inference_toolkit/serving.py +++ b/src/sagemaker_huggingface_inference_toolkit/serving.py @@ -32,3 +32,7 @@ def _start_mms(): def main(): _start_mms() + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py index ac36de6..ba8141a 100644 --- a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py @@ -18,13 +18,12 @@ from pathlib import Path from typing import Optional -from huggingface_hub import HfApi -from huggingface_hub.file_download import cached_download, hf_hub_url -from transformers import pipeline +from huggingface_hub import HfApi, login, snapshot_download +from transformers import AutoTokenizer, pipeline from transformers.file_utils import is_tf_available, is_torch_available from transformers.pipelines import Conversation, Pipeline -from sagemaker_huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available +from sagemaker_huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, is_diffusers_available if is_tf_available(): @@ -40,43 +39,40 @@ def is_aws_neuron_available(): return _aws_neuron_available +def strtobool(val): + """Convert a string representation of truth to True or False. + True values are 'y', 'yes', 't', 'true', 'on', '1', 'TRUE', or 'True'; false values + are 'n', 'no', 'f', 'false', 'off', '0', 'FALSE' or 'False. Raises ValueError if + 'val' is anything else. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1", "TRUE", "True"): + return True + elif val in ("n", "no", "f", "false", "off", "0", "FALSE", "False"): + return False + else: + raise ValueError("invalid truth value %r" % (val,)) + + logger = logging.getLogger(__name__) -PYTORCH_WEIGHTS_NAME = "pytorch_model.bin" -TF2_WEIGHTS_NAME = "tf_model.h5" -FRAMEWORK_MAPPING = {"pytorch": PYTORCH_WEIGHTS_NAME, "tensorflow": TF2_WEIGHTS_NAME} - -FILE_LIST_NAMES = [ - "config.json", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - "vocab.json", - "vocab.txt", - "merges.txt", - "dict.txt", - "preprocessor_config.json", - "added_tokens.json", - "README.md", - "spiece.model", - "sentencepiece.bpe.model", - "sentencepiece.bpe.vocab", - "sentence.bpe.model", - "bpe.codes", - "source.spm", - "target.spm", - "spm.model", - "sentence_bert_config.json", - "sentence_roberta_config.json", - "sentence_distilbert_config.json", - "added_tokens.json", - "model_args.json", - "entity_vocab.json", - "pooling_config.json", -] - -if is_optimum_neuron_available(): - FILE_LIST_NAMES.append("model.neuron") + +FRAMEWORK_MAPPING = { + "pytorch": "pytorch*", + "tensorflow": "tf*", + "tf": "tf*", + "pt": "pytorch*", + "flax": "flax*", + "rust": "rust*", + "onnx": "*onnx*", + "safetensors": "*safetensors", + "coreml": "*mlmodel", + "tflite": "*tflite", + "savedmodel": "*tar.gz", + "openvino": "*openvino*", + "ckpt": "*ckpt", +} + REPO_ID_SEPARATOR = "__" @@ -99,6 +95,21 @@ def is_aws_neuron_available(): HF_API_TOKEN = os.environ.get("HF_API_TOKEN", None) HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", None) +TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "False")) + + +def create_artifact_filter(framework): + """ + Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading + """ + ignore_regex_list = list(set(FRAMEWORK_MAPPING.values())) + + pattern = FRAMEWORK_MAPPING.get(framework, None) + if pattern in ignore_regex_list: + ignore_regex_list.remove(pattern) + return ignore_regex_list + else: + return [] def wrap_conversation_pipeline(pipeline): @@ -179,11 +190,8 @@ def _load_model_from_hub( "This is an experimental beta features, which allows downloading model from the Hugging Face Hub on start up. " "It loads the model defined in the env var `HF_MODEL_ID`" ) - # get all files from repository - _api = HfApi() - model_info = _api.model_info(repo_id=model_id, revision=revision, token=use_auth_token) - os.makedirs(model_dir, exist_ok=True) - + if use_auth_token is not None: + login(token=use_auth_token) # extracts base framework framework = _get_framework() @@ -191,21 +199,23 @@ def _load_model_from_hub( storage_folder = _build_storage_path(model_id, model_dir, revision) os.makedirs(storage_folder, exist_ok=True) - # filters files to download - download_file_list = [ - file.rfilename - for file in model_info.siblings - if file.rfilename in FILE_LIST_NAMES + [FRAMEWORK_MAPPING[framework]] - ] - - # download files to storage_folder and removes cache - for file in download_file_list: - url = hf_hub_url(model_id, filename=file, revision=revision) - - path = cached_download(url, cache_dir=storage_folder, force_filename=file, use_auth_token=use_auth_token) - - if os.path.exists(path + ".lock"): - os.remove(path + ".lock") + # check if safetensors weights are available + if framework == "pytorch": + files = HfApi().model_info(model_id).siblings + if any(f.rfilename.endswith("safetensors") for f in files): + framework = "safetensors" + + # create regex to only include the framework specific weights + ignore_regex = create_artifact_filter(framework) + + # Download the repository to the workdir and filter out non-framework specific weights + snapshot_download( + model_id, + revision=revision, + local_dir=str(storage_folder), + local_dir_use_symlinks=False, + ignore_patterns=ignore_regex, + ) return storage_folder @@ -273,8 +283,23 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline: else: kwargs["tokenizer"] = model_dir - # load pipeline - hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) + if TRUST_REMOTE_CODE and os.environ.get("HF_MODEL_ID", None) is not None and device == 0: + tokenizer = AutoTokenizer.from_pretrained(os.environ["HF_MODEL_ID"]) + + hf_pipeline = pipeline( + task=task, + model=os.environ["HF_MODEL_ID"], + tokenizer=tokenizer, + trust_remote_code=TRUST_REMOTE_CODE, + model_kwargs={"device_map": "auto", "torch_dtype": "auto"}, + ) + elif is_diffusers_available() and task == "text-to-image": + hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs) + else: + # load pipeline + hf_pipeline = pipeline( + task=task, model=model_dir, device=device, trust_remote_code=TRUST_REMOTE_CODE, **kwargs + ) # wrapp specific pipeline to support better ux if task == "conversational": diff --git a/tests/integ/test_diffusers.py b/tests/integ/test_diffusers.py new file mode 100644 index 0000000..e14d1fc --- /dev/null +++ b/tests/integ/test_diffusers.py @@ -0,0 +1,84 @@ +import os +import re +from io import BytesIO + +import boto3 +from integ.utils import clean_up, timeout_and_delete_by_name +from PIL import Image +from sagemaker import Session +from sagemaker.model import Model + + +os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") +SAGEMAKER_EXECUTION_ROLE = os.environ.get("SAGEMAKER_EXECUTION_ROLE", "sagemaker_execution_role") + + +def get_framework_ecr_image(registry_id="763104351884", repository_name="huggingface-pytorch-inference", device="cpu"): + client = boto3.client("ecr") + + def get_all_ecr_images(registry_id, repository_name, result_key): + response = client.list_images( + registryId=registry_id, + repositoryName=repository_name, + ) + results = response[result_key] + while "nextToken" in response: + response = client.list_images( + registryId=registry_id, + nextToken=response["nextToken"], + repositoryName=repository_name, + ) + results.extend(response[result_key]) + return results + + images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds") + image_tags = [image["imageTag"] for image in images] + image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$") + tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0] + return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}" + + +# TODO: needs existing container +def test_text_to_image_model(): + image_uri = get_framework_ecr_image(repository_name="huggingface-pytorch-inference", device="gpu") + + name = "hf-test-text-to-image" + task = "text-to-image" + model = "echarlaix/tiny-random-stable-diffusion-xl" + # instance_type = "ml.m5.large" if device == "cpu" else "ml.g4dn.xlarge" + instance_type = "local_gpu" + env = {"HF_MODEL_ID": model, "HF_TASK": task} + + sagemaker_session = Session() + client = boto3.client("sagemaker-runtime") + + hf_model = Model( + image_uri=image_uri, # A Docker image URI. + model_data=None, # The S3 location of a SageMaker model data .tar.gz + env=env, # Environment variables to run with image_uri when hosted in SageMaker (default: None). + role=SAGEMAKER_EXECUTION_ROLE, # An AWS IAM role (either name or full ARN). + name=name, # The model name + sagemaker_session=sagemaker_session, + ) + + with timeout_and_delete_by_name(name, sagemaker_session, minutes=59): + # Use accelerator type to differentiate EI vs. CPU and GPU. Don't use processor value + hf_model.deploy( + initial_instance_count=1, + instance_type=instance_type, + endpoint_name=name, + ) + response = client.invoke_endpoint( + EndpointName=name, + Body={"inputs": "a yellow lemon tree"}, + ContentType="application/json", + Accept="image/png", + ) + + # validate response + response_body = response["Body"].read().decode("utf-8") + + img = Image.open(BytesIO(response_body)) + assert isinstance(img, Image.Image) + + clean_up(endpoint_name=name, sagemaker_session=sagemaker_session) diff --git a/tests/integ/test_models_from_hub.py b/tests/integ/test_models_from_hub.py index 87e0ed2..7bd7de0 100644 --- a/tests/integ/test_models_from_hub.py +++ b/tests/integ/test_models_from_hub.py @@ -36,7 +36,6 @@ def get_all_ecr_images(registry_id, repository_name, result_key): images = get_all_ecr_images(registry_id=registry_id, repository_name=repository_name, result_key="imageIds") image_tags = [image["imageTag"] for image in images] - print(image_tags) image_regex = re.compile("\d\.\d\.\d-" + device + "-.{4}$") tag = sorted(list(filter(image_regex.match, image_tags)), reverse=True)[0] return f"{registry_id}.dkr.ecr.{os.environ.get('AWS_DEFAULT_REGION','us-east-1')}.amazonaws.com/{repository_name}:{tag}" @@ -169,7 +168,6 @@ def test_deployment_from_hub(task, device, framework): "p95_request_time": np.percentile(time_buffer, 95), "body": json.loads(response_body), } - print(data) json.dump(data, outfile) assert task2performance[task][device]["average_request_time"] >= np.mean(time_buffer) diff --git a/tests/unit/test_diffusers_utils.py b/tests/unit/test_diffusers_utils.py new file mode 100644 index 0000000..c00c139 --- /dev/null +++ b/tests/unit/test_diffusers_utils.py @@ -0,0 +1,52 @@ +# Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +from transformers.testing_utils import require_torch, slow + +from PIL import Image +from sagemaker_huggingface_inference_toolkit.diffusers_utils import SMAutoPipelineForText2Image +from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline + + +@require_torch +def test_get_diffusers_pipeline(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_dir = _load_model_from_hub( + "hf-internal-testing/tiny-stable-diffusion-torch", + tmpdirname, + ) + pipe = get_pipeline("text-to-image", -1, storage_dir) + assert isinstance(pipe, SMAutoPipelineForText2Image) + + +@slow +@require_torch +def test_pipe_on_gpu(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_dir = _load_model_from_hub( + "hf-internal-testing/tiny-stable-diffusion-torch", + tmpdirname, + ) + pipe = get_pipeline("text-to-image", 0, storage_dir) + assert pipe.device.type == "cuda" + + +@require_torch +def test_text_to_image_task(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_dir = _load_model_from_hub("hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname) + pipe = get_pipeline("text-to-image", -1, storage_dir) + res = pipe("Lets create an embedding") + assert isinstance(res, Image.Image) diff --git a/tests/unit/test_transformers_utils.py b/tests/unit/test_transformers_utils.py index 028a55b..902a074 100644 --- a/tests/unit/test_transformers_utils.py +++ b/tests/unit/test_transformers_utils.py @@ -19,9 +19,6 @@ from transformers.testing_utils import require_tf, require_torch, slow from sagemaker_huggingface_inference_toolkit.transformers_utils import ( - FILE_LIST_NAMES, - PYTORCH_WEIGHTS_NAME, - TF2_WEIGHTS_NAME, _build_storage_path, _get_framework, _is_gpu_available, @@ -51,9 +48,7 @@ def test_loading_model_from_hub(): # folder contains all config files and pytorch_model.bin folder_contents = os.listdir(storage_folder) - assert any([True for files in FILE_LIST_NAMES if files in folder_contents]) - assert PYTORCH_WEIGHTS_NAME in folder_contents - assert TF2_WEIGHTS_NAME not in folder_contents + assert "config.json" in folder_contents @require_torch @@ -64,11 +59,21 @@ def test_loading_model_from_hub_with_revision(): # folder contains all config files and pytorch_model.bin assert REVISION in storage_folder folder_contents = os.listdir(storage_folder) - assert any([True for files in FILE_LIST_NAMES if files in folder_contents]) - assert PYTORCH_WEIGHTS_NAME in folder_contents + assert "config.json" in folder_contents assert "tokenizer_config.json" not in folder_contents +@require_torch +def test_loading_model_safetensor_from_hub_with_revision(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_model_from_hub( + model_id="hf-internal-testing/tiny-random-bert-safetensors", model_dir=tmpdirname + ) + + folder_contents = os.listdir(storage_folder) + assert "model.safetensors" in folder_contents + + def test_gpu_is_not_available(): device = _is_gpu_available() assert device is False