|
8 | 8 | from typing import Tuple
|
9 | 9 |
|
10 | 10 | import torch
|
| 11 | +from pippy import annotate_split_points, Pipe, PipeSplitWrapper |
11 | 12 | from torch.distributed._tensor import Replicate, Shard
|
12 | 13 |
|
13 | 14 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
@@ -143,7 +144,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
|
143 | 144 | """
|
144 | 145 | # apply PTD parallelisms
|
145 | 146 | if parallel_dims.pp_enabled:
|
146 |
| - raise NotImplementedError("PP not implemented yet.") |
| 147 | + pp_mesh = world_mesh["pp"] |
| 148 | + stage_idx = pp_mesh.get_local_rank() |
| 149 | + layers_per_rank = len(model.layers) // parallel_dims.pp |
| 150 | + for i in range(1, parallel_dims.pp): |
| 151 | + annotate_split_points( |
| 152 | + model, |
| 153 | + { |
| 154 | + f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING |
| 155 | + }, |
| 156 | + ) |
| 157 | + |
| 158 | + # Get example input |
| 159 | + label_shape = input_shape = (8, 2048) # TODO |
| 160 | + input_ids = torch.randint( |
| 161 | + model.vocab_size, input_shape, dtype=torch.int64, device="meta" |
| 162 | + ) |
| 163 | + labels = torch.randint( |
| 164 | + model.vocab_size, label_shape, dtype=torch.int64, device="meta" |
| 165 | + ) |
| 166 | + print("input_ids: ", input_ids.shape, input_ids.dtype) |
| 167 | + print("labels: ", labels.shape, labels.dtype) |
| 168 | + |
| 169 | + # Create a pipeline representation from the model |
| 170 | + pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,)) |
| 171 | + model = pipe.get_stage_module(stage_idx) |
147 | 172 |
|
148 | 173 | # First we apply Tensor Parallelism if it's enabled
|
149 | 174 | if parallel_dims.tp_enabled:
|
@@ -256,10 +281,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
|
256 | 281 | meta_to_real_init_fn(model)
|
257 | 282 | model.cuda()
|
258 | 283 |
|
259 |
| - # TODO(whc) - proposal: remove this call, and assert that we always load a checkpoint |
260 |
| - # we have now moved from meta to device, |
261 |
| - # reset parameters for proper initialization |
262 |
| - model.reset_parameters() |
263 |
| - logger.info("Model fully initialized via reset_parameters") |
| 284 | + if parallel_dims.pp_enabled: |
| 285 | + setattr(pipe.split_gm, f"submod_{stage_idx}", model) |
| 286 | + return pipe |
| 287 | + else: |
| 288 | + # TODO(whc) - proposal: remove this call, and assert that we always load a checkpoint |
| 289 | + # we have now moved from meta to device, |
| 290 | + # reset parameters for proper initialization |
| 291 | + model.reset_parameters() |
| 292 | + logger.info("Model fully initialized via reset_parameters") |
264 | 293 |
|
265 | 294 | return model
|
0 commit comments