diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index de55155b96..b1d4196dfd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -65,8 +65,6 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -170,16 +168,19 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorBase): - input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input - ) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) + if save_original_input: + # No need for column-wise data since this + # tensor will not be cached for backward pass + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False inputmat = input_quantizer(inputmat) - own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -344,23 +345,29 @@ def forward( inputmat = inp ctx.weight_quantizer = weight_quantizer - saved_inputmat = None ctx.backward_input_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + # Discard unneeded data in input tensor + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorBase) + ): + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + # Discard row-wise data since it is not needed in backward pass + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Cached input tensor + saved_inputmat = None if backward_needs_input: - if not save_original_input: - if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if ( - isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) - or not ctx.backward_input_needs_gather - ): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -572,20 +579,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - input_is_quantized = isinstance(inputmat, QuantizedTensorBase) if ctx.fp8 or ctx.debug: - if not input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): + # Input tensor is already quantized + pass + elif ctx.debug: + # Debug quantizer will be applied immediately before wgrad GEMM + pass + else: + # Quantize input tensor quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - quantizer.set_usage( - rowwise=True, - columnwise=not ctx.backward_input_needs_gather, - ) + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + quantizer.set_usage(rowwise=True, columnwise=False) else: - quantizer.set_usage(rowwise=False, columnwise=True) + quantizer.set_usage(rowwise=True, columnwise=True) inputmat = quantizer(inputmat) else: - if input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype)