6
6
import socket
7
7
import sys
8
8
import tempfile
9
+ import traceback
9
10
from dataclasses import dataclass
10
11
from typing import Any , Callable , Literal
11
12
12
13
import cloudpickle
13
14
import torch
14
15
import torch .distributed as dist
15
- from torch .distributed .elastic .multiprocessing import start_processes
16
- from typing_extensions import Self
16
+ import torch .distributed .elastic .multiprocessing as dist_mp
17
17
18
18
from .logging_utils import log_records_to_socket , redirect_stdio_to_logger
19
19
from .utils import (
@@ -40,16 +40,20 @@ class WorkerArgs:
40
40
hostname : str
41
41
timeout : int
42
42
43
- def to_bytes (self ) -> bytes :
44
- return cloudpickle . dumps ( self )
43
+ def serialize (self ) -> SerializedWorkerArgs :
44
+ return SerializedWorkerArgs ( worker_args = self )
45
45
46
- @classmethod
47
- def from_bytes (cls , serialized : bytes ) -> Self :
48
- return cloudpickle .loads (serialized )
49
46
47
+ class SerializedWorkerArgs :
48
+ def __init__ (self , worker_args : WorkerArgs ) -> None :
49
+ self .bytes = cloudpickle .dumps (worker_args )
50
50
51
- def entrypoint (serialized_worker_args : bytes ) -> Any | WorkerException :
52
- worker_args = WorkerArgs .from_bytes (serialized_worker_args )
51
+ def deserialize (self ) -> WorkerArgs :
52
+ return cloudpickle .loads (self .bytes )
53
+
54
+
55
+ def entrypoint (serialized_worker_args : SerializedWorkerArgs ) -> Any | WorkerException :
56
+ worker_args : WorkerArgs = serialized_worker_args .deserialize ()
53
57
54
58
logger = logging .getLogger ()
55
59
@@ -63,18 +67,14 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
63
67
64
68
redirect_stdio_to_logger (logger )
65
69
66
- store = dist .TCPStore ( # pyright: ignore[reportPrivateImportUsage]
70
+ store = dist .TCPStore ( # pyright: ignore [reportPrivateImportUsage]
67
71
host_name = worker_args .main_agent_hostname ,
68
72
port = worker_args .main_agent_port ,
69
73
world_size = worker_args .world_size ,
70
74
is_master = (worker_args .rank == 0 ),
71
75
)
72
76
73
- backend = worker_args .backend
74
- if backend is None :
75
- backend = "nccl" if torch .cuda .is_available () else "gloo"
76
-
77
- logger .debug (f"using backend: { backend } " )
77
+ backend = worker_args .backend or ("nccl" if torch .cuda .is_available () else "gloo" )
78
78
79
79
dist .init_process_group (
80
80
backend = backend ,
@@ -91,19 +91,17 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException:
91
91
os .environ ["MASTER_ADDR" ] = worker_args .main_agent_hostname
92
92
os .environ ["MASTER_PORT" ] = str (worker_args .main_agent_port )
93
93
94
- logger .debug (f"executing function: { worker_args .function } " )
95
-
96
94
try :
97
95
return worker_args .function ()
98
96
except Exception as e :
99
- logger . error ( e )
97
+ traceback . print_exc ( )
100
98
return WorkerException (exception = e )
101
99
finally :
102
100
sys .stdout .flush ()
103
101
sys .stderr .flush ()
104
102
105
103
106
- def main (launcher_agent_group : LauncherAgentGroup , logger_hostname : str , logger_port : int ):
104
+ def main (launcher_agent_group : LauncherAgentGroup , logger_hostname : str , logger_port : int ) -> None :
107
105
agent_rank = launcher_agent_group .rank - 1
108
106
109
107
payload = AgentPayload (
@@ -132,16 +130,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
132
130
133
131
redirect_stdio_to_logger (logger )
134
132
135
- if torch .__version__ >= "2.3" :
136
- from torch .distributed .elastic .multiprocessing import DefaultLogsSpecs
137
-
138
- log_kwargs = {"logs_specs" : DefaultLogsSpecs (log_dir = tempfile .mkdtemp ())}
139
- else :
140
- log_kwargs = {"log_dir" : tempfile .mkdtemp ()}
141
-
142
133
# spawn workers
143
134
144
- ctx = start_processes (
135
+ ctx = dist_mp . start_processes (
145
136
name = f"{ hostname } _" ,
146
137
entrypoint = entrypoint ,
147
138
args = {
@@ -159,31 +150,30 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
159
150
world_size = worker_world_size ,
160
151
hostname = launcher_payload .hostnames [agent_rank ],
161
152
timeout = launcher_payload .timeout ,
162
- ).to_bytes (),
153
+ ).serialize (),
163
154
)
164
155
for i in range (num_workers )
165
156
},
166
157
envs = {i : {} for i in range (num_workers )},
167
- ** log_kwargs , # pyright: ignore [reportArgumentType]
158
+ ** (
159
+ {"logs_specs" : dist_mp .DefaultLogsSpecs (log_dir = tempfile .mkdtemp ())}
160
+ if torch .__version__ >= "2.3"
161
+ else {"log_dir" : tempfile .mkdtemp ()}
162
+ ), # pyright: ignore [reportArgumentType]
168
163
)
169
- logger .info ("starting processes" )
170
164
171
165
try :
172
166
status = None
173
167
while True :
174
168
if status is None or status .state == "running" :
175
- status = AgentStatus .from_result (
176
- result = ctx .wait (5 ), worker_global_ranks = worker_global_ranks
177
- )
169
+ status = AgentStatus .from_result (ctx .wait (5 ))
178
170
179
171
agent_statuses = launcher_agent_group .sync_agent_statuses (status = status )
180
172
181
- if all (s .state == "done" for s in agent_statuses ):
182
- break
183
- elif any ( s . state == "failed" for s in agent_statuses ) :
173
+ all_done = all (s .state == "done" for s in agent_statuses )
174
+ any_failed = any ( s . state == "failed" for s in agent_statuses )
175
+ if all_done or any_failed :
184
176
break
185
- except :
186
- raise
187
177
finally :
188
178
ctx .close ()
189
179
sys .stdout .flush ()
0 commit comments