diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 1d6c69d174..868d4f52a6 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -42,7 +42,6 @@ def amax_to_scale( float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. """ - scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: @@ -53,8 +52,7 @@ def amax_to_scale( # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=torch.finfo(torch.float16).max) - scale.copy_(res) - return scale + return res.to(torch.float32) @torch.no_grad()