Skip to content

Commit 2f1fae5

Browse files
authored
Merge pull request #60 from aws/new-modalities
Add new modalities
2 parents 7cb5009 + 419b278 commit 2f1fae5

20 files changed

+248
-29
lines changed

README.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
</div>
55

66

7+
8+
79
# SageMaker Hugging Face Inference Toolkit
810

911
[![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black)
@@ -111,7 +113,7 @@ HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
111113

112114
## 🧑🏻‍💻 User defined code/modules
113115

114-
The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it.
116+
The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it. You can find an example for it in [sagemaker/17_customer_inference_script](https://github.com/huggingface/notebooks/blob/master/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb)
115117
For example:
116118
```bash
117119
model.tar.gz/
@@ -144,3 +146,13 @@ requests to us.
144146
## 📜 License
145147

146148
SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License.
149+
150+
---
151+
152+
## 🧑🏻‍💻 Development Environment
153+
154+
Install all test and development packages with
155+
156+
```bash
157+
pip3 install -e ".[test,dev]"
158+
```

setup.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,34 @@
3030
# We don't declare our dependency on transformers here because we build with
3131
# different packages for different variants
3232

33-
VERSION = "1.3.1"
33+
VERSION = "2.0.0"
34+
35+
36+
# Ubuntu packages
37+
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
38+
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
39+
# libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg
3440

3541
install_requires = [
3642
"sagemaker-inference>=1.5.11",
3743
"huggingface_hub>=0.0.8",
3844
"retrying",
3945
"numpy",
46+
# vision
47+
"Pillow",
48+
# speech + torchaudio
49+
"librosa",
50+
"pyctcdecode>=0.3.0",
51+
"phonemizer",
4052
]
4153

4254
extras = {}
4355

4456
# Hugging Face specific dependencies
45-
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.5.1"]
57+
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]
4658

4759
# framework specific dependencies
48-
extras["torch"] = ["torch>=1.8.0"]
60+
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
4961
extras["tensorflow"] = ["tensorflow>=2.4.0"]
5062

5163
# MMS Server dependencies
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""This module contains constants that define MIME content types."""
15+
# Default Mime-Types
16+
JSON = "application/json"
17+
CSV = "text/csv"
18+
OCTET_STREAM = "application/octet-stream"
19+
ANY = "*/*"
20+
NPY = "application/x-npy"
21+
UTF8_TYPES = [JSON, CSV]
22+
# Vision Mime-Types
23+
JPEG = "image/jpeg"
24+
PNG = "image/png"
25+
TIFF = "image/tiff"
26+
BMP = "image/bmp"
27+
GIF = "image/gif"
28+
WEBP = "image/webp"
29+
X_IMAGE = "image/x-image"
30+
VISION_TYPES = [JPEG, PNG, TIFF, BMP, GIF, WEBP, X_IMAGE]
31+
# Speech Mime-Types
32+
FLAC = "audio/x-flac"
33+
MP3 = "audio/mpeg"
34+
WAV = "audio/wave"
35+
OGG = "audio/ogg"
36+
X_AUDIO = "audio/x-audio"
37+
AUDIO_TYPES = [FLAC, MP3, WAV, OGG, X_AUDIO]

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

+47-4
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import base64
1415
import csv
1516
import datetime
1617
import json
17-
from io import StringIO
18+
from io import BytesIO, StringIO
1819

1920
import numpy as np
20-
from sagemaker_inference import content_types, errors
21-
from sagemaker_inference.decoder import _npy_to_numpy, _npz_to_sparse
21+
from sagemaker_inference import errors
22+
from sagemaker_inference.decoder import _npy_to_numpy
2223
from sagemaker_inference.encoder import _array_to_npy
2324

2425
from mms.service import PredictionException
26+
from PIL import Image
27+
from sagemaker_huggingface_inference_toolkit import content_types
2528

2629

2730
def decode_json(content):
@@ -51,6 +54,28 @@ def decode_csv(string_like): # type: (str) -> np.array
5154
return {"inputs": request_list}
5255

5356

57+
def decode_image(bpayload: bytearray):
58+
"""Convert a .jpeg / .png / .tiff... object to a proper inputs dict.
59+
Args:
60+
bpayload (bytes): byte stream.
61+
Returns:
62+
(dict): dictonatry for input
63+
"""
64+
image = Image.open(BytesIO(bpayload)).convert("RGB")
65+
return {"inputs": image}
66+
67+
68+
def decode_audio(bpayload: bytearray):
69+
"""Convert a .wav / .flac / .mp3 object to a proper inputs dict.
70+
Args:
71+
bpayload (bytes): byte stream.
72+
Returns:
73+
(dict): dictonatry for input
74+
"""
75+
76+
return {"inputs": bytes(bpayload)}
77+
78+
5479
# https://github.com/automl/SMAC3/issues/453
5580
class _JSONEncoder(json.JSONEncoder):
5681
"""
@@ -66,6 +91,11 @@ def default(self, obj):
6691
return obj.tolist()
6792
elif isinstance(obj, datetime.datetime):
6893
return obj.__str__()
94+
elif isinstance(obj, Image.Image):
95+
with BytesIO() as out:
96+
obj.save(out, format="PNG")
97+
png_string = out.getvalue()
98+
return base64.b64encode(png_string).decode("utf-8")
6999
else:
70100
return super(_JSONEncoder, self).default(obj)
71101

@@ -111,8 +141,21 @@ def encode_csv(content): # type: (str) -> np.array
111141
_decoder_map = {
112142
content_types.NPY: _npy_to_numpy,
113143
content_types.CSV: decode_csv,
114-
content_types.NPZ: _npz_to_sparse,
115144
content_types.JSON: decode_json,
145+
# image mime-types
146+
content_types.JPEG: decode_image,
147+
content_types.PNG: decode_image,
148+
content_types.TIFF: decode_image,
149+
content_types.BMP: decode_image,
150+
content_types.GIF: decode_image,
151+
content_types.WEBP: decode_image,
152+
content_types.X_IMAGE: decode_image,
153+
# audio mime-types
154+
content_types.FLAC: decode_audio,
155+
content_types.MP3: decode_audio,
156+
content_types.WAV: decode_audio,
157+
content_types.OGG: decode_audio,
158+
content_types.X_AUDIO: decode_audio,
116159
}
117160

118161

src/sagemaker_huggingface_inference_toolkit/handler_service.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import time
2020
from abc import ABC
2121

22-
from sagemaker_inference import content_types, environment, utils
22+
from sagemaker_inference import environment, utils
2323
from transformers.pipelines import SUPPORTED_TASKS
2424

2525
from mms.service import PredictionException
26-
from sagemaker_huggingface_inference_toolkit import decoder_encoder
26+
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
2727
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
2828
_is_gpu_available,
2929
get_pipeline,

src/sagemaker_huggingface_inference_toolkit/transformers_utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,21 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
255255
raise EnvironmentError(
256256
"The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined"
257257
)
258+
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
259+
if task in {
260+
"automatic-speech-recognition",
261+
"image-segmentation",
262+
"image-classification",
263+
"audio-classification",
264+
"object-detection",
265+
"zero-shot-image-classification",
266+
}:
267+
kwargs["feature_extractor"] = model_dir
268+
else:
269+
kwargs["tokenizer"] = model_dir
258270

259-
hf_pipeline = pipeline(task=task, model=model_dir, tokenizer=model_dir, device=device, **kwargs)
271+
# load pipeline
272+
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
260273

261274
# wrapp specific pipeline to support better ux
262275
if task == "conversational":

tests/integ/config.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import os
2+
13
from integ.utils import (
4+
validate_automatic_speech_recognition,
5+
validate_classification,
26
validate_feature_extraction,
37
validate_fill_mask,
48
validate_ner,
59
validate_question_answering,
610
validate_summarization,
711
validate_text2text_generation,
8-
validate_text_classification,
912
validate_text_generation,
1013
validate_translation,
1114
validate_zero_shot_classification,
@@ -53,6 +56,14 @@
5356
"pytorch": "gpt2",
5457
"tensorflow": "gpt2",
5558
},
59+
"image-classification": {
60+
"pytorch": "google/vit-base-patch16-224",
61+
"tensorflow": "google/vit-base-patch16-224",
62+
},
63+
"automatic-speech-recognition": {
64+
"pytorch": "facebook/wav2vec2-base-100h",
65+
"tensorflow": "facebook/wav2vec2-base-960h",
66+
},
5667
}
5768

5869
task2input = {
@@ -78,6 +89,8 @@
7889
"inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything."
7990
},
8091
"text-generation": {"inputs": "My name is philipp and I am"},
92+
"image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
93+
"automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
8194
}
8295

8396
task2output = {
@@ -98,6 +111,16 @@
98111
"feature-extraction": None,
99112
"fill-mask": None,
100113
"text-generation": None,
114+
"image-classification": [
115+
{"score": 0.8858247399330139, "label": "tiger, Panthera tigris"},
116+
{"score": 0.10940514504909515, "label": "tiger cat"},
117+
{"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"},
118+
{"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"},
119+
{"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"},
120+
],
121+
"automatic-speech-recognition": {
122+
"text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS"
123+
},
101124
}
102125

103126
task2performance = {
@@ -181,10 +204,26 @@
181204
"average_request_time": 3,
182205
},
183206
},
207+
"image-classification": {
208+
"cpu": {
209+
"average_request_time": 4,
210+
},
211+
"gpu": {
212+
"average_request_time": 1,
213+
},
214+
},
215+
"automatic-speech-recognition": {
216+
"cpu": {
217+
"average_request_time": 6,
218+
},
219+
"gpu": {
220+
"average_request_time": 6,
221+
},
222+
},
184223
}
185224

186225
task2validation = {
187-
"text-classification": validate_text_classification,
226+
"text-classification": validate_classification,
188227
"zero-shot-classification": validate_zero_shot_classification,
189228
"feature-extraction": validate_feature_extraction,
190229
"ner": validate_ner,
@@ -194,4 +233,6 @@
194233
"translation_xx_to_yy": validate_translation,
195234
"text2text-generation": validate_text2text_generation,
196235
"text-generation": validate_text_generation,
236+
"image-classification": validate_classification,
237+
"automatic-speech-recognition": validate_automatic_speech_recognition,
197238
}

0 commit comments

Comments
 (0)