Skip to content

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jul 17, 2025

Description

#1865 introduced a failure in the distributed debug tests. The root cause is because we only all-gather the row-wise data for DebugQuantizedTensor:

rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]

However, if the linear layer is caching its original input tensor and requantizing in the backward pass, the correct behavior is to only quantize the column-wise data. This PR is a hacky workaround that only applies the debug quantizer to the gathered input tensor.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Modify linear backward to avoid all-gathering debug tensor with only column-wise data

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 requested a review from pggPL July 17, 2025 23:43
@timmoon10 timmoon10 force-pushed the debug-linear-original-input branch from 5387dd7 to 6f66aa4 Compare July 17, 2025 23:52
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@cyanguwa cyanguwa mentioned this pull request Jul 18, 2025
13 tasks
Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 requested a review from ksivaman July 18, 2025 21:21
else:
quantizer.set_usage(rowwise=False, columnwise=True)
quantizer.set_usage(rowwise=True, columnwise=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Tim, why do we need the rowwise data here?

Copy link
Collaborator Author

@timmoon10 timmoon10 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we shouldn't need it, but I was running into issues where FP8 casts were failing without it. Actually, I don't think we need it once we skip the debug quantizer case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may be due to that the condition is modified. For (Float8Quantizer, Float8CurrentScalingQuantizer) cases, it doesn't support only quantizing the colwise data. But if backward_input_needs_gather is False, (Float8Quantizer, Float8CurrentScalingQuantizer) tensors would also go into the else path, which will cause the error. This hurts the performance of blockwise FP8 and MXFP8. Do you need me to create a fix PR for this, or will you fix this together with the DebugTensor?

@pggPL
Copy link
Collaborator

pggPL commented Jul 21, 2025

I think this line is source of error

out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor

The rowwise/columnwise in DebugQuantizedTensors are the tensors used in gemms -> they can be both the same Float8Tensor object for example. And update_usage() does nothing in debug tensors currently.

@timmoon10
Copy link
Collaborator Author

@pggPL I tried changing

out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor

to

out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor 

However, the error reappeared when I applied the debug quantizer to the local input tensor.

For now, I think we should merge this as a quick bugfix and we can fix the edge cases for the debug tensor later.

ksivaman
ksivaman previously approved these changes Jul 21, 2025
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

FP8 does not support transpose-only cast.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 merged commit 315b47d into NVIDIA:main Jul 22, 2025
12 checks passed
KshitijLakhani pushed a commit that referenced this pull request Jul 22, 2025
…ug quantizer (#1963)

* Debug linear layer when saving original input and using debug quantizer

Signed-off-by: Tim Moon <[email protected]>

* Workaround bugs with quantizing with only column-wise usage

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused imports

Signed-off-by: Tim Moon <[email protected]>

* Avoid unnecessary row-wise data

Signed-off-by: Tim Moon <[email protected]>

* Workaround bugs with quantizing with only column-wise usage

FP8 does not support transpose-only cast.

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants