Skip to content

Commit 1961e45

Browse files
authored
Merge pull request #44 from apoorvkh/pg-timeout
add pg_timeout flag
2 parents cd1a895 + eea5998 commit 1961e45

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

src/torchrunx/agent.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
import os
45
import socket
56
import sys
@@ -33,6 +34,7 @@ class WorkerArgs:
3334
local_world_size: int
3435
world_size: int
3536
log_file: os.PathLike
37+
timeout: int
3638

3739
def to_bytes(self) -> bytes:
3840
return cloudpickle.dumps(self)
@@ -81,7 +83,11 @@ def entrypoint(serialized_worker_args: bytes):
8183
if backend is None:
8284
backend = "nccl" if torch.cuda.is_available() else "gloo"
8385
dist.init_process_group(
84-
backend=backend, world_size=worker_args.world_size, rank=worker_args.rank, store=store
86+
backend=backend,
87+
world_size=worker_args.world_size,
88+
rank=worker_args.rank,
89+
store=store,
90+
timeout=datetime.timedelta(seconds=worker_args.timeout),
8591
)
8692

8793
os.environ["RANK"] = str(worker_args.rank)
@@ -130,6 +136,7 @@ def main(launcher_agent_group: LauncherAgentGroup):
130136
local_world_size=num_workers,
131137
world_size=worker_world_size,
132138
log_file=worker_log_files[i],
139+
timeout=launcher_payload.timeout,
133140
).to_bytes(),
134141
)
135142
for i in range(num_workers)

src/torchrunx/launcher.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class Launcher:
9696
]
9797
)
9898
env_file: str | os.PathLike | None = None
99+
timeout: int = 600
99100

100101
def run(
101102
self,
@@ -209,6 +210,7 @@ def run(
209210
worker_global_ranks=worker_global_ranks,
210211
worker_log_files=worker_log_files,
211212
backend=self.backend,
213+
timeout=self.timeout,
212214
)
213215

214216
agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType]
@@ -270,6 +272,7 @@ def launch(
270272
"NCCL*",
271273
],
272274
env_file: str | os.PathLike | None = None,
275+
timeout: int = 600,
273276
) -> dict[int, Any]:
274277
"""
275278
Launch a distributed PyTorch function on the specified nodes.
@@ -292,6 +295,8 @@ def launch(
292295
:type env_vars: list[str], optional
293296
:param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None
294297
:type env_file: str | os.PathLike | None, optional
298+
:param timeout: Worker process group timeout, defaults to 600
299+
:type timeout: int, optional
295300
:raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
296301
:return: A dictionary mapping worker ranks to their output
297302
:rtype: dict[int, Any]
@@ -304,4 +309,5 @@ def launch(
304309
log_dir=log_dir,
305310
env_vars=env_vars,
306311
env_file=env_file,
312+
timeout=timeout,
307313
).run(func=func, func_kwargs=func_kwargs)

src/torchrunx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class LauncherPayload:
2929
worker_global_ranks: list[list[int]]
3030
worker_log_files: list[list[Path]]
3131
backend: Literal["mpi", "gloo", "nccl", "ucc", None]
32+
timeout: int
3233

3334

3435
@dataclass

0 commit comments

Comments
 (0)