-
Notifications
You must be signed in to change notification settings - Fork 549
Add Pipeline Parallel (and 2D PP+FSDP) support #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
train.py
Outdated
logger.info( | ||
f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}" | ||
) | ||
model = pmod.get_stage_module(pp_mesh.get_local_rank()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: watch out for rank-stage inequality in case of Interleaved 1F1B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, i need to switch to an interleaved schedule and clean this up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the demo! LGTM!
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 6bd8013 Pull Request resolved: #161
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Thanks for pulling PP in!
print("labels: ", labels.shape, labels.dtype) | ||
|
||
# Create a pipeline representation from the model | ||
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: strictly speaking, the second arg is the number of microbatches -- it is okay if you using PP dim to represent it for now. Longer term I think it should be exposed as a field in the config file.
) | ||
|
||
# Get example input | ||
label_shape = input_shape = (8, 2048) # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm would PP be working for all cases that are not this shape, or it requires the shape to be the exact input shape of the runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to double check how this works and fix.
# TODO(whc) need to fix PP + FSDP-mixed-precision | ||
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs | ||
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
param_dtype=torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't by default change this, this would make the cases where FSDP or FSDP + TP use fp32 instead of bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if supporting bf16 should be a criteria for landing. I would imagine that training with FSDP + PP in fp32 is not really viable efficiency-wise (at least for larger jobs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should fix this before landing the PP change. I think there was a possible way to fix this in the tracer, but lost track of it, will dig it up
train.py
Outdated
# there are virtual stages | ||
if parallel_dims.pp_enabled: | ||
stage = PipelineStage( | ||
pipe=pipe_meta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be pipe_meta
or model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its correct. Ke proposed an alternative, but we'd still have to pass the pipe_info and the model into _PipelineStage
in that case. I could make this change.
pipe=pipe_meta, | ||
stage_index=pp_rank, | ||
device=device, | ||
group=pp_mesh.get_group(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if we should put the stage creation into parallelize_llama, IMO we only need pp_schedule in train.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, I think this question and Ke's suggestion about returning a PipelineStage from parallelize_llama are better taken in context of a next PR that also adds support for looped schedules.
Looped schedules further complicate things bc the PP logic first needs to chunk up the model, then apply the DP/TP portion of parallelize_llama on each chunk, and finally pass all the chunks into the schedule.
I think in the end, I might prefer to separate out PP from parallelize_llama, and have a flow where we can take the return from PP apply function and iteratively call parallelize_llama on those chunks.
loss = ( | ||
torch.mean(torch.stack(losses)) | ||
if is_last_stage | ||
else torch.Tensor([-1.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need the default -1 value? because of logging purpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, yea i could make it a 'None' but then i have to update logger to not log at all. maybe that's actually a better way to do it. let me try that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok- so what I could do is try to alter the metrics code so that on non-last-stage ranks, we omit printing loss, or, we print "loss: None" instead of -1.
The change will add more lines of code, since I need to deal with several places that expect loss and global_[avg/mean]_loss to be valid numbers
- avoid writing them into metrics dict
- replace their format string with a string value instead of a float value in the logger.info
- avoid calling loss.item() in the first place
I agree in principle that's the "right" fix, but i'm not sure if its worth the LOC / complexity. I don't totally hate the -1
thing.
Another option I considered is to skip the whole codeblock of '# log metrics' on non-last-stage ranks. I ruled this out, since it is still useful to log mfu, memory for other ranks.
So let me know what you want to do here @wanchaol
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 205f8b0 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: cbbb628 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 94f89f9 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: ac8c371 Pull Request resolved: #161
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: feb45e1 Pull Request resolved: #161
for i in range(1, parallel_dims.pp) | ||
} | ||
# Get example input | ||
label_shape = input_shape = (8, 2048) # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 any ideas for a clean way to do this in torchtrain? do we expect people to get a batch out of their dataloader and then reset it? or do we expect people to hardcode it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think what i might do is directly pass input_shape from train.py,
and in train.py i can set input_shape = (job_config.batch_size, job_config.seq_len) or something. is that clean enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok pushed a variation on this.
not sure if its better to hide this inside parallelize since we already have job config, or make it explicit from train.py that we are passing input_shape in for some reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way sounds okay to me -- eventually, the shape comes the config.
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 0616a1c Pull Request resolved: #161
layers_per_rank = len(model.layers) // parallel_dims.pp | ||
split_spec = { | ||
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING | ||
for i in range(1, parallel_dims.pp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm new to PP api and have a question:
If layers_per_rank
= 5, parallel_dims.pp
= 2, what should be the split_spec. My straightforward thought is SplitPoint.BEGINNING
should contain i = 1, 3, 5
, but according to the code it's just i = 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parallel_dims.pp
refers to the number of pipeline stages we split the model into.
For example, if model.layers
= 10, 10 // 2 = 5, then we put 5 layers per stage (i.e. layers_per_rank = 5
).
Hence we make a cut at model.layers.5
-- (nRanks - 1) split points.
squashed |
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: 0616a1c Pull Request resolved: #161
Stack from ghstack (oldest at bottom):
chunks per stage
schedule and test other schedules)