-
Notifications
You must be signed in to change notification settings - Fork 461
Add Sequence Parallelism to llama #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
parallelize_plan=layer_plan, | ||
) | ||
|
||
rank0_log(f"Applied Sequence Parallelism to the model...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if its useful to log more info about the SP plan. I was thinking about it for PP, what info do we want to print. Should each parallelism print its own summary, or should we have one overall function that prints overall parallel info in a unified way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 That's a good point. I think yeah we should probably log the parallelize plan for SP. This would require some changes in PyTorch to add __str__
to our ParallelStyles, I can add the log once the PyTorch PR is merged.
Should each parallelism print its own summary, or should we have one overall function that prints overall parallel info in a unified way
My two cents: It's a bit tricky to give overall summary. I think we can figure out how to even print the intended summary for each parallelism first, i.e. when transformerblock stacked too many, we can't log/print every layer parallel plan, so I think maybe we print pp
degree of transformerblock, and we might not want to print the SP plan for each PP transformerblock.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great to me! one inline question
distribute_rmsnorm(transformer_block.attention_norm, tp_mesh) | ||
distribute_rmsnorm(transformer_block.ffn_norm, tp_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also apply it on the final norm after all transformer blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sth currently enabled, but I think we can explore this in real training and see if shard the final norm would give additional memory/perf benefits :)
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
* Add other filtering options beyond group * Address review comments
Stack from ghstack (oldest at bottom):
Somehow the torch.compile not working although eager sequence
parallelism working, so currently don't turn it on by default