3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import unittest
7
- from typing import Tuple
6
+
7
+ from typing import Any , Dict , SupportsInt , Tuple , cast
8
+ from unittest import TestCase
8
9
9
10
from torch .distributed .elastic .rendezvous import (
10
11
RendezvousHandler ,
@@ -40,7 +41,7 @@ def get_run_id(self) -> str:
40
41
return ""
41
42
42
43
43
- class RendezvousHandlerFactoryTest (unittest . TestCase ):
44
+ class RendezvousHandlerFactoryTest (TestCase ):
44
45
def test_double_registration (self ):
45
46
factory = RendezvousHandlerFactory ()
46
47
factory .register ("mock" , create_mock_rdzv_handler )
@@ -67,17 +68,171 @@ def test_create_handler(self):
67
68
self .assertTrue (isinstance (mock_rdzv_handler , MockRendezvousHandler ))
68
69
69
70
70
- class RendezvousParametersTest (unittest .TestCase ):
71
- def test_get_or_default (self ):
72
-
73
- params = RendezvousParameters (
74
- backend = "foobar" ,
75
- endpoint = "localhost" ,
76
- run_id = "1234" ,
77
- min_nodes = 1 ,
78
- max_nodes = 1 ,
79
- timeout1 = 10 ,
71
+ class RendezvousParametersTest (TestCase ):
72
+ def setUp (self ) -> None :
73
+ self ._backend = "dummy_backend"
74
+ self ._endpoint = "dummy_endpoint"
75
+ self ._run_id = "dummy_run_id"
76
+ self ._min_nodes = 3
77
+ self ._max_nodes = 6
78
+ self ._kwargs : Dict [str , Any ] = {}
79
+
80
+ def _create_params (self ) -> RendezvousParameters :
81
+ return RendezvousParameters (
82
+ backend = self ._backend ,
83
+ endpoint = self ._endpoint ,
84
+ run_id = self ._run_id ,
85
+ min_nodes = self ._min_nodes ,
86
+ max_nodes = self ._max_nodes ,
87
+ ** self ._kwargs ,
80
88
)
81
89
82
- self .assertEqual (10 , params .get ("timeout1" , 20 ))
83
- self .assertEqual (60 , params .get ("timeout2" , 60 ))
90
+ def test_init_initializes_params (self ) -> None :
91
+ self ._kwargs ["dummy_param" ] = "x"
92
+
93
+ params = self ._create_params ()
94
+
95
+ self .assertEqual (params .backend , self ._backend )
96
+ self .assertEqual (params .endpoint , self ._endpoint )
97
+ self .assertEqual (params .run_id , self ._run_id )
98
+ self .assertEqual (params .min_nodes , self ._min_nodes )
99
+ self .assertEqual (params .max_nodes , self ._max_nodes )
100
+
101
+ self .assertEqual (params .get ("dummy_param" ), "x" )
102
+
103
+ def test_init_initializes_params_if_min_nodes_equals_to_1 (self ) -> None :
104
+ self ._min_nodes = 1
105
+
106
+ params = self ._create_params ()
107
+
108
+ self .assertEqual (params .min_nodes , self ._min_nodes )
109
+ self .assertEqual (params .max_nodes , self ._max_nodes )
110
+
111
+ def test_init_initializes_params_if_min_nodes_equals_to_max_nodes (self ) -> None :
112
+ self ._max_nodes = 3
113
+
114
+ params = self ._create_params ()
115
+
116
+ self .assertEqual (params .min_nodes , self ._min_nodes )
117
+ self .assertEqual (params .max_nodes , self ._max_nodes )
118
+
119
+ def test_init_raises_error_if_backend_is_none_or_empty (self ) -> None :
120
+ for backend in [None , "" ]:
121
+ with self .subTest (backend = backend ):
122
+ self ._backend = backend # type: ignore[assignment]
123
+
124
+ with self .assertRaisesRegex (
125
+ ValueError ,
126
+ r"^The rendezvous backend name must be a non-empty string.$" ,
127
+ ):
128
+ self ._create_params ()
129
+
130
+ def test_init_raises_error_if_min_nodes_is_less_than_1 (self ) -> None :
131
+ for min_nodes in [0 , - 1 , - 5 ]:
132
+ with self .subTest (min_nodes = min_nodes ):
133
+ self ._min_nodes = min_nodes
134
+
135
+ with self .assertRaisesRegex (
136
+ ValueError ,
137
+ rf"^The minimum number of rendezvous nodes \({ min_nodes } \) must be greater "
138
+ rf"than zero.$" ,
139
+ ):
140
+ self ._create_params ()
141
+
142
+ def test_init_raises_error_if_max_nodes_is_less_than_min_nodes (self ) -> None :
143
+ for max_nodes in [2 , 1 , - 2 ]:
144
+ with self .subTest (max_nodes = max_nodes ):
145
+ self ._max_nodes = max_nodes
146
+
147
+ with self .assertRaisesRegex (
148
+ ValueError ,
149
+ rf"^The maximum number of rendezvous nodes \({ max_nodes } \) must be greater "
150
+ "than or equal to the minimum number of rendezvous nodes "
151
+ rf"\({ self ._min_nodes } \).$" ,
152
+ ):
153
+ self ._create_params ()
154
+
155
+ def test_get_returns_none_if_key_does_not_exist (self ) -> None :
156
+ params = self ._create_params ()
157
+
158
+ self .assertIsNone (params .get ("dummy_param" ))
159
+
160
+ def test_get_returns_default_if_key_does_not_exist (self ) -> None :
161
+ params = self ._create_params ()
162
+
163
+ self .assertEqual (params .get ("dummy_param" , default = "x" ), "x" )
164
+
165
+ def test_get_as_bool_returns_none_if_key_does_not_exist (self ) -> None :
166
+ params = self ._create_params ()
167
+
168
+ self .assertIsNone (params .get_as_bool ("dummy_param" ))
169
+
170
+ def test_get_as_bool_returns_default_if_key_does_not_exist (self ) -> None :
171
+ params = self ._create_params ()
172
+
173
+ self .assertTrue (params .get_as_bool ("dummy_param" , default = True ))
174
+
175
+ def test_get_as_bool_returns_true_if_value_represents_true (self ) -> None :
176
+ for value in ["1" , "True" , "tRue" , "T" , "t" , "yEs" , "Y" , 1 , True ]:
177
+ with self .subTest (value = value ):
178
+ self ._kwargs ["dummy_param" ] = value
179
+
180
+ params = self ._create_params ()
181
+
182
+ self .assertTrue (params .get_as_bool ("dummy_param" ))
183
+
184
+ def test_get_as_bool_returns_false_if_value_represents_false (self ) -> None :
185
+ for value in ["0" , "False" , "faLse" , "F" , "f" , "nO" , "N" , 0 , False ]:
186
+ with self .subTest (value = value ):
187
+ self ._kwargs ["dummy_param" ] = value
188
+
189
+ params = self ._create_params ()
190
+
191
+ self .assertFalse (params .get_as_bool ("dummy_param" ))
192
+
193
+ def test_get_as_bool_raises_error_if_value_is_invalid (self ) -> None :
194
+ for value in ["01" , "Flse" , "Ture" , "g" , "4" , "_" , "truefalse" , 2 , - 1 ]:
195
+ with self .subTest (value = value ):
196
+ self ._kwargs ["dummy_param" ] = value
197
+
198
+ params = self ._create_params ()
199
+
200
+ with self .assertRaisesRegex (
201
+ ValueError ,
202
+ r"^The rendezvous configuration option 'dummy_param' does not represent a "
203
+ r"valid boolean value.$" ,
204
+ ):
205
+ params .get_as_bool ("dummy_param" )
206
+
207
+ def test_get_as_int_returns_none_if_key_does_not_exist (self ) -> None :
208
+ params = self ._create_params ()
209
+
210
+ self .assertIsNone (params .get_as_int ("dummy_param" ))
211
+
212
+ def test_get_as_int_returns_default_if_key_does_not_exist (self ) -> None :
213
+ params = self ._create_params ()
214
+
215
+ self .assertEqual (params .get_as_int ("dummy_param" , default = 5 ), 5 )
216
+
217
+ def test_get_as_int_returns_integer_if_value_represents_integer (self ) -> None :
218
+ for value in ["0" , "-10" , "5" , " 4" , "4 " , " 4 " , 0 , - 4 , 3 ]:
219
+ with self .subTest (value = value ):
220
+ self ._kwargs ["dummy_param" ] = value
221
+
222
+ params = self ._create_params ()
223
+
224
+ self .assertEqual (params .get_as_int ("dummy_param" ), int (cast (SupportsInt , value )))
225
+
226
+ def test_get_as_int_raises_error_if_value_is_invalid (self ) -> None :
227
+ for value in ["a" , "0a" , "3b" , "abc" ]:
228
+ with self .subTest (value = value ):
229
+ self ._kwargs ["dummy_param" ] = value
230
+
231
+ params = self ._create_params ()
232
+
233
+ with self .assertRaisesRegex (
234
+ ValueError ,
235
+ r"^The rendezvous configuration option 'dummy_param' does not represent a "
236
+ r"valid integer value.$" ,
237
+ ):
238
+ params .get_as_int ("dummy_param" )
0 commit comments