Skip to content

Commit df55fbf

Browse files
committed
wip pipelinestage
ghstack-source-id: be5cff6 Pull Request resolved: #174
1 parent 6a430d8 commit df55fbf

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

train.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17+
from pippy.PipelineSchedule import (
18+
create_metadata_tensor,
19+
extract_metadata_from_tensor,
20+
get_stage_shapes,
21+
PipelineStageV2Impl,
22+
validate_stage_shapes,
23+
)
1724
from torch.distributed.elastic.multiprocessing.errors import record
1825
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1926
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@@ -195,9 +202,39 @@ def main(job_config: JobConfig):
195202
pp_degree = pp_mesh.size()
196203
pp_rank = pp_mesh.get_local_rank()
197204
logger.info(
198-
f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}"
205+
f"{Color.blue}Extracting pipeline module for stage {pp_rank()}{Color.reset}"
199206
)
200207
model = pmod.get_stage_module(pp_mesh.get_local_rank())
208+
input_shape = label_shape = (8, 2048)
209+
device=torch.device(f"cuda:{int(os.environ["LOCAL_RANK"])}")
210+
microbatch = torch.empty(input_shape, dtype=torch.int64, device=device)
211+
shape_meta = get_stage_shapes(
212+
models=[model],
213+
stage_ids=[pp_rank],
214+
num_stages=pp_degree,
215+
rank=pp_rank,
216+
world_size=pp_degree,
217+
group=pp_mesh.get_group(),
218+
device=device,
219+
microbatch=[microbatch],
220+
)
221+
input_args = [
222+
torch.empty(s, device=device) for s in shape_meta[pp_stage_id]["inputs"]
223+
]
224+
output_args = [
225+
torch.empty(s, device=device) for s in shape_meta[pp_stage_id]["outputs"]
226+
]
227+
stage = PipelineStageV2Impl(
228+
module=model,
229+
stage_id=pp_rank,
230+
num_stages: pp_degree,
231+
rank=world_mesh.get_rank(),
232+
world_size=world_mesh.size(),
233+
device=device,
234+
input_args=input_args,
235+
output_args=output_args,
236+
)
237+
print(stage)
201238

202239
# build optimizer after applying parallelisms to the model
203240
optimizer = build_optimizer(model, job_config)

0 commit comments

Comments
 (0)