|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.nn.functional as F
|
| 17 | +from pippy.PipelineSchedule import ( |
| 18 | + create_metadata_tensor, |
| 19 | + extract_metadata_from_tensor, |
| 20 | + get_stage_shapes, |
| 21 | + PipelineStageV2Impl, |
| 22 | + validate_stage_shapes, |
| 23 | +) |
17 | 24 | from torch.distributed.elastic.multiprocessing.errors import record
|
18 | 25 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
19 | 26 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
@@ -195,9 +202,39 @@ def main(job_config: JobConfig):
|
195 | 202 | pp_degree = pp_mesh.size()
|
196 | 203 | pp_rank = pp_mesh.get_local_rank()
|
197 | 204 | logger.info(
|
198 |
| - f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}" |
| 205 | + f"{Color.blue}Extracting pipeline module for stage {pp_rank()}{Color.reset}" |
199 | 206 | )
|
200 | 207 | model = pmod.get_stage_module(pp_mesh.get_local_rank())
|
| 208 | + input_shape = label_shape = (8, 2048) |
| 209 | + device=torch.device(f"cuda:{int(os.environ["LOCAL_RANK"])}") |
| 210 | + microbatch = torch.empty(input_shape, dtype=torch.int64, device=device) |
| 211 | + shape_meta = get_stage_shapes( |
| 212 | + models=[model], |
| 213 | + stage_ids=[pp_rank], |
| 214 | + num_stages=pp_degree, |
| 215 | + rank=pp_rank, |
| 216 | + world_size=pp_degree, |
| 217 | + group=pp_mesh.get_group(), |
| 218 | + device=device, |
| 219 | + microbatch=[microbatch], |
| 220 | + ) |
| 221 | + input_args = [ |
| 222 | + torch.empty(s, device=device) for s in shape_meta[pp_stage_id]["inputs"] |
| 223 | + ] |
| 224 | + output_args = [ |
| 225 | + torch.empty(s, device=device) for s in shape_meta[pp_stage_id]["outputs"] |
| 226 | + ] |
| 227 | + stage = PipelineStageV2Impl( |
| 228 | + module=model, |
| 229 | + stage_id=pp_rank, |
| 230 | + num_stages: pp_degree, |
| 231 | + rank=world_mesh.get_rank(), |
| 232 | + world_size=world_mesh.size(), |
| 233 | + device=device, |
| 234 | + input_args=input_args, |
| 235 | + output_args=output_args, |
| 236 | + ) |
| 237 | + print(stage) |
201 | 238 |
|
202 | 239 | # build optimizer after applying parallelisms to the model
|
203 | 240 | optimizer = build_optimizer(model, job_config)
|
|
0 commit comments