-
Notifications
You must be signed in to change notification settings - Fork 510
blockwise fp8 weight memory optimization: on-demand columnwise fp8 weight creation #2168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
7eef7c5
to
b6a51f7
Compare
In fact you don't need a standalone kernel to do the transpose. Just call |
If using But yes, not sure why it does not use existing kernel like using tex.fp8_transpose as what has been done in |
If you want to release the column data after use, you can set |
* Custom call tests passing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> * Lint Signed-off-by: Jeremy Berchtold <[email protected]> * Fix comments Signed-off-by: Jeremy Berchtold <[email protected]> * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold <[email protected]> * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold <[email protected]> * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold <[email protected]> * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold <[email protected]> * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold <[email protected]> * Fix tests Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: skydoorkai <[email protected]>
Signed-off-by: skydoorkai <[email protected]>
a492545
to
3e8f377
Compare
Yes, no new kernel is needed. Updated to use tex.fp8_transpose. |
Signed-off-by: skydoorkai <[email protected]>
fbdb58b
to
59e7dcd
Compare
f35e5dd
to
2d460d2
Compare
for more information, see https://pre-commit.ci
Description
When blockwise fp8 is used, weights are quantized into rowwise_data and columnwise_data. rowwise_data is used in forward pass and columnwise_data is used in backward pass.
Current implementation will create columnwise_data in advance for training, so the blockwise fp8 weight GPU memory consumption is similar to bf16 (plus two scale tensors).
Since weights are 2D-quantized for blockwise fp8, columnwise_data can be created from rowwise_data with a blockwise-quantization-aware transpose without loss of accuracy, this PR adds a GPU memory optimization method by only creating columnwise_data when needed (in backward pass) from rowwise_data. This will reduce fp8 weight GPU memory in half.
Similar to activation checkpointing, this is a tradeoff between GPU memory and computation.
Type of change
Changes
Please list the changes introduced in this PR:
Add FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE in FP8GlobalStateManager. Default if False, when users can set it by setting an env variable NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE=1 and blockwise fp8 recipe is used.
Add get_columnwise_fp8_tensor function to create columnwise fp8 data from rowwise fp8 data by a triton kernel.
Add support for GroupedLinear, Linear, LayerNormLinear.
Add test_fp8_weight_on_demand_transpose test
Checklist: