-
Notifications
You must be signed in to change notification settings - Fork 49
Layernorm bwd OPT #1880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Layernorm bwd OPT #1880
Conversation
Maybe we can follow recent cuda change here pytorch/pytorch@73b4938 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the backward pass computation for layer normalization's gamma and beta gradients by implementing a two-stage column reduction approach to improve parallelism when the matrix dimension M is much larger than N.
- Introduces a new
GammaBetaReduceFunctor
kernel that uses tiled computation with local memory for better occupancy - Adds logic to automatically select between the optimized two-stage reduction and the existing simple kernel based on occupancy thresholds
- Implements separate code paths for different combinations of gamma and beta gradient computations
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
e8bef72
to
0bfee0f
Compare
I noticed layer norm backward on gamma and beta is very slow when column is much longer. i.e. [M,N] column reduction and M>>N.
For example, in timm tnt_s_patch16_224 training, layernorm bwd shape [25088,16,24], normalized shape [24]. it will only launch one workgroup. I use a two staged column reduction to increase parallelism. GammaBetaBackwardSimpleKernelFunctor takes 9 ms on PVC, 8.5ms on BMG. After opt, we use GammaBetaReduceFunctor and two sum to do column reduction, they will take 0.09ms + 0.06ms x2 on PVC and 0.19ms + 0.04ms x 2 on BMG