-
Notifications
You must be signed in to change notification settings - Fork 500
[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel #437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/tianyu-l/12/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
… to accelerate Sequence Parallel" At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) This is almost a reverse of #190. [ghstack-poisoned]
… to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
@@ -350,37 +350,41 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): | |||
{ | |||
"tok_embeddings": RowwiseParallel( | |||
input_layouts=Replicate(), | |||
output_layouts=Shard(1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious: could this be output_layouts=Shard(0)
and then do not need the PrepareModuleInput
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu
Currently we are doing folding after embedding layer, so we can't do what you suggested.
But I just realize that maybe we can do folding even before embedding layer, then I think we can do this, just like the non-folding case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awgu
OK I tried out the change. Please see comparison here.
Everything works except the CI failure says
RuntimeError: It seems that we cannot capture your model as a full graph. Typical reasons include graph breaks, data/shape-dependent control flow, or missing meta kernels for custom operators. You can use our manual pipeline interfaces, or try to fix the graph breaks
So I decided to change it back.
torchtitan/models/llama/model.py
Outdated
@@ -187,7 +185,10 @@ def forward( | |||
torch.Tensor: Output tensor after attention. | |||
|
|||
""" | |||
bs, seqlen, _ = x.shape | |||
# dim 0 of x is a folded dimension of [bs, seqlen] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for consistency with other comments but does not matter since this is not for landing
# dim 0 of x is a folded dimension of [bs, seqlen] | |
# dim 0 of x is a folded dimension of (bs, seqlen) |
fwiw, this can also be achieved w/ torch.compile + force_stride_order w/o changing the model code. Basically, we can force the stride order of the all-gather/reduce-scatter input to be in a way such that Async-TP currently does this (example). With some work we can make it work for all-gather/reduce-scatter too. |
… to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
… to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
… to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
… to accelerate Sequence Parallel" Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
b0ed7f0
to
64d47fd
Compare
4945b85
to
43c08cd
Compare
Why is this marked as "example" and "do not merge"? What is the issue with this PR? Thanks! |
Stack from ghstack (oldest at bottom):
Note: This PR is for showcasing purpose only and is almost a reverse of #190.
At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra
aten.cat
after each collective.Stats from @awgu:
Experiment on 8-layer


debug_model
before:
after: