Skip to content

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jul 4, 2024

Stack from ghstack (oldest at bottom):

Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra aten.cat after each collective.

Stats from @awgu:

for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer debug_model
before:
image
after:
image

tianyu-l added a commit that referenced this pull request Jul 4, 2024
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 4, 2024
@tianyu-l tianyu-l changed the title fold batch and sequence dimensions to accelerate Sequence Parallel [DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel Jul 4, 2024
… to accelerate Sequence Parallel"


At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

This is almost a reverse of #190.

[ghstack-poisoned]
… to accelerate Sequence Parallel"


Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer `debug_model`
before:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796">
after:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0">

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 4, 2024
@@ -350,37 +350,41 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: could this be output_layouts=Shard(0) and then do not need the PrepareModuleInput?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu
Currently we are doing folding after embedding layer, so we can't do what you suggested.
But I just realize that maybe we can do folding even before embedding layer, then I think we can do this, just like the non-folding case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu
OK I tried out the change. Please see comparison here.
Everything works except the CI failure says

RuntimeError: It seems that we cannot capture your model as a full graph. Typical reasons include graph breaks, data/shape-dependent control flow, or missing meta kernels for custom operators. You can use our manual pipeline interfaces, or try to fix the graph breaks

So I decided to change it back.

@@ -187,7 +185,10 @@ def forward(
torch.Tensor: Output tensor after attention.

"""
bs, seqlen, _ = x.shape
# dim 0 of x is a folded dimension of [bs, seqlen]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for consistency with other comments but does not matter since this is not for landing

Suggested change
# dim 0 of x is a folded dimension of [bs, seqlen]
# dim 0 of x is a folded dimension of (bs, seqlen)

@yifuwang
Copy link

yifuwang commented Jul 8, 2024

fwiw, this can also be achieved w/ torch.compile + force_stride_order w/o changing the model code.

Basically, we can force the stride order of the all-gather/reduce-scatter input to be in a way such that input.swapdim(0, dim) is contiguous, and we stop enforcing the contiguity of all-gather/reduce-scatter outputs. This way, the layout transformation will be subsumed into the leading/following pointwise ops.

Async-TP currently does this (example). With some work we can make it work for all-gather/reduce-scatter too.

… to accelerate Sequence Parallel"


Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer `debug_model`
before:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796">
after:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0">

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 9, 2024
… to accelerate Sequence Parallel"


Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer `debug_model`
before:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796">
after:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0">

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 9, 2024
… to accelerate Sequence Parallel"


Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer `debug_model`
before:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796">
after:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0">

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 9, 2024
… to accelerate Sequence Parallel"


Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding,  all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective.

Stats from awgu:
> for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer `debug_model`
before:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796">
after:
<img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0">

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 10, 2024
@tianyu-l tianyu-l force-pushed the gh/tianyu-l/12/base branch from b0ed7f0 to 64d47fd Compare August 16, 2024 21:00
@tianyu-l tianyu-l force-pushed the gh/tianyu-l/12/head branch from 4945b85 to 43c08cd Compare August 16, 2024 21:00
tianyu-l added a commit that referenced this pull request Aug 16, 2024
@lw
Copy link

lw commented Dec 31, 2024

Why is this marked as "example" and "do not merge"? What is the issue with this PR? Thanks!

@awgu
Copy link
Collaborator

awgu commented Dec 31, 2024

@lw because this requires changing the model code, I think @tianyu-l left it as a non-merged PR to show how people could change their own (forked) code to enable this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants