Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 777059f

Browse files
committed
Force dense_indexing when load with zero as the index
1 parent 21bffac commit 777059f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchinductor/codegen/triton.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,13 @@ def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
616616
new_index = index.subs(dict(zip(index_vars, reindex(new_index_vars))))
617617
return new_index
618618

619-
def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
619+
def indexing(
620+
self,
621+
index: sympy.Expr,
622+
copy_shape=None,
623+
dense_indexing=False,
624+
load_index0=False,
625+
):
620626
"""
621627
Compute the index and mask to pass to tl.load() or tl.store()
622628
"""
@@ -632,6 +638,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
632638
or indirect_indexing
633639
or self._load_mask is not None
634640
) and index != 0
641+
635642
have_dense = True
636643
have_loop_vars = False
637644
mask = []
@@ -646,7 +653,7 @@ def indexing(self, index: sympy.Expr, copy_shape=None, dense_indexing=False):
646653
mask.append(f"{tree.prefix}mask")
647654
dense_mask.append(f"{tree.prefix}mask")
648655

649-
if need_dense and not have_dense:
656+
if (need_dense and not have_dense) or load_index0:
650657
mask = dense_mask
651658
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
652659
elif not have_loop_vars and copy_shape:
@@ -692,7 +699,7 @@ def mask_loads(self, mask):
692699
def load(self, name: str, index: sympy.Expr, upcast: bool = False):
693700
var = self.args.input(name)
694701
indirect_indexing = self.is_indirect_indexing(index)
695-
index, mask = self.indexing(index)
702+
index, mask = self.indexing(index, load_index0=(index == 0))
696703
if "rmask" in mask:
697704
# This eviction policy heuristic is untested.
698705
# ptillet suggested we should try only doing this for

0 commit comments

Comments
 (0)