Skip to content

Add backward RMSNorm+Add fusion #2028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

janekb04
Copy link
Contributor

@janekb04 janekb04 commented Aug 5, 2025

Description

Fuse MakeExtraOutput with RMSNorm in the backward pass. Works through adding nvte_rmsnorm_bwd_add.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • (8be541e) Update outdated comments
  • (165f3fb) Add nvte_rmsnorm_bwd_add as well as PyTorch binding, test, and update kernel to support the fusion
  • (c2d43b7) Add BackwardAddRMSNorm fused operation, including test

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10
Copy link
Collaborator

/te-ci

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, with some nits and future optimizations.

Comment on lines +20 to +21
auto kernel = launch_params.params.add ? &rmsnorm_bwd_tuned_kernel<Kernel_traits, true>
: &rmsnorm_bwd_tuned_kernel<Kernel_traits, false>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doubling the number of kernel compilations for RMSNorm. This is fine for now, but it makes an NVRTC impl more compelling.

@@ -111,6 +118,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
}
}

maybe_t<Ovec[LDGS], FusedAdd> add;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future optimization: Load add directly into dx to avoid extra register usage. However, this would require that Ivec and Ovec are identical types.

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
@janekb04 janekb04 force-pushed the normalization-backward-add-fusion branch from 22d91b0 to 65aeee7 Compare August 7, 2025 18:36
@timmoon10
Copy link
Collaborator

/te-ci core pytorch

@timmoon10
Copy link
Collaborator

/te-ci pytorch core

@timmoon10
Copy link
Collaborator

/te-ci core

…error in `test_normalization.cu`

Signed-off-by: Jan Bielak <[email protected]>
@timmoon10
Copy link
Collaborator

/te-ci core

@timmoon10
Copy link
Collaborator

/te-ci core

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants