Skip to content

Commit 359d0a0

Browse files
cbalioglufacebook-github-bot
authored andcommitted
[torch/elastic] Improve the implementation of RendezvousParameters and add its unit tests. (#146)
Summary: Pull Request resolved: pytorch/elastic#146 Pull Request resolved: pytorch#54807 Improve the implementation and the unit test coverage of `RendezvousParameters`. Test Plan: Run the existing and newly-introduced unit tests. Reviewed By: kiukchung Differential Revision: D27342444 fbshipit-source-id: 88de356c0a799844a739eb9105185bb8c1acf11f
1 parent 7f06c65 commit 359d0a0

File tree

3 files changed

+228
-80
lines changed

3 files changed

+228
-80
lines changed

test/distributed/elastic/rendezvous/api_test.py

Lines changed: 170 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import unittest
7-
from typing import Tuple
6+
7+
from typing import Any, Dict, SupportsInt, Tuple, cast
8+
from unittest import TestCase
89

910
from torch.distributed.elastic.rendezvous import (
1011
RendezvousHandler,
@@ -40,7 +41,7 @@ def get_run_id(self) -> str:
4041
return ""
4142

4243

43-
class RendezvousHandlerFactoryTest(unittest.TestCase):
44+
class RendezvousHandlerFactoryTest(TestCase):
4445
def test_double_registration(self):
4546
factory = RendezvousHandlerFactory()
4647
factory.register("mock", create_mock_rdzv_handler)
@@ -67,17 +68,171 @@ def test_create_handler(self):
6768
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))
6869

6970

