26
26
27
27
@patch ("subprocess.call" )
28
28
@patch ("subprocess.Popen" )
29
- @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process " )
29
+ @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process " )
30
30
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler" )
31
31
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements" )
32
32
@patch ("os.path.exists" , return_value = True )
33
33
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file" )
34
34
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format" )
35
+ @patch ("sagemaker_inference.environment.Environment" )
35
36
def test_start_mms_default_service_handler (
37
+ env ,
36
38
adapt ,
37
39
create_config ,
38
40
exists ,
@@ -42,10 +44,11 @@ def test_start_mms_default_service_handler(
42
44
subprocess_popen ,
43
45
subprocess_call ,
44
46
):
47
+ env .return_value .startup_timeout = 10000
45
48
mms_model_server .start_model_server ()
46
49
47
50
adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , model_dir )
48
- create_config .assert_called_once_with ()
51
+ create_config .assert_called_once_with (env . return_value )
49
52
exists .assert_called_once_with (mms_model_server .REQUIREMENTS_PATH )
50
53
install_requirements .assert_called_once_with ()
51
54
@@ -67,7 +70,7 @@ def test_start_mms_default_service_handler(
67
70
@patch ("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available" , return_value = True )
68
71
@patch ("subprocess.call" )
69
72
@patch ("subprocess.Popen" )
70
- @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process " )
73
+ @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process " )
71
74
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub" )
72
75
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler" )
73
76
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements" )
@@ -76,7 +79,9 @@ def test_start_mms_default_service_handler(
76
79
@patch ("os.path.exists" , return_value = True )
77
80
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file" )
78
81
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format" )
82
+ @patch ("sagemaker_inference.environment.Environment" )
79
83
def test_start_mms_neuron (
84
+ env ,
80
85
adapt ,
81
86
create_config ,
82
87
exists ,
@@ -90,11 +95,11 @@ def test_start_mms_neuron(
90
95
subprocess_call ,
91
96
is_aws_neuron_available ,
92
97
):
93
-
98
+ env . return_value . startup_timeout = 10000
94
99
mms_model_server .start_model_server ()
95
100
96
101
adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , model_dir )
97
- create_config .assert_called_once_with ()
102
+ create_config .assert_called_once_with (env . return_value )
98
103
exists .assert_called_once_with (mms_model_server .REQUIREMENTS_PATH )
99
104
install_requirements .assert_called_once_with ()
100
105
@@ -115,7 +120,7 @@ def test_start_mms_neuron(
115
120
116
121
@patch ("subprocess.call" )
117
122
@patch ("subprocess.Popen" )
118
- @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process " )
123
+ @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process " )
119
124
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub" )
120
125
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler" )
121
126
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements" )
@@ -124,7 +129,9 @@ def test_start_mms_neuron(
124
129
@patch ("os.path.exists" , return_value = True )
125
130
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._create_model_server_config_file" )
126
131
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._adapt_to_mms_format" )
132
+ @patch ("sagemaker_inference.environment.Environment" )
127
133
def test_start_mms_with_model_from_hub (
134
+ env ,
128
135
adapt ,
129
136
create_config ,
130
137
exists ,
@@ -137,6 +144,8 @@ def test_start_mms_with_model_from_hub(
137
144
subprocess_popen ,
138
145
subprocess_call ,
139
146
):
147
+ env .return_value .startup_timeout = 10000
148
+
140
149
os .environ ["HF_MODEL_ID" ] = "lysandre/tiny-bert-random"
141
150
142
151
mms_model_server .start_model_server ()
@@ -149,7 +158,7 @@ def test_start_mms_with_model_from_hub(
149
158
)
150
159
151
160
adapt .assert_called_once_with (mms_model_server .DEFAULT_HANDLER_SERVICE , load_model_from_hub ())
152
- create_config .assert_called_once_with ()
161
+ create_config .assert_called_once_with (env . return_value )
153
162
exists .assert_called_with (mms_model_server .REQUIREMENTS_PATH )
154
163
install_requirements .assert_called_once_with ()
155
164
@@ -172,7 +181,7 @@ def test_start_mms_with_model_from_hub(
172
181
@patch ("sagemaker_huggingface_inference_toolkit.transformers_utils._aws_neuron_available" , return_value = True )
173
182
@patch ("subprocess.call" )
174
183
@patch ("subprocess.Popen" )
175
- @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retrieve_mms_server_process " )
184
+ @patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._retry_retrieve_mms_server_process " )
176
185
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._load_model_from_hub" )
177
186
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._add_sigterm_handler" )
178
187
@patch ("sagemaker_huggingface_inference_toolkit.mms_model_server._install_requirements" )
0 commit comments