121
121
Components APIs
122
122
-----------------
123
123
"""
124
+ import shlex
124
125
from pathlib import Path
125
- from typing import Dict , Optional
126
+ from typing import Dict , Optional , Iterable
126
127
127
128
import torchx
128
129
import torchx .specs as specs
131
132
132
133
def ddp (
133
134
* script_args : str ,
134
- script : str ,
135
+ script : Optional [str ] = None ,
136
+ m : Optional [str ] = None ,
135
137
image : str = torchx .IMAGE ,
136
138
name : Optional [str ] = None ,
139
+ h : Optional [str ] = None ,
137
140
cpu : int = 2 ,
138
141
gpu : int = 0 ,
139
142
memMB : int = 1024 ,
140
- h : Optional [str ] = None ,
141
143
j : str = "1x2" ,
142
144
env : Optional [Dict [str , str ]] = None ,
143
- rdzv_endpoint : str = "etcd-server.default.svc.cluster.local:2379" ,
145
+ max_restarts : Optional [int ] = None ,
146
+ rdzv_backend : str = "c10d" ,
147
+ rdzv_endpoint : Optional [str ] = None ,
144
148
) -> specs .AppDef :
145
149
"""
146
150
Distributed data parallel style application (one role, multi-replica).
@@ -154,6 +158,7 @@ def ddp(
154
158
Args:
155
159
script_args: arguments to the main module
156
160
script: script or binary to run within the image
161
+ m: the python module path to run
157
162
image: image (e.g. docker)
158
163
name: job name override (uses the script name if not specified)
159
164
cpu: number of cpus per replica
@@ -162,9 +167,14 @@ def ddp(
162
167
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
163
168
j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
164
169
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
165
- rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
170
+ max_restarts: the number of restarts allowed
171
+ rdzv_backend: rendezvous backend (only matters when nnodes > 1)
172
+ rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
166
173
"""
167
174
175
+ if (script is None ) == (m is None ):
176
+ raise ValueError ("exactly one of --script and -m must be specified" )
177
+
168
178
rep = j .split ("x" )
169
179
if len (rep ) == 1 : # num replicas only
170
180
nnodes = 1
@@ -175,33 +185,79 @@ def ddp(
175
185
else :
176
186
raise ValueError (f"Invalid format for -j, usage example: 1x4. Given: { j } " )
177
187
178
- script_name_noext = Path (script ).stem # script name no extension
188
+ if script :
189
+ # script name/module no extension
190
+ role_name = Path (script ).stem
191
+ elif m :
192
+ role_name = m .rpartition ("." )[2 ]
193
+ else :
194
+ raise ValueError ("failed to compute role_name" )
195
+
196
+ if rdzv_endpoint is None :
197
+ rdzv_endpoint = _noquote (f"$${ macros .rank0_env } :29500" )
198
+
199
+ if nnodes == 1 :
200
+ rdzv_backend = "c10d"
201
+ rdzv_endpoint = "localhost:29500"
202
+
203
+ if env is None :
204
+ env = {}
205
+ env .setdefault ("LOGLEVEL" , "INFO" )
206
+
207
+ cmd = [
208
+ "python" ,
209
+ "-m" ,
210
+ "torch.distributed.run" ,
211
+ "--rdzv_backend" ,
212
+ rdzv_backend ,
213
+ "--rdzv_endpoint" ,
214
+ rdzv_endpoint ,
215
+ "--rdzv_id" ,
216
+ f"{ macros .app_id } " ,
217
+ "--nnodes" ,
218
+ str (nnodes ),
219
+ "--nproc_per_node" ,
220
+ str (nproc_per_node ),
221
+ ]
222
+ if max_restarts is not None :
223
+ cmd += ["--max_restarts" , str (max_restarts )]
224
+ if script is not None :
225
+ cmd += [script ]
226
+ elif m is not None :
227
+ cmd += ["-m" , m ]
228
+ cmd += script_args
179
229
return specs .AppDef (
180
- name = name or script_name_noext ,
230
+ name = name or role_name ,
181
231
roles = [
182
232
specs .Role (
183
- name = script_name_noext ,
233
+ name = role_name ,
184
234
image = image ,
185
- entrypoint = "python " ,
235
+ entrypoint = "bash " ,
186
236
num_replicas = nnodes ,
187
237
resource = specs .resource (cpu = cpu , gpu = gpu , memMB = memMB , h = h ),
188
- args = [
189
- "-m" ,
190
- "torch.distributed.run" ,
191
- "--rdzv_backend" ,
192
- ("c10d" if nnodes == 1 else "etcd" ),
193
- "--rdzv_endpoint" ,
194
- ("localhost:29500" if nnodes == 1 else rdzv_endpoint ),
195
- "--rdzv_id" ,
196
- f"{ macros .app_id } " ,
197
- "--nnodes" ,
198
- str (nnodes ),
199
- "--nproc_per_node" ,
200
- str (nproc_per_node ),
201
- script ,
202
- * script_args ,
203
- ],
204
- env = env or {},
238
+ args = ["-c" , _args_join (cmd )],
239
+ env = env ,
240
+ port_map = {
241
+ "c10d" : 29500 ,
242
+ },
205
243
)
206
244
],
207
245
)
246
+
247
+
248
+ def _args_join (args : Iterable [str ]) -> str :
249
+ """
250
+ _args_join is like shlex.join but if the argument is wrapped in _noquote
251
+ it'll not quote that argument.
252
+ """
253
+ quoted = [arg if isinstance (arg , _noquote ) else shlex .quote (arg ) for arg in args ]
254
+ return " " .join (quoted )
255
+
256
+
257
+ class _noquote (str ):
258
+ """
259
+ _noquote is a wrapper around str that indicates that the argument shouldn't
260
+ be passed through shlex.quote.
261
+ """
262
+
263
+ pass
0 commit comments