Skip to content
65 changes: 39 additions & 26 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState
Expand Down Expand Up @@ -170,16 +168,19 @@ def forward(
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(
rowwise=True, columnwise=backward_needs_input and not save_original_input
)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
if save_original_input:
# No need for column-wise data since this
# tensor will not be cached for backward pass
input_quantizer.set_usage(columnwise=False)
own_quantized_input = False
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP

Expand Down Expand Up @@ -344,23 +345,29 @@ def forward(
inputmat = inp

ctx.weight_quantizer = weight_quantizer
saved_inputmat = None

ctx.backward_input_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)

# Discard unneeded data in input tensor
if (
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
):
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
else:
# Discard row-wise data since it is not needed in backward pass
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)

# Cached input tensor
saved_inputmat = None
if backward_needs_input:
if not save_original_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if (
isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.backward_input_needs_gather
):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
saved_inputmat = inputmat

# Weight with column-wise usage is needed for dgrad GEMM.
Expand Down Expand Up @@ -572,20 +579,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmat_total = None
inputmat_total_work = None
if ctx.requires_wgrad:
input_is_quantized = isinstance(inputmat, QuantizedTensorBase)
if ctx.fp8 or ctx.debug:
if not input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
# Input tensor is already quantized
pass
elif ctx.debug:
# Debug quantizer will be applied immediately before wgrad GEMM
pass
else:
# Quantize input tensor
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage(rowwise=True, columnwise=False)
else:
quantizer.set_usage(rowwise=False, columnwise=True)
quantizer.set_usage(rowwise=True, columnwise=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Tim, why do we need the rowwise data here?

Copy link
Collaborator Author

@timmoon10 timmoon10 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we shouldn't need it, but I was running into issues where FP8 casts were failing without it. Actually, I don't think we need it once we skip the debug quantizer case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may be due to that the condition is modified. For (Float8Quantizer, Float8CurrentScalingQuantizer) cases, it doesn't support only quantizing the colwise data. But if backward_input_needs_gather is False, (Float8Quantizer, Float8CurrentScalingQuantizer) tensors would also go into the else path, which will cause the error. This hurts the performance of blockwise FP8 and MXFP8. Do you need me to create a fix PR for this, or will you fix this together with the DebugTensor?

inputmat = quantizer(inputmat)
else:
if input_is_quantized:
if isinstance(inputmat, QuantizedTensorBase):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
Expand Down