Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Commit 7659e77

Browse files
committed
Fix a misaligned address bug
Summary: #981 introduced a misaligned address bug which relates how tl.load from index 0 should be written in triton.
1 parent ed0b4ce commit 7659e77

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

test/test_torchinductor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2940,6 +2940,20 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
29402940
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
29412941
self.common(forward, args)
29422942

2943+
def test_misaligned_address_issue1(self):
2944+
def forward(sub_tensor_1, unsqueeze_default):
2945+
gather_default = torch.ops.aten.gather.default(
2946+
sub_tensor_1, 1, unsqueeze_default
2947+
)
2948+
return gather_default
2949+
2950+
args = [
2951+
((1, 1000), (1000, 1), torch.float32),
2952+
((1, 1), (1, 1), torch.int64),
2953+
]
2954+
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
2955+
self.common(forward, args)
2956+
29432957

29442958
if HAS_CPU:
29452959

torchinductor/codegen/triton.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,13 @@ def indexing(
648648
mask.append(f"{tree.prefix}mask")
649649
dense_mask.append(f"{tree.prefix}mask")
650650

651-
if need_dense and not have_dense:
652-
mask = dense_mask
651+
if (need_dense and not have_dense) or index == 0:
653652
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
653+
if index == 0:
654+
return index_str, "None"
655+
else:
656+
mask = dense_mask
657+
654658
elif not have_loop_vars and copy_shape:
655659
mask = dense_mask
656660
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
@@ -699,11 +703,7 @@ def mask_loads(self, mask):
699703
def load(self, name: str, index: sympy.Expr, upcast: bool = False):
700704
var = self.args.input(name)
701705
indirect_indexing = self.is_indirect_indexing(index)
702-
if index == 0:
703-
# No need to use mask when loading a single element from index 0
704-
index, mask = "0", "None"
705-
else:
706-
index, mask = self.indexing(index)
706+
index, mask = self.indexing(index)
707707

708708
if "rmask" in mask:
709709
# This eviction policy heuristic is untested.

0 commit comments

Comments
 (0)