Skip to content

[ATen][Native][CUDA][SCALED_MM] limit f8f8bf16 rowwise scaled matmul to sm_90 #145728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Jan 27, 2025

The CUTLASS-based kernel for f8f8bf16 rowwise scaled matmul is specific to Hopper devices only. It is not re-usable on newer devices without modifications. This PR adds a guard for this matmul to be sm_90 specific. Once the kernel is there, the guard may be removed.

cc @ptrblck @msaroufim @eqy @yanbing-j @vkuzo @albanD @kadeng @penguinwu @manuelcandales @SherlockNoMad @angelayi

Copy link

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145728

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d8acae8 with merge base 71caac2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Aidyn-A Aidyn-A requested review from malfet and removed request for eqy and syed-ahmed January 27, 2025 11:41
@Aidyn-A Aidyn-A added topic: not user facing topic category matrix multiplication module: float8 For torch.float8_e5m2 and torch.float8_e4m3 module: core aten Related to change to the Core ATen opset labels Jan 27, 2025
@Aidyn-A Aidyn-A requested a review from eqy January 27, 2025 11:42
@Aidyn-A Aidyn-A changed the title [ATen][FP8] limit f8f8bf16 rowwise scaled matmul to sm_90 [ATen][Native][CUDA][SCALED_MM] limit f8f8bf16 rowwise scaled matmul to sm_90 Jan 27, 2025
@@ -790,6 +790,9 @@ void check_inputs(
const at::Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const at::Tensor& out) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 9, "f8f8bf16_rowwise is sm_90 specific.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be another change to not call into this at all?
We should fall back to another implementation for sm10+ right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think in the vary least we should have a tracker for functionality/features we skip on new hardware but we are tracking so that support can be added in full

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SM_90 minor version? Not relevant at all here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about SM_90 minor version? Not relevant at all here?

It is not relevant. But, I missed SM_89, I will include it as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fall back to another implementation for sm10+ right?

Correct, the current approach/kernel is not compatible with SM_100+, since there are no kernels for the Blackwell machines yet, I propose to just throw an exception. Otherwise, it will fail with a CUTLASS error, which is not an elegant behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A do they still fail with CUTLASS 3.7 btw? Or do we need to wait for 3.8?

That is a good question. I will need to check it. Thanks for reminding me about the CUTLASS update!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A We also need a CUDNN update (only for the versions we started the ManyLinux upgrade on so 2.6/2.8)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Skylion007 for the PR! I just ran the test with the latest CUTLASS 3.8 on SM_100 and got the errors:

FAILED [0.3029s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_rowwise_scaling_sanity_use_fast_accum_False_cuda - AssertionError: Tensor-likes are not close!
FAILED [2.5473s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_rowwise_scaling_sanity_use_fast_accum_True_cuda - AssertionError: Tensor-likes are not close!
FAILED [0.0025s] test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_mm_vs_emulated_row_wise_bfloat16_cuda - AssertionError: Tensor-likes are not close!

The reason it failed with numerical mismatches is that the kernel was simply aborted with the following message:

ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like __CUDA_ARCH_FEAT_SM90_ALL is not defined, not sure if this a CMAKE or CUTLASS bug.

@@ -790,6 +790,9 @@ void check_inputs(
const at::Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const at::Tensor& out) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major <= 9, "f8f8bf16_rowwise is sm_90 specific.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really bug here is that CUTLASS just throws a print statement here if the arch isn't supported and doesn't propogate an error up the stack right? Seems like an API design failure over there, but I want to know if there is an easy fix that can be made for 3.8

Copy link
Collaborator Author

@Aidyn-A Aidyn-A Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think there is an easy solution for that, as a bunch of PTX-level instructions needed for the RowwiseScaledMM are not supported on Blackwell:

ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208010; error   : Instruction 'wgmma.mma_async with FP8 types' not supported on .target 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208010; error   : Instruction 'wgmma.mma_async with FP8 types' cannot be compiled for architecture 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208015; error   : Instruction 'wgmma.commit_group' not supported on .target 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208015; error   : Instruction 'wgmma.commit_group' cannot be compiled for architecture 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208025; error   : Instruction 'wgmma.wait_group' not supported on .target 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208025; error   : Instruction 'wgmma.wait_group' cannot be compiled for architecture 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208389; error   : Instruction 'wgmma.fence' not supported on .target 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 208389; error   : Instruction 'wgmma.fence' cannot be compiled for architecture 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 209850; error   : Instruction 'setmaxnreg.dec' not supported on .target 'sm_100'
ptxas /tmp/tmpxft_000138f8_00000000-6_RowwiseScaledMM.ptx, line 210122; error   : Instruction 'setmaxnreg.inc' not supported on .target 'sm_100'

Copy link
Collaborator

@Skylion007 Skylion007 Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A Okay, I found the issue wgmma support was dropped completely in Blackwell, and replaced with a new instruction with different call signature / call convention, so yeah this isn't an easy fix and probably requires a deep refactor of CUTLASS ;-;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A Don't add a new exception here, just reuse the current exception logic on line 712 and expand the if conditional

@@ -790,6 +790,9 @@ void check_inputs(
const at::Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const at::Tensor& out) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major <= 9, "f8f8bf16_rowwise is sm_90 specific.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Aidyn-A Don't add a new exception here, just reuse the current exception logic on line 712 and expand the if conditional

@Aidyn-A Aidyn-A requested a review from Skylion007 January 28, 2025 06:43
@drisspg drisspg added the module: cuda Related to torch.cuda, and CUDA support in general label Jan 29, 2025
@drisspg
Copy link
Contributor

drisspg commented Jan 29, 2025

cc @eqy do we have any tracking for sm100 updates needed. I think the proper fix is to have another instantiation w/

using ArchTag = cutlass::arch::Sm90;
sm100 which in theory swap wgmma to tcgen05 when. we land cutlass 3.8 and 12.8

More so curious if we have an issue or somwhere we can backlog this stuff

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 29, 2025
@Skylion007
Copy link
Collaborator

cc @eqy do we have any tracking for sm100 updates needed. I think the proper fix is to have another instantiation w/

using ArchTag = cutlass::arch::Sm90;

sm100 which in theory swap wgmma to tcgen05 when. we land cutlass 3.8 and 12.8
More so curious if we have an issue or somwhere we can backlog this stuff

sm100 doesn't have any PTXAS instructions for wgmma emulation, so we need a CUTLASS update with new kernels beyond 3.8 (updating to 3.8 will not fix this sadly).

@eqy
Copy link
Collaborator

eqy commented Jan 29, 2025

OK with merging this temporarily before CUTLASS upgrade to abate noisy failures on Blackwell

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 30, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 30, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 06:44 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 06:44 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 30, 2025 06:44 Inactive
mori360 pushed a commit to mori360/pytorch that referenced this pull request Feb 6, 2025
…to sm_90 (pytorch#145728)

The CUTLASS-based kernel for f8f8bf16 rowwise scaled matmul is specific to Hopper devices only. It is not re-usable on newer devices without modifications. This PR adds a guard for this matmul to be sm_90 specific. Once the kernel is there, the guard may be removed.

Pull Request resolved: pytorch#145728
Approved by: https://github.com/Skylion007, https://github.com/eqy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request matrix multiplication Merged module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general module: float8 For torch.float8_e5m2 and torch.float8_e4m3 open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants