Skip to content

Add new modalities #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
</div>




# SageMaker Hugging Face Inference Toolkit

[![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black)
Expand Down Expand Up @@ -111,7 +113,7 @@ HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"

## 🧑🏻‍💻 User defined code/modules

The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it.
The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it. You can find an example for it in [sagemaker/17_customer_inference_script](https://github.com/huggingface/notebooks/blob/master/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb)
For example:
```bash
model.tar.gz/
Expand Down Expand Up @@ -144,3 +146,13 @@ requests to us.
## 📜 License

SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License.

---

## 🧑🏻‍💻 Development Environment

Install all test and development packages with

```bash
pip3 install -e ".[test,dev]"
```
18 changes: 15 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,34 @@
# We don't declare our dependency on transformers here because we build with
# different packages for different variants

VERSION = "1.3.1"
VERSION = "2.0.0"


# Ubuntu packages
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
# libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg

install_requires = [
"sagemaker-inference>=1.5.11",
"huggingface_hub>=0.0.8",
"retrying",
"numpy",
# vision
"Pillow",
# speech + torchaudio
"librosa",
"pyctcdecode>=0.3.0",
"phonemizer",
]

extras = {}

# Hugging Face specific dependencies
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.5.1"]
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]

# framework specific dependencies
extras["torch"] = ["torch>=1.8.0"]
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
extras["tensorflow"] = ["tensorflow>=2.4.0"]

# MMS Server dependencies
Expand Down
37 changes: 37 additions & 0 deletions src/sagemaker_huggingface_inference_toolkit/content_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains constants that define MIME content types."""
# Default Mime-Types
JSON = "application/json"
CSV = "text/csv"
OCTET_STREAM = "application/octet-stream"
ANY = "*/*"
NPY = "application/x-npy"
UTF8_TYPES = [JSON, CSV]
# Vision Mime-Types
JPEG = "image/jpeg"
PNG = "image/png"
TIFF = "image/tiff"
BMP = "image/bmp"
GIF = "image/gif"
WEBP = "image/webp"
X_IMAGE = "image/x-image"
VISION_TYPES = [JPEG, PNG, TIFF, BMP, GIF, WEBP, X_IMAGE]
# Speech Mime-Types
FLAC = "audio/x-flac"
MP3 = "audio/mpeg"
WAV = "audio/wave"
OGG = "audio/ogg"
X_AUDIO = "audio/x-audio"
AUDIO_TYPES = [FLAC, MP3, WAV, OGG, X_AUDIO]
51 changes: 47 additions & 4 deletions src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import csv
import datetime
import json
from io import StringIO
from io import BytesIO, StringIO

import numpy as np
from sagemaker_inference import content_types, errors
from sagemaker_inference.decoder import _npy_to_numpy, _npz_to_sparse
from sagemaker_inference import errors
from sagemaker_inference.decoder import _npy_to_numpy
from sagemaker_inference.encoder import _array_to_npy

from mms.service import PredictionException
from PIL import Image
from sagemaker_huggingface_inference_toolkit import content_types


def decode_json(content):
Expand Down Expand Up @@ -51,6 +54,28 @@ def decode_csv(string_like): # type: (str) -> np.array
return {"inputs": request_list}


def decode_image(bpayload: bytearray):
"""Convert a .jpeg / .png / .tiff... object to a proper inputs dict.
Args:
bpayload (bytes): byte stream.
Returns:
(dict): dictonatry for input
"""
image = Image.open(BytesIO(bpayload)).convert("RGB")
return {"inputs": image}


def decode_audio(bpayload: bytearray):
"""Convert a .wav / .flac / .mp3 object to a proper inputs dict.
Args:
bpayload (bytes): byte stream.
Returns:
(dict): dictonatry for input
"""

return {"inputs": bytes(bpayload)}


# https://github.com/automl/SMAC3/issues/453
class _JSONEncoder(json.JSONEncoder):
"""
Expand All @@ -66,6 +91,11 @@ def default(self, obj):
return obj.tolist()
elif isinstance(obj, datetime.datetime):
return obj.__str__()
elif isinstance(obj, Image.Image):
with BytesIO() as out:
obj.save(out, format="PNG")
png_string = out.getvalue()
return base64.b64encode(png_string).decode("utf-8")
else:
return super(_JSONEncoder, self).default(obj)

Expand Down Expand Up @@ -111,8 +141,21 @@ def encode_csv(content): # type: (str) -> np.array
_decoder_map = {
content_types.NPY: _npy_to_numpy,
content_types.CSV: decode_csv,
content_types.NPZ: _npz_to_sparse,
content_types.JSON: decode_json,
# image mime-types
content_types.JPEG: decode_image,
content_types.PNG: decode_image,
content_types.TIFF: decode_image,
content_types.BMP: decode_image,
content_types.GIF: decode_image,
content_types.WEBP: decode_image,
content_types.X_IMAGE: decode_image,
# audio mime-types
content_types.FLAC: decode_audio,
content_types.MP3: decode_audio,
content_types.WAV: decode_audio,
content_types.OGG: decode_audio,
content_types.X_AUDIO: decode_audio,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import time
from abc import ABC

from sagemaker_inference import content_types, environment, utils
from sagemaker_inference import environment, utils
from transformers.pipelines import SUPPORTED_TASKS

from mms.service import PredictionException
from sagemaker_huggingface_inference_toolkit import decoder_encoder
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
_is_gpu_available,
get_pipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,21 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
raise EnvironmentError(
"The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined"
)
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
"automatic-speech-recognition",
"image-segmentation",
"image-classification",
"audio-classification",
"object-detection",
"zero-shot-image-classification",
}:
kwargs["feature_extractor"] = model_dir
else:
kwargs["tokenizer"] = model_dir

hf_pipeline = pipeline(task=task, model=model_dir, tokenizer=model_dir, device=device, **kwargs)
# load pipeline
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)

# wrapp specific pipeline to support better ux
if task == "conversational":
Expand Down
45 changes: 43 additions & 2 deletions tests/integ/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os

from integ.utils import (
validate_automatic_speech_recognition,
validate_classification,
validate_feature_extraction,
validate_fill_mask,
validate_ner,
validate_question_answering,
validate_summarization,
validate_text2text_generation,
validate_text_classification,
validate_text_generation,
validate_translation,
validate_zero_shot_classification,
Expand Down Expand Up @@ -53,6 +56,14 @@
"pytorch": "gpt2",
"tensorflow": "gpt2",
},
"image-classification": {
"pytorch": "google/vit-base-patch16-224",
"tensorflow": "google/vit-base-patch16-224",
},
"automatic-speech-recognition": {
"pytorch": "facebook/wav2vec2-base-100h",
"tensorflow": "facebook/wav2vec2-base-960h",
},
}

task2input = {
Expand All @@ -78,6 +89,8 @@
"inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything."
},
"text-generation": {"inputs": "My name is philipp and I am"},
"image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
"automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
}

task2output = {
Expand All @@ -98,6 +111,16 @@
"feature-extraction": None,
"fill-mask": None,
"text-generation": None,
"image-classification": [
{"score": 0.8858247399330139, "label": "tiger, Panthera tigris"},
{"score": 0.10940514504909515, "label": "tiger cat"},
{"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"},
{"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"},
{"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"},
],
"automatic-speech-recognition": {
"text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS"
},
}

task2performance = {
Expand Down Expand Up @@ -181,10 +204,26 @@
"average_request_time": 3,
},
},
"image-classification": {
"cpu": {
"average_request_time": 4,
},
"gpu": {
"average_request_time": 1,
},
},
"automatic-speech-recognition": {
"cpu": {
"average_request_time": 6,
},
"gpu": {
"average_request_time": 6,
},
},
}

task2validation = {
"text-classification": validate_text_classification,
"text-classification": validate_classification,
"zero-shot-classification": validate_zero_shot_classification,
"feature-extraction": validate_feature_extraction,
"ner": validate_ner,
Expand All @@ -194,4 +233,6 @@
"translation_xx_to_yy": validate_translation,
"text2text-generation": validate_text2text_generation,
"text-generation": validate_text_generation,
"image-classification": validate_classification,
"automatic-speech-recognition": validate_automatic_speech_recognition,
}
Loading