Skip to content

Add diffusers utils #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/integ-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8
- name: Install Python dependencies
run: pip install -e .[test,dev]
- name: Run Integration Tests
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8
- name: Install Python dependencies
run: pip install -e .[quality]
- name: Run Quality check
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8
- name: Install Python dependencies
run: pip install -e .[test,dev]
- name: Run Unit Tests
Expand Down
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
[![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black)


SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests.
SageMaker Hugging Face Inference Toolkit is an open-source library for serving 🤗 Transformers and Diffusers models on Amazon SageMaker. This library provides default pre-processing, predict and postprocessing for certain 🤗 Transformers and Diffusers models and tasks. It utilizes the [SageMaker Inference Toolkit](https://github.com/aws/sagemaker-inference-toolkit) for starting up the model server, which is responsible for handling inference requests.

For Training, see [Run training on Amazon SageMaker](https://huggingface.co/docs/sagemaker/train).

Expand Down Expand Up @@ -109,6 +109,14 @@ The `HF_API_TOKEN` environment variable defines the your Hugging Face authorizat
HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
```

#### `HF_TRUST_REMOTE_CODE`

The `HF_TRUST_REMOTE_CODE` environment variable defines wether or not to allow for custom models defined on the Hub in their own modeling files. Allowed values are `"True"` and `"False"`

```bash
HF_TRUST_REMOTE_CODE="True"
```

#### `HF_OPTIMUM_BATCH_SIZE`

The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
Expand Down Expand Up @@ -172,3 +180,25 @@ Install all test and development packages with
```bash
pip3 install -e ".[test,dev]"
```
## Run Model Locally

1. manually change `MMS_CONFIG_FILE`
```
wget -O sagemaker-mms.properties https://github.com/raw/aws/deep-learning-containers/master/huggingface/build_artifacts/inference/config.properties
```

2. Run Container, e.g. `text-to-image`
```
HF_MODEL_ID="stabilityai/stable-diffusion-xl-base-1.0" HF_TASK="text-to-image" python src/sagemaker_huggingface_inference_toolkit/serving.py
```
3. Adjust `handler_service.py` and comment out `if content_type in content_types.UTF8_TYPES:` thats needed for SageMaker but cannot be used locally

3. Send request
```
curl --request POST \
--url http://localhost:8080/invocations \
--header 'Accept: image/png' \
--header 'Content-Type: application/json' \
--data '"{\"inputs\": \"Camera\"}" \
--output image.png
```
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

# Hugging Face specific dependencies
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]
extras["diffusers"] = ["diffusers>=0.23.0"]

# framework specific dependencies
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
Expand Down Expand Up @@ -87,8 +88,7 @@
"flake8>=3.8.3",
]

extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"]

extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] + extras["diffusers"]
setup(
name="sagemaker-huggingface-inference-toolkit",
version=VERSION,
Expand Down
53 changes: 46 additions & 7 deletions src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
from sagemaker_inference import errors
from sagemaker_inference.decoder import _npy_to_numpy
from sagemaker_inference.encoder import _array_to_npy

from mms.service import PredictionException
from PIL import Image
Expand Down Expand Up @@ -100,7 +99,7 @@ def default(self, obj):
return super(_JSONEncoder, self).default(obj)


def encode_json(content):
def encode_json(content, accept_type=None):
"""
encodes json with custom `JSONEncoder`
"""
Expand All @@ -114,7 +113,25 @@ def encode_json(content):
)


def encode_csv(content): # type: (str) -> np.array
def _array_to_npy(array_like, accept_type=None):
"""Convert an array-like object to the NPY format.

To understand better what an array-like object is see:
https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays

Args:
array_like (np.array or Iterable or int or float): array-like object
to be converted to NPY.

Returns:
(obj): NPY array.
"""
buffer = BytesIO()
np.save(buffer, array_like)
return buffer.getvalue()


def encode_csv(content, accept_type=None):
"""Convert the result of a transformers pipeline to CSV.
Args:
content (dict | list): result of transformers pipeline.
Expand All @@ -133,10 +150,32 @@ def encode_csv(content): # type: (str) -> np.array
return stream.getvalue()


def encode_image(image, accept_type=content_types.PNG):
"""Convert a PIL.Image object to a byte stream.
Args:
image (PIL.Image): image to be converted.
accept_type (str): content type of the image.
Returns:
(bytes): byte stream of the image.
"""
accept_type = "PNG" if content_types.X_IMAGE == accept_type else accept_type.split("/")[-1].upper()

with BytesIO() as out:
image.save(out, format=accept_type)
return out.getvalue()


_encoder_map = {
content_types.NPY: _array_to_npy,
content_types.CSV: encode_csv,
content_types.JSON: encode_json,
content_types.JPEG: encode_image,
content_types.PNG: encode_image,
content_types.TIFF: encode_image,
content_types.BMP: encode_image,
content_types.GIF: encode_image,
content_types.WEBP: encode_image,
content_types.X_IMAGE: encode_image,
}
_decoder_map = {
content_types.NPY: _npy_to_numpy,
Expand Down Expand Up @@ -172,12 +211,12 @@ def decode(content, content_type=content_types.JSON):
raise pred_err


def encode(content, content_type=content_types.JSON):
def encode(content, accept_type=content_types.JSON):
"""
Encode an 🤗 Transformers object in a specific content_type.
"""
try:
encoder = _encoder_map[content_type]
return encoder(content)
encoder = _encoder_map[accept_type]
return encoder(content, accept_type)
except KeyError:
raise errors.UnsupportedFormatError(content_type)
raise errors.UnsupportedFormatError(accept_type)
75 changes: 75 additions & 0 deletions src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2023 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import logging

from transformers.utils.import_utils import is_torch_bf16_gpu_available


logger = logging.getLogger(__name__)

_diffusers = importlib.util.find_spec("diffusers") is not None


def is_diffusers_available():
return _diffusers


if is_diffusers_available():
import torch

from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline


class SMAutoPipelineForText2Image:
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
dtype = torch.float32
if device == "cuda":
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
device_map = "auto" if device == "cuda" else None

self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
# try to use DPMSolverMultistepScheduler
if isinstance(self.pipeline, StableDiffusionPipeline):
try:
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
except Exception:
pass
self.pipeline.to(device)

def __call__(
self,
prompt,
**kwargs,
):
# TODO: add support for more images (Reason is correct output)
if "num_images_per_prompt" in kwargs:
kwargs.pop("num_images_per_prompt")
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")

# Call pipeline with parameters
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
return out.images[0]


DIFFUSERS_TASKS = {
"text-to-image": SMAutoPipelineForText2Image,
}


def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs):
"""Get a pipeline for Diffusers models."""
device = "cuda" if device == 0 else "cpu"
pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device)
return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,13 @@ def load(self, model_dir):
elif "config.json" in os.listdir(model_dir):
task = infer_task_from_model_architecture(f"{model_dir}/config.json")
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
elif "model_index.json" in os.listdir(model_dir):
task = "text-to-image"
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
else:
raise ValueError(
f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} as env 'HF_TASK'.", 403
f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.",
403,
)
return hf_pipeline

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_input_shapes(model_dir):
return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)}


# TODO: not used yet, need to sync on how to determine if we are running on inf2 instance
def get_optimum_neuron_pipeline(task, model_dir):
"""Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised."""
from optimum.neuron.pipelines import NEURONX_SUPPORTED_TASKS, pipeline
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker_huggingface_inference_toolkit/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ def _start_mms():

def main():
_start_mms()


if __name__ == "__main__":
main()
Loading