diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c3812e0fb2..cdb75aa1b6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -999,6 +999,13 @@ def get_weight_workspace( out = None if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) + if quantizer is not None and isinstance(out, MXFP8TensorBase): + if quantizer.rowwise_usage and out._rowwise_data is None: + out = None + del self._fp8_workspaces[cache_name] + elif quantizer.columnwise_usage and out._columnwise_data is None: + out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work