Skip to content

Conversation

skydoorkai
Copy link
Contributor

Description

When blockwise fp8 is used, weights are quantized into rowwise_data and columnwise_data. rowwise_data is used in forward pass and columnwise_data is used in backward pass.

Current implementation will create columnwise_data in advance for training, so the blockwise fp8 weight GPU memory consumption is similar to bf16 (plus two scale tensors).

Since weights are 2D-quantized for blockwise fp8, columnwise_data can be created from rowwise_data with a blockwise-quantization-aware transpose without loss of accuracy, this PR adds a GPU memory optimization method by only creating columnwise_data when needed (in backward pass) from rowwise_data. This will reduce fp8 weight GPU memory in half.
Similar to activation checkpointing, this is a tradeoff between GPU memory and computation.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE in FP8GlobalStateManager. Default if False, when users can set it by setting an env variable NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE=1 and blockwise fp8 recipe is used.

  • Add get_columnwise_fp8_tensor function to create columnwise fp8 data from rowwise fp8 data by a triton kernel.

  • Add support for GroupedLinear, Linear, LayerNormLinear.

  • Add test_fp8_weight_on_demand_transpose test

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@skydoorkai skydoorkai force-pushed the fp8_blockwise_weight_opt branch from 7eef7c5 to b6a51f7 Compare September 10, 2025 02:59
@yaox12
Copy link
Member

yaox12 commented Sep 10, 2025

In fact you don't need a standalone kernel to do the transpose. Just call weight_fp8.update_usage(columnwise_usage=True).

@BestJuly
Copy link
Collaborator

BestJuly commented Sep 10, 2025

In fact you don't need a standalone kernel to do the transpose. Just call weight_fp8.update_usage(columnwise_usage=True).

If using update_usage, it will generate the transposed data when it does not have transposed one. And after transpose, in the next micro batch backward pass within the same global batch, the columnwise data is always there, which may not save memory as what they want. I think they want to use this kernel to generate the transposed weight and release the memory right after the usage.

But yes, not sure why it does not use existing kernel like using tex.fp8_transpose as what has been done in _create_columnwise().

@yaox12
Copy link
Member

yaox12 commented Sep 15, 2025

In fact you don't need a standalone kernel to do the transpose. Just call weight_fp8.update_usage(columnwise_usage=True).

If using update_usage, it will generate the transposed data when it does not have transposed one. And after transpose, in the next micro batch backward pass within the same global batch, the columnwise data is always there, which may not save memory as what they want. I think they want to use this kernel to generate the transposed weight and release the memory right after the usage.

But yes, not sure why it does not use existing kernel like using tex.fp8_transpose as what has been done in _create_columnwise().

If you want to release the column data after use, you can set columnwise_usage=False after the GEMM. And I agree, at least we don't need to add another kernel to do the transpose.

jberchtold-nvidia and others added 2 commits September 15, 2025 19:55
* Custom call tests passing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix test_layer.py

Signed-off-by: Jeremy Berchtold <[email protected]>

* Lint

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix comments

Signed-off-by: Jeremy Berchtold <[email protected]>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)

Signed-off-by: Jeremy Berchtold <[email protected]>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp

Signed-off-by: Jeremy Berchtold <[email protected]>

* Cast non-quantized kernels to input dtype in VJPs

Signed-off-by: Jeremy Berchtold <[email protected]>

* Rename HighPrecisionTensor to NoScaleTensor

Signed-off-by: Jeremy Berchtold <[email protected]>

* Use NoScaleTensor in pure JAX impls where it was missing

Signed-off-by: Jeremy Berchtold <[email protected]>

* Fix tests

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: skydoorkai <[email protected]>
@skydoorkai skydoorkai force-pushed the fp8_blockwise_weight_opt branch from a492545 to 3e8f377 Compare September 15, 2025 11:56
@skydoorkai
Copy link
Contributor Author

In fact you don't need a standalone kernel to do the transpose. Just call weight_fp8.update_usage(columnwise_usage=True).

If using update_usage, it will generate the transposed data when it does not have transposed one. And after transpose, in the next micro batch backward pass within the same global batch, the columnwise data is always there, which may not save memory as what they want. I think they want to use this kernel to generate the transposed weight and release the memory right after the usage.
But yes, not sure why it does not use existing kernel like using tex.fp8_transpose as what has been done in _create_columnwise().

If you want to release the column data after use, you can set columnwise_usage=False after the GEMM. And I agree, at least we don't need to add another kernel to do the transpose.

Yes, no new kernel is needed. Updated to use tex.fp8_transpose.

Signed-off-by: skydoorkai <[email protected]>
@skydoorkai skydoorkai force-pushed the fp8_blockwise_weight_opt branch from fbdb58b to 59e7dcd Compare September 15, 2025 12:00
@skydoorkai skydoorkai force-pushed the fp8_blockwise_weight_opt branch from f35e5dd to 2d460d2 Compare September 15, 2025 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants