15
15
16
16
import torch
17
17
import torch .nn .functional as F
18
+ from pippy .PipelineSchedule import PipelineScheduleGPipe
19
+ from pippy .PipelineStage import PipelineStage
18
20
from torch .distributed .elastic .multiprocessing .errors import record
19
21
from torch .distributed .tensor .parallel import loss_parallel
20
22
@@ -120,7 +122,9 @@ def main(job_config: JobConfig):
120
122
world_size = world_size ,
121
123
enable_loss_parallel = job_config .training .enable_loss_parallel ,
122
124
)
123
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
125
+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
126
+ # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
127
+ torch .cuda .set_device (device )
124
128
init_distributed (job_config )
125
129
126
130
world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -139,6 +143,14 @@ def main(job_config: JobConfig):
139
143
dp_rank = dp_mesh .get_local_rank ()
140
144
else :
141
145
dp_degree , dp_rank = 1 , 0
146
+
147
+ if parallel_dims .pp_enabled :
148
+ pp_mesh = world_mesh ["pp" ]
149
+ pp_degree = pp_mesh .size ()
150
+ pp_rank = pp_mesh .get_local_rank ()
151
+ else :
152
+ pp_degree , pp_rank = 1 , 0
153
+
142
154
data_loader = build_dataloader_fn (
143
155
job_config .training .dataset ,
144
156
job_config .training .dataset_path ,
@@ -197,14 +209,38 @@ def loss_fn(pred, labels):
197
209
model = models_parallelize_fns [model_name ](
198
210
model , world_mesh , parallel_dims , job_config
199
211
)
200
- # allocate sharded model on GPU and initialize weights via DTensor
212
+ if parallel_dims .pp_enabled :
213
+ pipe_meta = model
214
+ model = pipe_meta .get_stage_module (pp_rank )
215
+
201
216
model .to_empty (device = "cuda" )
202
- model .init_weights ()
217
+
218
+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
219
+ # there are virtual stages
220
+ if parallel_dims .pp_enabled :
221
+ stage = PipelineStage (
222
+ pipe = pipe_meta ,
223
+ stage_index = pp_rank ,
224
+ device = device ,
225
+ group = pp_mesh .get_group (),
226
+ )
227
+ pp_schedule = PipelineScheduleGPipe (
228
+ stage ,
229
+ n_microbatches = parallel_dims .pp ,
230
+ loss_fn = loss_fn ,
231
+ )
232
+ else :
233
+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
234
+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
235
+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
236
+ # becuase it can't find "embedding" layer, for example.
237
+
238
+ # allocate sharded model on GPU and initialize weights via DTensor
239
+ model .init_weights ()
203
240
204
241
# build optimizer after applying parallelisms to the model
205
242
optimizer = build_optimizer (model , job_config )
206
243
scheduler = get_lr_scheduler (optimizer , job_config )
207
-
208
244
metric_logger = build_metric_logger (job_config )
209
245
210
246
# torch.compile model for improved performance
@@ -274,13 +310,30 @@ def loss_fn(pred, labels):
274
310
275
311
input_ids = input_ids .cuda ()
276
312
labels = labels .cuda ()
277
-
278
313
optimizer .zero_grad ()
279
314
280
- # forward / backward
281
- pred = model (input_ids )
282
- loss = loss_fn (pred , labels )
283
- loss .backward ()
315
+ if parallel_dims .pp_enabled :
316
+ # pipeline parallel forward / backward inside step() call
317
+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
318
+
319
+ if pp_mesh .get_local_rank () == 0 :
320
+ pp_schedule .step (input_ids )
321
+ elif is_last_stage :
322
+ losses = []
323
+ pp_schedule .step (target = labels , losses = losses )
324
+ else :
325
+ schedule .step ()
326
+
327
+ # accumulate losses across pipeline microbatches
328
+ current_loss = (
329
+ torch .mean (torch .stack (losses )).item () if is_last_stage else - 1.0
330
+ )
331
+ else :
332
+ # forward / backward
333
+ pred = model (input_ids )
334
+ loss = loss_fn (pred , labels )
335
+ loss .backward ()
336
+ current_loss = loss .item ()
284
337
285
338
# clip gradients
286
339
torch .nn .utils .clip_grad_norm_ (
@@ -291,7 +344,6 @@ def loss_fn(pred, labels):
291
344
optimizer .step ()
292
345
scheduler .step ()
293
346
294
- current_loss = loss .item ()
295
347
losses_since_last_log .append (current_loss )
296
348
297
349
# log metrics
0 commit comments