70-
class RendezvousParametersTest(unittest.TestCase):
71-
def test_get_or_default(self):
72-
73-
params = RendezvousParameters(
74-
backend="foobar",
75-
endpoint="localhost",
76-
run_id="1234",
77-
min_nodes=1,
78-
max_nodes=1,
79-
timeout1=10,
71+
class RendezvousParametersTest(TestCase):
72+
def setUp(self) -> None:
73+
self._backend = "dummy_backend"
74+
self._endpoint = "dummy_endpoint"
75+
self._run_id = "dummy_run_id"
76+
self._min_nodes = 3
77+
self._max_nodes = 6
78+
self._kwargs: Dict[str, Any] = {}
79+
80+
def _create_params(self) -> RendezvousParameters:
81+
return RendezvousParameters(
82+
backend=self._backend,
83+
endpoint=self._endpoint,
84+
run_id=self._run_id,
85+
min_nodes=self._min_nodes,
86+
max_nodes=self._max_nodes,
87+
**self._kwargs,
8088
)
8189

82-
self.assertEqual(10, params.get("timeout1", 20))
83-
self.assertEqual(60, params.get("timeout2", 60))
90+
def test_init_initializes_params(self) -> None:
91+
self._kwargs["dummy_param"] = "x"
92+
93+
params = self._create_params()
94+
95+
self.assertEqual(params.backend, self._backend)
96+
self.assertEqual(params.endpoint, self._endpoint)
97+
self.assertEqual(params.run_id, self._run_id)
98+
self.assertEqual(params.min_nodes, self._min_nodes)
99+
self.assertEqual(params.max_nodes, self._max_nodes)
100+
101+
self.assertEqual(params.get("dummy_param"), "x")
102+
103+
def test_init_initializes_params_if_min_nodes_equals_to_1(self) -> None:
104+
self._min_nodes = 1
105+
106+
params = self._create_params()
107+
108+
self.assertEqual(params.min_nodes, self._min_nodes)
109+
self.assertEqual(params.max_nodes, self._max_nodes)
110+
111+
def test_init_initializes_params_if_min_nodes_equals_to_max_nodes(self) -> None:
112+
self._max_nodes = 3
113+
114+
params = self._create_params()
115+
116+
self.assertEqual(params.min_nodes, self._min_nodes)
117+
self.assertEqual(params.max_nodes, self._max_nodes)
118+
119+
def test_init_raises_error_if_backend_is_none_or_empty(self) -> None:
120+
for backend in [None, ""]:
121+
with self.subTest(backend=backend):
122+
self._backend = backend # type: ignore[assignment]
123+
124+
with self.assertRaisesRegex(
125+
ValueError,
126+
r"^The rendezvous backend name must be a non-empty string.$",
127+
):
128+
self._create_params()
129+
130+
def test_init_raises_error_if_min_nodes_is_less_than_1(self) -> None:
131+
for min_nodes in [0, -1, -5]:
132+
with self.subTest(min_nodes=min_nodes):
133+
self._min_nodes = min_nodes
134+
135+
with self.assertRaisesRegex(
136+
ValueError,
137+
rf"^The minimum number of rendezvous nodes \({min_nodes}\) must be greater "
138+
rf"than zero.$",
139+
):
140+
self._create_params()
141+
142+
def test_init_raises_error_if_max_nodes_is_less_than_min_nodes(self) -> None:
143+
for max_nodes in [2, 1, -2]:
144+
with self.subTest(max_nodes=max_nodes):
145+
self._max_nodes = max_nodes
146+
147+
with self.assertRaisesRegex(
148+
ValueError,
149+
rf"^The maximum number of rendezvous nodes \({max_nodes}\) must be greater "
150+
"than or equal to the minimum number of rendezvous nodes "
151+
rf"\({self._min_nodes}\).$",
152+
):
153+
self._create_params()
154+
155+
def test_get_returns_none_if_key_does_not_exist(self) -> None:
156+
params = self._create_params()
157+
158+
self.assertIsNone(params.get("dummy_param"))
159+
160+
def test_get_returns_default_if_key_does_not_exist(self) -> None:
161+
params = self._create_params()
162+
163+
self.assertEqual(params.get("dummy_param", default="x"), "x")
164+
165+
def test_get_as_bool_returns_none_if_key_does_not_exist(self) -> None:
166+
params = self._create_params()
167+
168+
self.assertIsNone(params.get_as_bool("dummy_param"))
169+
170+
def test_get_as_bool_returns_default_if_key_does_not_exist(self) -> None:
171+
params = self._create_params()
172+
173+
self.assertTrue(params.get_as_bool("dummy_param", default=True))
174+
175+
def test_get_as_bool_returns_true_if_value_represents_true(self) -> None:
176+
for value in ["1", "True", "tRue", "T", "t", "yEs", "Y", 1, True]:
177+
with self.subTest(value=value):
178+
self._kwargs["dummy_param"] = value
179+
180+
params = self._create_params()
181+
182+
self.assertTrue(params.get_as_bool("dummy_param"))
183+
184+
def test_get_as_bool_returns_false_if_value_represents_false(self) -> None:
185+
for value in ["0", "False", "faLse", "F", "f", "nO", "N", 0, False]:
186+
with self.subTest(value=value):
187+
self._kwargs["dummy_param"] = value
188+
189+
params = self._create_params()
190+
191+
self.assertFalse(params.get_as_bool("dummy_param"))
192+
193+
def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None:
194+
for value in ["01", "Flse", "Ture", "g", "4", "_", "truefalse", 2, -1]:
195+
with self.subTest(value=value):
196+
self._kwargs["dummy_param"] = value
197+
198+
params = self._create_params()
199+
200+
with self.assertRaisesRegex(
201+
ValueError,
202+
r"^The rendezvous configuration option 'dummy_param' does not represent a "
203+
r"valid boolean value.$",
204+
):
205+
params.get_as_bool("dummy_param")
206+
207+
def test_get_as_int_returns_none_if_key_does_not_exist(self) -> None:
208+
params = self._create_params()
209+
210+
self.assertIsNone(params.get_as_int("dummy_param"))
211+
212+
def test_get_as_int_returns_default_if_key_does_not_exist(self) -> None:
213+
params = self._create_params()
214+
215+
self.assertEqual(params.get_as_int("dummy_param", default=5), 5)
216+
217+
def test_get_as_int_returns_integer_if_value_represents_integer(self) -> None:
218+
for value in ["0", "-10", "5", " 4", "4 ", " 4 ", 0, -4, 3]:
219+
with self.subTest(value=value):
220+
self._kwargs["dummy_param"] = value
221+
222+
params = self._create_params()
223+
224+
self.assertEqual(params.get_as_int("dummy_param"), int(cast(SupportsInt, value)))
225+
226+
def test_get_as_int_raises_error_if_value_is_invalid(self) -> None:
227+
for value in ["a", "0a", "3b", "abc"]:
228+
with self.subTest(value=value):
229+
self._kwargs["dummy_param"] = value
230+
231+
params = self._create_params()
232+
233+
with self.assertRaisesRegex(
234+
ValueError,
235+
r"^The rendezvous configuration option 'dummy_param' does not represent a "
236+
r"valid integer value.$",
237+
):
238+
params.get_as_int("dummy_param")

torch/distributed/elastic/rendezvous/api.py

Lines changed: 49 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,22 @@ def main():
133133

134134

135135
class RendezvousParameters:
136+
"""Holds the parameters to construct a `RendezvousHandler`.
137+
138+
Args:
139+
backend:
140+
The name of the backend to use to handle the rendezvous.
141+
endpoint:
142+
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
143+
run_id:
144+
The id of the rendezvous.
145+
min_nodes:
146+
The minimum number of nodes to admit to the rendezvous.
147+
max_nodes:
148+
The maximum number of nodes to admit to the rendezvous.
149+
**kwargs:
150+
Additional parameters for the specified backend.
136151
"""
137-
The data object holding parameters to construct a ``RendezvousHandler``.
138-
"""
139-
140-
# Default timeout for the rendezvous.
141-
_DEFAULT_TIMEOUT: int = 600 # 10 minutes
142-
143-
# Additional waiting time after reaching the minimum number of nodes
144-
# in case the rendezvous is elastic (min != max).
145-
_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds
146152

147153
def __init__(
148154
self,
@@ -153,25 +159,17 @@ def __init__(
153159
max_nodes: int,
154160
**kwargs,
155161
):
156-
"""
157-
Args:
158-
backend: The backend that is used to register the rendezvous.
159-
endpoint: The endpoint of the rendezvous. Usually it is a string in the format
160-
<hostname>:<port>.
161-
run_id: The id of the rendezvous.
162-
min_nodes: The minimum number of nodes required to complete the rendezvous.
163-
max_nodes: The maximum number of nodes that are allowed to join the rendezvous.
164-
**kwargs: Additional parameters for the specified backend.
165-
"""
166-
if backend is None:
167-
raise ValueError("The backend cannot be None.")
162+
if not backend:
163+
raise ValueError("The rendezvous backend name must be a non-empty string.")
168164

169165
if min_nodes < 1:
170-
raise ValueError("The minimum number of nodes must be greater than zero.")
166+
raise ValueError(
167+
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
168+
)
171169
if max_nodes < min_nodes:
172170
raise ValueError(
173-
"The maximum number of nodes must be greater than"
174-
" or equal to the minimum number of nodes."
171+
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
172+
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
175173
)
176174

177175
self.backend = backend
@@ -181,54 +179,40 @@ def __init__(
181179
self.max_nodes = max_nodes
182180
self.config = kwargs
183181

184-
@property
185-
def timeout(self):
186-
"""
187-
Gets the timeout for the rendezvous.
188-
"""
189-
return self.get_as_int("timeout", self._DEFAULT_TIMEOUT)
190-
191-
@property
192-
def last_call_timeout(self):
193-
"""
194-
Gets additional waiting time after reaching the minimum number of nodes
195-
in case the rendezvous is elastic (min != max).
196-
"""
197-
return self.get_as_int("last_call_timeout", self._DEFAULT_LAST_CALL_TIMEOUT)
198-
199182
def get(self, key: str, default: Any = None) -> Any:
200-
"""
201-
Returns the value for ``key`` if ``key`` exists, else ``default``.
202-
"""
183+
"""Returns the value for `key` if `key` exists, else `default`."""
203184
return self.config.get(key, default)
204185

205186
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
206-
"""
207-
Returns the value for ``key`` as a ``bool`` if ``key`` exists.
208-
"""
209-
val = self.get(key, default)
210-
if val is None:
211-
return val
212-
if isinstance(val, int) or isinstance(val, bool):
213-
return True if val else False
214-
if isinstance(val, str):
215-
return val.lower() in ["1", "true", "t", "yes", "y"]
187+
"""Returns the value for `key` as a `bool`."""
188+
value = self.get(key, default)
189+
if value is None or isinstance(value, bool):
190+
return value
191+
if isinstance(value, int):
192+
if value == 1:
193+
return True
194+
if value == 0:
195+
return False
196+
elif isinstance(value, str):
197+
if value.lower() in ["1", "true", "t", "yes", "y"]:
198+
return True
199+
if value.lower() in ["0", "false", "f", "no", "n"]:
200+
return False
216201
raise ValueError(
217-
f"The '{key}' rendezvous config does not represent a valid boolean value."
202+
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
218203
)
219204

220205
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
221-
"""
222-
Returns the value for ``key`` as an ``int`` if ``key`` exists.
223-
"""
224-
val = self.get(key, default)
225-
if val is None:
226-
return val
206+
"""Returns the value for `key` as an `int`."""
207+
value = self.get(key, default)
208+
if value is None:
209+
return value
227210
try:
228-
return int(val)
211+
return int(value)
229212
except ValueError:
230213
raise ValueError(
231-
f"The '{key}' rendezvous config does not represent a valid integer value."
214+
f"The rendezvous configuration option '{key}' does not represent a valid integer "
215+
"value."
232216
)
233217

234218

@@ -269,15 +253,17 @@ def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
269253
creator = self._registry[params.backend]
270254
except KeyError:
271255
raise ValueError(
272-
f"The rendezvous backend '{params.backend}' is not registered. Did you forget to call {self.register.__name__}?"
256+
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
257+
f"to call {self.register.__name__}?"
273258
)
274259

275260
handler = creator(params)
276261

277262
# Do some sanity check.
278263
if handler.get_backend() != params.backend:
279264
raise RuntimeError(
280-
f"The rendezvous handler backend '{handler.get_backend()}' does not match the requested backend '{params.backend}'."
265+
f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
266+
f"requested backend '{params.backend}'."
281267
)
282268

283269
return handler

torch/distributed/elastic/rendezvous/etcd_rendezvous.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ class EtcdRendezvousRetryImmediately(Exception):
5454
pass
5555

5656

57+
# Default timeout for the rendezvous.
58+
_DEFAULT_TIMEOUT: int = 600 # 10 minutes
59+
60+
# Additional waiting time after reaching the minimum number of nodes
61+
# in case the rendezvous is elastic (min != max).
62+
_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds
63+
5764
# Various constants used internally in EtcdRendezvous
5865
CONST_ETCD_SETUP_TTL = 5
5966
CONST_ETCD_FROZEN_TTL = 10
@@ -1252,7 +1259,7 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
12521259
run_id=params.run_id,
12531260
num_min_workers=params.min_nodes,
12541261
num_max_workers=params.max_nodes,
1255-
timeout=params.timeout,
1256-
last_call_timeout=params.last_call_timeout,
1262+
timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
1263+
last_call_timeout=params.get_as_int("last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT),
12571264
)
12581265
return EtcdRendezvousHandler(rdzv_impl=rdzv)

0 commit comments

Comments
 (0)