Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 6c37dfe

Browse files
authored
[inductor] Fix a use before def error (#956)
Summary: for #918
1 parent ca8243c commit 6c37dfe

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

test/test_torchinductor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2871,7 +2871,7 @@ def fn(x):
28712871
],
28722872
)
28732873

2874-
def test_tmp_not_defined_issue(self):
2874+
def test_tmp_not_defined_issue1(self):
28752875
def forward(
28762876
primals_3,
28772877
primals_4,
@@ -2920,6 +2920,22 @@ def forward(
29202920
inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps]
29212921
self.common(forward, inps)
29222922

2923+
def test_tmp_not_defined_issue2(self):
2924+
def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
2925+
div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1)
2926+
mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1)
2927+
sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24)
2928+
return (new_zeros_default_4, sum_default_7)
2929+
2930+
args = [
2931+
((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
2932+
((), (), torch.float32),
2933+
((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32),
2934+
((3,), (1,), torch.float32),
2935+
]
2936+
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
2937+
self.common(forward, args)
2938+
29232939

29242940
if HAS_CPU:
29252941

torchinductor/codegen/triton.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,12 @@ def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
616616
new_index = index.subs(dict(zip(index_vars, reindex(new_index_vars))))
617617
return new_index
618618

619-
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
619+
def indexing(
620+
self,
621+
index: sympy.Expr,
622+
copy_shape=None,
623+
dense_indexing=False,
624+
):
620625
"""
621626
Compute the index and mask to pass to tl.load() or tl.store()
622627
"""
@@ -632,6 +637,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
632637
or indirect_indexing
633638
or self._load_mask is not None
634639
) and index != 0
640+
635641
have_dense = True
636642
have_loop_vars = False
637643
mask = []
@@ -646,7 +652,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
646652
mask.append(f"{tree.prefix}mask")
647653
dense_mask.append(f"{tree.prefix}mask")
648654

649-
if need_dense and not have_dense:
655+
if (need_dense and not have_dense) or index == 0:
650656
mask = dense_mask
651657
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
652658
elif not have_loop_vars and copy_shape:
@@ -704,7 +710,12 @@ def load(self, name: str, index: sympy.Expr, upcast: bool = False):
704710
if upcast:
705711
line += ".to(tl.float32)"
706712

707-
if self.inside_reduction and "rmask" not in mask and not indirect_indexing:
713+
if (
714+
self.inside_reduction
715+
and "rmask" not in mask
716+
and "tmp" not in mask
717+
and not indirect_indexing
718+
):
708719
# can lift a common load outside of reduction loop
709720
# One exception is when this is an indirect_load.
710721
tmp = self.cse.generate(self.body, line)

0 commit comments

Comments
 (0)