Skip to content

Conversation

guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Mar 19, 2025

…raining

Description

When we use mxfp8 to first run a few validation steps and then run a training step, the weight columnwise data in the row-wise linear layer is missing. This PR fixes it.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Member

/te-ci pytorch

Guyue Huang and others added 3 commits March 19, 2025 17:05
@guyueh1
Copy link
Contributor Author

guyueh1 commented Mar 20, 2025

The previous change didn't cover all buggy cases. I had to change it like now.
Basically, the fp8_workspaces will be updated if it found that for the same name, the tensor usage is different from what the quantizer specifies.
I think this needs more careful review, cc @ksivaman @ptrendx

Guyue Huang added 2 commits March 19, 2025 17:16
Signed-off-by: Guyue Huang <[email protected]>
…1/TransformerEngine into fix_mxfp8_columnwise_data_missing

Signed-off-by: Guyue Huang <[email protected]>
@pggPL
Copy link
Collaborator

pggPL commented Mar 20, 2025

The problem is when we change quantizer use/recipe during training. For example in your case - validation does not need gradient, but training needs. We plan to handle switching the recipe, but it is not done yet.

But I'm quite confused how you get this error. In first training step you should run with is_first_microbatch set to True or None and it should result with updating the cache with correct tensor.

@guyueh1
Copy link
Contributor Author

guyueh1 commented Mar 20, 2025

The problem is when we change quantizer use/recipe during training. For example in your case - validation does not need gradient, but training needs. We plan to handle switching the recipe, but it is not done yet.

But I'm quite confused how you get this error. In first training step you should run with is_first_microbatch set to True or None and it should result with updating the cache with correct tensor.

To give you an example that illustrate the bug:

We first run 1 microbatches of validation, then we run 1 microbatches of training,

  1. 1st validation iteration: is_first_microbatch = True, update_workspace = True, cache_name = "weight", quantizer.rowwise_usage=True, quantizer.columnwise_usage=False
  2. 1st training iteration: is_first_microbatch = True, update_workspace = True, cache_name = "weight", quantizer.rowwise_usage=True, quantizer.columnwise_usage=True
    After 1, self._fp8_workspace['weight'] will save a quantized mxfp8 tensor with no columnwise data. Then after 2, theoretically we should update the workspace object in self_fp8_workspace['weight'] in-place in line 1000, here we should update out to have both rowwise and columnwise data, but in practice this didn't happen, after tex.quantize(...) the out still don't have columnwise data, thus raised the error.

@ptrendx
Copy link
Member

ptrendx commented Mar 20, 2025

Right, the problem in that line 1000 is that it gives the out parameter to the tex.quantize function, which then just assumes that it is the right output and does not check what the quantizer said. There are 2 possibilities here, the tex.quantize should realize that the provided output is not actually correct and:

  • either add the missing fields
  • or error out
    The second option would require the caller to provide the right output (probably by creating a new tensor), which would incur some slight overhead. The first option would require us to actually go over the functions (since tex.quantize is not the only one that outputs quantized tensors) and modify them all to be able to properly handle this. I'm inclined to go with the first option for the perf reasons (and will implement it), but wanted to hear your opinions as well @guyueh1 @ksivaman @timmoon10 @pggPL.

@guyueh1
Copy link
Contributor Author

guyueh1 commented Mar 20, 2025

For option 1, you still update the tex.quantize(..) function, add/remove rowwise or columnwise data fields from it according to quantizer, is that right?

I agree it's the right thing to do.

My PR is an alternative, whenever it finds the saved tensor in _fp8_workspace isn't matching the quantizer (we can expand the condition of 'not-matching' of course, now it only checks the row/col wise usage), it deletes the saved tensor and the consequent code will create a new one and save it. I think the overhead is the same as if you modify text.quantize(...).

@ptrendx
Copy link
Member

ptrendx commented Mar 20, 2025

Right, your PR would have similar overheads to option 2 since that way you need to allocate the full tensor again. I'm fine with merging it as is and I will follow up with option 1 that just allocates the missing pieces.

@guyueh1
Copy link
Contributor Author

guyueh1 commented Mar 21, 2025

can we merge this? @ksivaman

@ksivaman
Copy link
Member

I agree that functions like tex.quantize that produce fp8 output should give precedence to the quantizer and not what exists in the provided output tensor, so that's better for the longer term. For now will merge is CI is clean

@ksivaman
Copy link
Member

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
@timmoon10
Copy link
Collaborator

/te-ci pytorch

@guyueh1
Copy link
Contributor Author

guyueh1 commented Mar 25, 2025

L0 test seems to be failing for irrelevant things (paged attention), how should we fix it? @timmoon10

@timmoon10 timmoon10 merged commit abbdd76 into NVIDIA:main Mar 25, 2025
10 of 11 checks passed
KshitijLakhani pushed a commit that referenced this pull request Mar 26, 2025
* Fix mxfp8 columnwise data missing when switching from validation to training

Signed-off-by: Guyue Huang <[email protected]>

* Fix when you interleave training and inference

Signed-off-by: Guyue Huang <[email protected]>

* refact

Signed-off-by: Guyue Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm useless code

Signed-off-by: Guyue Huang <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: guyueh1 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: guyueh1 <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Guyue Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
lhb8125 pushed a commit to lhb8125/TransformerEngine that referenced this pull request Apr 8, 2025
* Fix mxfp8 columnwise data missing when switching from validation to training

Signed-off-by: Guyue Huang <[email protected]>

* Fix when you interleave training and inference

Signed-off-by: Guyue Huang <[email protected]>

* refact

Signed-off-by: Guyue Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm useless code

Signed-off-by: Guyue Huang <[email protected]>

* Update transformer_engine/pytorch/module/base.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: guyueh1 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: guyueh1 <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Guyue Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants