Skip to content

Commit 18bcf62

Browse files
XiaobingSuperpytorchmergebot
authored andcommitted
inductor: promote half/bfloat16 constant to float for cpu vectorization path (pytorch#105440)
As scalar path, we should also promote half/bfloat16 constant to float for better accuracy, after this PR, the TIMM ```dm_nfnet``` model amp path can be passed. Pull Request resolved: pytorch#105440 Approved by: https://github.com/jgong5, https://github.com/jansel
1 parent 7ddb66e commit 18bcf62

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,6 +1977,15 @@ def f(a):
19771977
x = torch.rand(4, 5)
19781978
self.common(f, (x,))
19791979

1980+
def test_scalar_mul_bfloat16(self):
1981+
def f(x):
1982+
return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
1983+
1984+
metrics.reset()
1985+
x = torch.randn(4, 5, dtype=torch.bfloat16)
1986+
self.common(f, (x,))
1987+
assert metrics.generated_cpp_vec_kernel_count == 1
1988+
19801989
def test_to_channels_last_bfloat16(self):
19811990
def f(a):
19821991
return a.to(memory_format=torch.channels_last)

torch/_inductor/codegen/cpp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,10 @@ def store_reduction(name, index, value):
20942094

20952095
@staticmethod
20962096
def constant(val, dtype):
2097+
if dtype == torch.bfloat16:
2098+
# Since load promotes all bfloat16-precision inputs to float, constants
2099+
# must be promoted as well
2100+
dtype = torch.float32
20972101
with RecordOptimizationContext(__name__) as node_ctx:
20982102
opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
20992103
assert opt_ctx

0 commit comments

Comments
 (0)