diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 4f4bf74480..cad687ca35 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -2940,6 +2940,20 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] self.common(forward, args) + def test_misaligned_address_issue1(self): + def forward(sub_tensor_1, unsqueeze_default): + gather_default = torch.ops.aten.gather.default( + sub_tensor_1, 1, unsqueeze_default + ) + return gather_default + + args = [ + ((1, 1000), (1000, 1), torch.float32), + ((1, 1), (1, 1), torch.int64), + ] + args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] + self.common(forward, args) + if HAS_CPU: diff --git a/torchinductor/codegen/triton.py b/torchinductor/codegen/triton.py index e05170c2c4..0499751d75 100644 --- a/torchinductor/codegen/triton.py +++ b/torchinductor/codegen/triton.py @@ -648,9 +648,13 @@ def indexing( mask.append(f"{tree.prefix}mask") dense_mask.append(f"{tree.prefix}mask") - if need_dense and not have_dense: - mask = dense_mask + if (need_dense and not have_dense) or index == 0: index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)" + if index == 0: + return index_str, "None" + else: + mask = dense_mask + elif not have_loop_vars and copy_shape: mask = dense_mask index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)" @@ -699,11 +703,7 @@ def mask_loads(self, mask): def load(self, name: str, index: sympy.Expr, upcast: bool = False): var = self.args.input(name) indirect_indexing = self.is_indirect_indexing(index) - if index == 0: - # No need to use mask when loading a single element from index 0 - index, mask = "0", "None" - else: - index, mask = self.indexing(index) + index, mask = self.indexing(index) if "rmask" in mask: # This eviction policy heuristic is untested.