Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 21bffac

Browse files
committed
[inductor] Fix a use before def error
Summary: for #918
1 parent a36aaed commit 21bffac

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

test/test_torchinductor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2871,7 +2871,7 @@ def fn(x):
28712871
],
28722872
)
28732873

2874-
def test_tmp_not_defined_issue(self):
2874+
def test_tmp_not_defined_issue1(self):
28752875
def forward(
28762876
primals_3,
28772877
primals_4,
@@ -2920,6 +2920,22 @@ def forward(
29202920
inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps]
29212921
self.common(forward, inps)
29222922

2923+
def test_tmp_not_defined_issue2(self):
2924+
def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
2925+
div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1)
2926+
mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1)
2927+
sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24)
2928+
return (new_zeros_default_4, sum_default_7)
2929+
2930+
args = [
2931+
((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
2932+
((), (), torch.float32),
2933+
((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
2934+
((3,), (1,), torch.float32),
2935+
]
2936+
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
2937+
self.common(forward, args)
2938+
29232939

29242940
if HAS_CPU:
29252941

torchinductor/codegen/triton.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,12 @@ def load(self, name: str, index: sympy.Expr, upcast: bool = False):
704704
if upcast:
705705
line += ".to(tl.float32)"
706706

707-
if self.inside_reduction and "rmask" not in mask and not indirect_indexing:
707+
if (
708+
self.inside_reduction
709+
and "rmask" not in mask
710+
and "tmp" not in mask
711+
and not indirect_indexing
712+
):
708713
# can lift a common load outside of reduction loop
709714
# One exception is when this is an indirect_load.
710715
tmp = self.cse.generate(self.body, line)

0 commit comments

Comments
 (0)