Skip to content

Commit c89475c

Browse files
vkuzojainapurva
authored andcommitted
float8: remove unneeded kernel for scale generation (#616)
Summary: The code to create a float8 scale is unnecessarily creating an extra GPU kernel launch by calling `torch.empty`, removing this. There is no performance impact, but it does make things easier to debug by reducing log size / making GPU traces simpler. Test Plan: ``` // extract trace of a linear fwd+bwd with python benchmarks/float8/profile_linear_float8.py ~/local/tmp/test // verify that the GPU kernel creating an empty scale tensor is no longer there // unit tests pass ./test/float8/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
1 parent dccb065 commit c89475c

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

torchao/float8/float8_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def amax_to_scale(
4242
float8_dtype: The float8 dtype.
4343
orig_dtype: The original dtype of the tensor.
4444
"""
45-
scale = torch.empty_like(amax, dtype=torch.float32)
4645
if float8_dtype in FP8_TYPES:
4746
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
4847
else:
@@ -53,8 +52,7 @@ def amax_to_scale(
5352
# to care about this for float32/bfloat16.
5453
if orig_dtype is torch.float16:
5554
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
56-
scale.copy_(res)
57-
return scale
55+
return res.to(torch.float32)
5856

5957

6058
@torch.no_grad()

0 commit comments

Comments
 (0)