Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,20 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
self.common(forward, args)

def test_misaligned_address_issue1(self):
def forward(sub_tensor_1, unsqueeze_default):
gather_default = torch.ops.aten.gather.default(
sub_tensor_1, 1, unsqueeze_default
)
return gather_default

args = [
((1, 1000), (1000, 1), torch.float32),
((1, 1), (1, 1), torch.int64),
]
args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args]
self.common(forward, args)


if HAS_CPU:

Expand Down
14 changes: 7 additions & 7 deletions torchinductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,13 @@ def indexing(
mask.append(f"{tree.prefix}mask")
dense_mask.append(f"{tree.prefix}mask")

if need_dense and not have_dense:
mask = dense_mask
if (need_dense and not have_dense) or index == 0:
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
if index == 0:
return index_str, "None"
else:
mask = dense_mask

elif not have_loop_vars and copy_shape:
mask = dense_mask
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
Expand Down Expand Up @@ -699,11 +703,7 @@ def mask_loads(self, mask):
def load(self, name: str, index: sympy.Expr, upcast: bool = False):
var = self.args.input(name)
indirect_indexing = self.is_indirect_indexing(index)
if index == 0:
# No need to use mask when loading a single element from index 0
index, mask = "0", "None"
else:
index, mask = self.indexing(index)
index, mask = self.indexing(index)

if "rmask" in mask:
# This eviction policy heuristic is untested.
Expand Down