-
Notifications
You must be signed in to change notification settings - Fork 511
Fix mxfp8 columnwise data missing #1593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix mxfp8 columnwise data missing #1593
Conversation
…raining Signed-off-by: Guyue Huang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/te-ci pytorch |
Signed-off-by: Guyue Huang <[email protected]>
Signed-off-by: Guyue Huang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Guyue Huang <[email protected]>
…1/TransformerEngine into fix_mxfp8_columnwise_data_missing Signed-off-by: Guyue Huang <[email protected]>
The problem is when we change quantizer use/recipe during training. For example in your case - validation does not need gradient, but training needs. We plan to handle switching the recipe, but it is not done yet. But I'm quite confused how you get this error. In first training step you should run with |
To give you an example that illustrate the bug: We first run 1 microbatches of validation, then we run 1 microbatches of training,
|
Right, the problem in that line 1000 is that it gives the out parameter to the tex.quantize function, which then just assumes that it is the right output and does not check what the quantizer said. There are 2 possibilities here, the tex.quantize should realize that the provided output is not actually correct and:
|
For option 1, you still update the tex.quantize(..) function, add/remove rowwise or columnwise data fields from it according to quantizer, is that right? I agree it's the right thing to do. My PR is an alternative, whenever it finds the saved tensor in _fp8_workspace isn't matching the quantizer (we can expand the condition of 'not-matching' of course, now it only checks the row/col wise usage), it deletes the saved tensor and the consequent code will create a new one and save it. I think the overhead is the same as if you modify |
Right, your PR would have similar overheads to option 2 since that way you need to allocate the full tensor again. I'm fine with merging it as is and I will follow up with option 1 that just allocates the missing pieces. |
can we merge this? @ksivaman |
I agree that functions like |
/te-ci pytorch |
Co-authored-by: Tim Moon <[email protected]> Signed-off-by: guyueh1 <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
L0 test seems to be failing for irrelevant things (paged attention), how should we fix it? @timmoon10 |
* Fix mxfp8 columnwise data missing when switching from validation to training Signed-off-by: Guyue Huang <[email protected]> * Fix when you interleave training and inference Signed-off-by: Guyue Huang <[email protected]> * refact Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm useless code Signed-off-by: Guyue Huang <[email protected]> * Update transformer_engine/pytorch/module/base.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: guyueh1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Guyue Huang <[email protected]> Signed-off-by: guyueh1 <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Guyue Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Tim Moon <[email protected]>
* Fix mxfp8 columnwise data missing when switching from validation to training Signed-off-by: Guyue Huang <[email protected]> * Fix when you interleave training and inference Signed-off-by: Guyue Huang <[email protected]> * refact Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm useless code Signed-off-by: Guyue Huang <[email protected]> * Update transformer_engine/pytorch/module/base.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: guyueh1 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Guyue Huang <[email protected]> Signed-off-by: guyueh1 <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Guyue Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Tim Moon <[email protected]>
…raining
Description
When we use mxfp8 to first run a few validation steps and then run a training step, the weight columnwise data in the row-wise linear layer is missing. This PR fixes it.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: