15
15
16
16
import torch
17
17
import torch .nn .functional as F
18
+ from pippy .PipelineSchedule import PipelineScheduleGPipe
19
+ from pippy .PipelineStage import PipelineStage
18
20
from torch .distributed .elastic .multiprocessing .errors import record
19
21
from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
20
22
from torch .distributed .tensor .parallel import loss_parallel
@@ -129,7 +131,9 @@ def main(job_config: JobConfig):
129
131
world_size = world_size ,
130
132
enable_loss_parallel = job_config .training .enable_loss_parallel ,
131
133
)
132
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
134
+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
135
+ # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
136
+ torch .cuda .set_device (device )
133
137
init_distributed (job_config )
134
138
135
139
world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -148,6 +152,14 @@ def main(job_config: JobConfig):
148
152
dp_rank = dp_mesh .get_local_rank ()
149
153
else :
150
154
dp_degree , dp_rank = 1 , 0
155
+
156
+ if parallel_dims .pp_enabled :
157
+ pp_mesh = world_mesh ["pp" ]
158
+ pp_degree = pp_mesh .size ()
159
+ pp_rank = pp_mesh .get_local_rank ()
160
+ else :
161
+ pp_degree , pp_rank = 1 , 0
162
+
151
163
data_loader = build_dataloader_fn (
152
164
job_config .training .dataset ,
153
165
job_config .training .dataset_path ,
@@ -197,18 +209,54 @@ def main(job_config: JobConfig):
197
209
model = models_parallelize_fns [model_name ](
198
210
model , world_mesh , parallel_dims , job_config
199
211
)
200
- # allocate sharded model on GPU and initialize weights via DTensor
201
- model .to_empty (device = "cuda" )
202
- model .init_weights ()
203
-
204
- # build optimizer after applying parallelisms to the model
205
- optimizer = build_optimizer (model , job_config )
206
- scheduler = get_lr_scheduler (optimizer , job_config )
212
+ if parallel_dims .pp_enabled :
213
+ pipe_meta = model
214
+ model = pipe_meta .get_stage_module (pp_rank )
207
215
208
216
# build grad scaler which is effective only when mixed precision training
209
217
# is enabled with fp16 param dtype under FSDP
210
218
scaler = build_grad_scaler (model )
211
219
220
+ def loss_fn (pred , labels ):
221
+ with (
222
+ loss_parallel ()
223
+ if parallel_dims .loss_parallel_enabled
224
+ else contextlib .nullcontext ()
225
+ ):
226
+ loss = F .cross_entropy (pred .flatten (0 , 1 ), labels .flatten (0 , 1 ))
227
+
228
+ # backward on scaled loss to create scaled gradients
229
+ scaler .scale (loss )
230
+ return loss
231
+
232
+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
233
+ # there are virtual stages
234
+ if parallel_dims .pp_enabled :
235
+ stage = PipelineStage (
236
+ pipe = pipe_meta ,
237
+ stage_index = pp_rank ,
238
+ device = device ,
239
+ group = pp_mesh .get_group (),
240
+ )
241
+ pp_schedule = PipelineScheduleGPipe (
242
+ stage ,
243
+ n_microbatches = parallel_dims .pp ,
244
+ loss_fn = loss_fn ,
245
+ )
246
+ model .to_empty (device = "cuda" )
247
+ else :
248
+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
249
+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
250
+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
251
+ # becuase it can't find "embedding" layer, for example.
252
+
253
+ # allocate sharded model on GPU and initialize weights via DTensor
254
+ model .to_empty (device = "cuda" )
255
+ model .init_weights ()
256
+
257
+ # build optimizer after applying parallelisms to the model
258
+ optimizer = build_optimizer (model , job_config )
259
+ scheduler = get_lr_scheduler (optimizer , job_config )
212
260
metric_logger = build_metric_logger (job_config )
213
261
214
262
# torch.compile model for improved performance
@@ -278,21 +326,32 @@ def main(job_config: JobConfig):
278
326
279
327
input_ids = input_ids .cuda ()
280
328
labels = labels .cuda ()
281
-
282
329
optimizer .zero_grad ()
283
330
284
- # forward
285
- pred = model (input_ids )
331
+ if parallel_dims .pp_enabled :
332
+ # pipeline F/Loss/B
333
+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
286
334
287
- with (
288
- loss_parallel ()
289
- if parallel_dims .loss_parallel_enabled
290
- else contextlib .nullcontext ()
291
- ):
292
- loss = F .cross_entropy (pred .flatten (0 , 1 ), labels .flatten (0 , 1 ))
335
+ if pp_mesh .get_local_rank () == 0 :
336
+ pp_schedule .step (input_ids )
337
+ elif is_last_stage :
338
+ losses = []
339
+ pp_schedule .step (target = labels , losses = losses )
340
+ else :
341
+ schedule .step ()
342
+
343
+ # accumulate losses across pipeline microbatches
344
+ current_loss = (
345
+ torch .mean (torch .stack (losses )).item () if is_last_stage else - 1.0
346
+ )
347
+ else :
348
+ # non-pipeline F/Loss/B
349
+ pred = model (input_ids )
350
+
351
+ loss = loss_fn (pred , labels )
352
+ loss .backward ()
293
353
294
- # backward on scaled loss to create scaled gradients
295
- scaler .scale (loss ).backward ()
354
+ current_loss = loss .item ()
296
355
297
356
# clip gradients (after unscaling gradients of the optimizer's params)
298
357
scaler .unscale_ (optimizer )
@@ -309,7 +368,6 @@ def main(job_config: JobConfig):
309
368
# updates the scale for next iteration
310
369
scaler .update ()
311
370
312
- current_loss = loss .item ()
313
371
losses_since_last_log .append (current_loss )
314
372
315
373
# log metrics
0 commit comments