Skip to content

Commit 7cb5009

Browse files
authored
Merge pull request #56 from aws/fix-configurable-start-up-timeout
fixes configurable startup timeout
2 parents 5b25d12 + 7f38612 commit 7cb5009

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
# We don't declare our dependency on transformers here because we build with
3131
# different packages for different variants
3232

33-
VERSION = "1.3.0"
33+
VERSION = "1.3.1"
3434

3535
install_requires = [
36-
"sagemaker-inference>=1.5.5",
36+
"sagemaker-inference>=1.5.11",
3737
"huggingface_hub>=0.0.8",
3838
"retrying",
3939
"numpy",

src/sagemaker_huggingface_inference_toolkit/mms_model_server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pathlib
1818
import subprocess
1919

20-
from sagemaker_inference import logging
20+
from sagemaker_inference import environment, logging
2121
from sagemaker_inference.environment import model_dir
2222
from sagemaker_inference.model_server import (
2323
DEFAULT_MMS_LOG_FILE,
@@ -28,7 +28,7 @@
2828
_add_sigterm_handler,
2929
_create_model_server_config_file,
3030
_install_requirements,
31-
_retrieve_mms_server_process,
31+
_retry_retrieve_mms_server_process,
3232
_set_python_path,
3333
)
3434

@@ -84,7 +84,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
8484
else:
8585
_adapt_to_mms_format(handler_service, model_dir)
8686

87-
_create_model_server_config_file()
87+
env = environment.Environment()
88+
_create_model_server_config_file(env)
8889

8990
if os.path.exists(REQUIREMENTS_PATH):
9091
_install_requirements()
@@ -102,7 +103,9 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
102103

103104
logger.info(multi_model_server_cmd)
104105
subprocess.Popen(multi_model_server_cmd)
105-
mms_process = _retrieve_mms_server_process()
106+
# retry for configured timeout
107+
mms_process = _retry_retrieve_mms_server_process(env.startup_timeout)
108+
106109
_add_sigterm_handler(mms_process)
107110
_add_sigchild_handler()
108111

tests/unit/test_mms_model_server.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626

2727
@patch("subprocess.call")
2828
@patch("subprocess.Popen")
29-
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process")
29+
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
3030
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
3131
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
3232
@patch("os.path.exists", return_value=True)
3333
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
3434
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
35+
@patch("sagemaker_inference.environment.Environment")
3536
def test_start_mms_default_service_handler(
37+
env,
3638
adapt,
3739
create_config,
3840
exists,
@@ -42,10 +44,11 @@ def test_start_mms_default_service_handler(
4244
subprocess_popen,
4345
subprocess_call,
4446
):
47+
env.return_value.startup_timeout = 10000
4548
mms_model_server.start_model_server()
4649

4750
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir)
48-
create_config.assert_called_once_with()
51+
create_config.assert_called_once_with(env.return_value)
4952
exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
5053
install_requirements.assert_called_once_with()
5154

@@ -67,7 +70,7 @@ def test_start_mms_default_service_handler(
6770
@patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)
6871
@patch("subprocess.call")
6972
@patch("subprocess.Popen")
70-
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process")
73+
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
7174
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
7275
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
7376
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
@@ -76,7 +79,9 @@ def test_start_mms_default_service_handler(
7679
@patch("os.path.exists", return_value=True)
7780
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
7881
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
82+
@patch("sagemaker_inference.environment.Environment")
7983
def test_start_mms_neuron(
84+
env,
8085
adapt,
8186
create_config,
8287
exists,
@@ -90,11 +95,11 @@ def test_start_mms_neuron(
9095
subprocess_call,
9196
is_aws_neuron_available,
9297
):
93-
98+
env.return_value.startup_timeout = 10000
9499
mms_model_server.start_model_server()
95100

96101
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, model_dir)
97-
create_config.assert_called_once_with()
102+
create_config.assert_called_once_with(env.return_value)
98103
exists.assert_called_once_with(mms_model_server.REQUIREMENTS_PATH)
99104
install_requirements.assert_called_once_with()
100105

@@ -115,7 +120,7 @@ def test_start_mms_neuron(
115120

116121
@patch("subprocess.call")
117122
@patch("subprocess.Popen")
118-
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process")
123+
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
119124
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
120125
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
121126
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")
@@ -124,7 +129,9 @@ def test_start_mms_neuron(
124129
@patch("os.path.exists", return_value=True)
125130
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file")
126131
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format")
132+
@patch("sagemaker_inference.environment.Environment")
127133
def test_start_mms_with_model_from_hub(
134+
env,
128135
adapt,
129136
create_config,
130137
exists,
@@ -137,6 +144,8 @@ def test_start_mms_with_model_from_hub(
137144
subprocess_popen,
138145
subprocess_call,
139146
):
147+
env.return_value.startup_timeout = 10000
148+
140149
os.environ["HF_MODEL_ID"] = "lysandre/tiny-bert-random"
141150

142151
mms_model_server.start_model_server()
@@ -149,7 +158,7 @@ def test_start_mms_with_model_from_hub(
149158
)
150159

151160
adapt.assert_called_once_with(mms_model_server.DEFAULT_HANDLER_SERVICE, load_model_from_hub())
152-
create_config.assert_called_once_with()
161+
create_config.assert_called_once_with(env.return_value)
153162
exists.assert_called_with(mms_model_server.REQUIREMENTS_PATH)
154163
install_requirements.assert_called_once_with()
155164

@@ -172,7 +181,7 @@ def test_start_mms_with_model_from_hub(
172181
@patch("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available", return_value=True)
173182
@patch("subprocess.call")
174183
@patch("subprocess.Popen")
175-
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process")
184+
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process")
176185
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub")
177186
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler")
178187
@patch("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements")

0 commit comments

Comments
 (0)