-
Notifications
You must be signed in to change notification settings - Fork 501
[PyTorch] Debug linear layer when saving original input and using debug quantizer #1963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
5387dd7
to
6f66aa4
Compare
for more information, see https://pre-commit.ci
/te-ci pytorch L1 |
Signed-off-by: Tim Moon <[email protected]>
else: | ||
quantizer.set_usage(rowwise=False, columnwise=True) | ||
quantizer.set_usage(rowwise=True, columnwise=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Tim, why do we need the rowwise data here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle we shouldn't need it, but I was running into issues where FP8 casts were failing without it. Actually, I don't think we need it once we skip the debug quantizer case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may be due to that the condition is modified. For (Float8Quantizer, Float8CurrentScalingQuantizer)
cases, it doesn't support only quantizing the colwise data. But if backward_input_needs_gather
is False, (Float8Quantizer, Float8CurrentScalingQuantizer)
tensors would also go into the else
path, which will cause the error. This hurts the performance of blockwise FP8 and MXFP8. Do you need me to create a fix PR for this, or will you fix this together with the DebugTensor?
I think this line is source of error
The rowwise/columnwise in DebugQuantizedTensors are the tensors used in gemms -> they can be both the same Float8Tensor object for example. And update_usage() does nothing in debug tensors currently.
|
Signed-off-by: Tim Moon <[email protected]>
@pggPL I tried changing
to out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor However, the error reappeared when I applied the debug quantizer to the local input tensor. For now, I think we should merge this as a quick bugfix and we can fix the edge cases for the debug tensor later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/te-ci pytorch L1 |
FP8 does not support transpose-only cast. Signed-off-by: Tim Moon <[email protected]>
…ug quantizer (#1963) * Debug linear layer when saving original input and using debug quantizer Signed-off-by: Tim Moon <[email protected]> * Workaround bugs with quantizing with only column-wise usage Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused imports Signed-off-by: Tim Moon <[email protected]> * Avoid unnecessary row-wise data Signed-off-by: Tim Moon <[email protected]> * Workaround bugs with quantizing with only column-wise usage FP8 does not support transpose-only cast. Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
#1865 introduced a failure in the distributed debug tests.
The root cause is because we only all-gather the row-wise data forDebugQuantizedTensor
:TransformerEngine/transformer_engine/pytorch/distributed.py
Line 1394 in f8933bb
However, if the linear layer is caching its original input tensor and requantizing in the backward pass, the correct behavior is to only quantize the column-wise data.This PR is a hacky workaround that only applies the debug quantizer to the gathered input tensor.Type of change
Changes
Checklist: