-
Notifications
You must be signed in to change notification settings - Fork 476
Add backward RMSNorm+Add fusion #2028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Jan Bielak <[email protected]>
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, with some nits and future optimizations.
auto kernel = launch_params.params.add ? &rmsnorm_bwd_tuned_kernel<Kernel_traits, true> | ||
: &rmsnorm_bwd_tuned_kernel<Kernel_traits, false>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are doubling the number of kernel compilations for RMSNorm. This is fine for now, but it makes an NVRTC impl more compelling.
@@ -111,6 +118,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke | |||
} | |||
} | |||
|
|||
maybe_t<Ovec[LDGS], FusedAdd> add; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future optimization: Load add
directly into dx
to avoid extra register usage. However, this would require that Ivec
and Ovec
are identical types.
transformer_engine/common/include/transformer_engine/normalization.h
Outdated
Show resolved
Hide resolved
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Jan Bielak <[email protected]> Signed-off-by: Jan Bielak <[email protected]>
22d91b0
to
65aeee7
Compare
Signed-off-by: Tim Moon <[email protected]>
/te-ci core pytorch |
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
/te-ci pytorch core |
Signed-off-by: Jan Bielak <[email protected]>
/te-ci core |
…error in `test_normalization.cu` Signed-off-by: Jan Bielak <[email protected]>
/te-ci core |
Signed-off-by: Jan Bielak <[email protected]>
/te-ci core |
Description
Fuse
MakeExtraOutput
withRMSNorm
in the backward pass. Works through addingnvte_rmsnorm_bwd_add
.Type of change
Changes
nvte_rmsnorm_bwd_add
as well as PyTorch binding, test, and update kernel to support the fusionBackwardAddRMSNorm
fused operation, including testChecklist: