From 8228aade5e4db644e405545cd22ce02f7420c836 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 27 Jun 2025 11:14:24 -0700 Subject: [PATCH] fp8_calibration path seems to be broken Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++-- transformer_engine/pytorch/module/linear.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b99952ad2a..1d682f14b8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1577,7 +1577,7 @@ def forward( return out def _get_quantizers(self, fp8_output, fp8_grad): - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * 6 grad_input_quantizer = None grad_weight_quantizer = None @@ -1680,7 +1680,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 375db477b0..0c191a46ac 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1860,7 +1860,7 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() - if self.fp8: + if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] @@ -1991,7 +1991,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None, None] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8a7c0ce2d1..af1a7f38d8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1377,7 +1377,7 @@ def forward( return out def _get_quantizers(self, fp8_output, fp8_grad): - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * 6 grad_input_quantizer = None grad_weight_quantizer = None @@ -1479,7 +1479,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True