19
19
20
20
import torch
21
21
import torch .nn .functional as F
22
+
23
+ # TODO(whc) this can be removed after pippy migration into pytorch core is complete.
24
+ try :
25
+ from pippy import ScheduleGPipe
26
+ from pippy .PipelineStage import _PipelineStage
27
+ except ImportError as exc :
28
+ raise ImportError (
29
+ "pippy is not installed. Please install it to use pipeline parallelism. "
30
+ "`pip install git+https://github.com/pytorch/pippy`"
31
+ ) from exc
32
+
22
33
from torch .distributed import destroy_process_group
23
34
from torch .distributed .checkpoint .stateful import Stateful
24
35
from torch .distributed .elastic .multiprocessing .errors import record
@@ -126,7 +137,8 @@ def main(job_config: JobConfig):
126
137
world_size = world_size ,
127
138
enable_loss_parallel = job_config .training .enable_loss_parallel ,
128
139
)
129
- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
140
+ device = torch .device (f"cuda:{ int (os .environ ['LOCAL_RANK' ])} " )
141
+ torch .cuda .set_device (device )
130
142
init_distributed (job_config )
131
143
132
144
world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
@@ -144,6 +156,15 @@ def main(job_config: JobConfig):
144
156
dp_rank = dp_mesh .get_local_rank ()
145
157
else :
146
158
dp_degree , dp_rank = 1 , 0
159
+
160
+ if parallel_dims .pp_enabled :
161
+ pp_mesh = world_mesh ["pp" ]
162
+ pp_degree = pp_mesh .size ()
163
+ pp_rank = pp_mesh .get_local_rank ()
164
+
165
+ else :
166
+ pp_degree , pp_rank = 1 , 0
167
+
147
168
data_loader = build_hf_data_loader (
148
169
job_config .training .dataset ,
149
170
job_config .training .dataset_path ,
@@ -201,13 +222,44 @@ def loss_fn(pred, labels):
201
222
# obtain the peak flops of bf16 type for MFU calculation
202
223
gpu_peak_flops = get_peak_flops (gpu_memory_monitor .device_name )
203
224
204
- # apply PT-D parallelisms and activation checkpointing
225
+ if parallel_dims .pp_enabled :
226
+ # TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
227
+ from torchtitan .parallelisms .parallelize_llama import apply_pipeline_parallelism
228
+
229
+ model , pipe_info = apply_pipeline_parallelism (
230
+ model , world_mesh , parallel_dims , job_config
231
+ )
232
+
233
+ # apply PT-D DP/TP parallelisms and activation checkpointing
205
234
model = models_parallelize_fns [model_name ](
206
235
model , world_mesh , parallel_dims , job_config
207
236
)
208
- # allocate sharded model on GPU and initialize weights via DTensor
237
+
209
238
model .to_empty (device = "cuda" )
210
- model .init_weights ()
239
+
240
+ # TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
241
+ # there are virtual stages
242
+ if parallel_dims .pp_enabled :
243
+ stage = _PipelineStage (
244
+ stage_module = model ,
245
+ stage_index = pp_rank ,
246
+ pipe_info = pipe_info ,
247
+ device = device ,
248
+ group = pp_mesh .get_group (),
249
+ )
250
+ pp_schedule = ScheduleGPipe (
251
+ stage ,
252
+ n_microbatches = parallel_dims .pp ,
253
+ loss_fn = loss_fn ,
254
+ )
255
+ else :
256
+ # if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
257
+ # and loading it to get initialization values. This is becuase the init_weights functions are written assuming
258
+ # the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
259
+ # becuase it can't find "embedding" layer, for example.
260
+
261
+ # allocate sharded model on GPU and initialize weights via DTensor
262
+ model .init_weights ()
211
263
212
264
gpu_mem_stats = gpu_memory_monitor .get_peak_stats ()
213
265
logger .info (
@@ -219,7 +271,6 @@ def loss_fn(pred, labels):
219
271
# build optimizer after applying parallelisms to the model
220
272
optimizer = build_optimizer (model , job_config )
221
273
scheduler = get_lr_scheduler (optimizer , job_config )
222
-
223
274
metric_logger = build_metric_logger (job_config )
224
275
225
276
# torch.compile model for improved performance
@@ -257,7 +308,13 @@ def loss_fn(pred, labels):
257
308
logger .info ("Created seed checkpoint" )
258
309
return
259
310
260
- checkpoint .load ()
311
+ checkpoint_loaded = checkpoint .load ()
312
+
313
+ if parallel_dims .pp_enabled and not checkpoint_loaded :
314
+ raise RuntimeError (
315
+ "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
316
+ "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
317
+ )
261
318
262
319
# plot losses loaded from checkpoint (if any) to TensorBoard
263
320
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
@@ -299,14 +356,33 @@ def loss_fn(pred, labels):
299
356
300
357
input_ids = input_ids .cuda ()
301
358
labels = labels .cuda ()
302
-
303
359
optimizer .zero_grad ()
304
360
305
- # forward / backward
306
- with loss_parallel_ctx ():
307
- pred = model (input_ids )
308
- loss = loss_fn (pred , labels )
309
- loss .backward ()
361
+ if parallel_dims .pp_enabled :
362
+ # pipeline parallel forward / backward inside step() call
363
+ is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
364
+
365
+ with loss_parallel_ctx ():
366
+ if pp_mesh .get_local_rank () == 0 :
367
+ pp_schedule .step (input_ids )
368
+ elif is_last_stage :
369
+ losses = []
370
+ pp_schedule .step (target = labels , losses = losses )
371
+ else :
372
+ schedule .step ()
373
+
374
+ # accumulate losses across pipeline microbatches
375
+ loss = (
376
+ torch .mean (torch .stack (losses ))
377
+ if is_last_stage
378
+ else torch .Tensor ([- 1.0 ])
379
+ )
380
+ else :
381
+ # Non-PP forward / backward
382
+ with loss_parallel_ctx ():
383
+ pred = model (input_ids )
384
+ loss = loss_fn (pred , labels )
385
+ loss .backward ()
310
386
311
387
# clip gradients
312
388
torch .nn .utils .clip_grad_norm_ (
0 commit comments