Skip to content

Commit 8c704f7

Browse files
Jiong Gongpytorchmergebot
authored andcommitted
[inductor cpp] fix argmax with >1 reduction dims (pytorch#113168)
Fix pytorch#113013. The argmax (and argmin) implementation doesn't handle the index compute properly when the number of reduction dims is larger than 1. It wrongly assumed only one reduction dim. With the given reproducer, the generated code before the change: ```c++ #include "/tmp/torchinductor_jgong5/tb/ctbgktuhgnnlel6ipqkfk76lfztr5pledachdkcq3asdqtlxpzt6.h" extern "C" void kernel(const double* in_ptr0, long* out_ptr0) { { { struct IndexValue_1 {size_t index; double value;}; IndexValue_1 tmp_acc0{0, -std::numeric_limits<double>::infinity()}; #if !defined(__clang_major__) || __clang_major__ > 9 #pragma omp declare reduction(argmax : IndexValue_1 :\ omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\ omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\ initializer(omp_priv = {0, -std::numeric_limits<double>::infinity()}) #endif for(long x0=static_cast<long>(0L); x0<static_cast<long>(9L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(2L); x1+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(0); auto tmp1 = c10::convert<long>(1); auto tmp2 = tmp0 < tmp1; auto tmp3 = c10::convert<long>(at::native::div_floor_integer((3L*x1), 2L)); auto tmp4 = c10::convert<long>(2L + (at::native::div_floor_integer((3L*x1), 2L))); auto tmp5 = tmp3 < tmp4; auto tmp6 = tmp2 & tmp5; auto tmp7 = [&] { auto tmp8 = in_ptr0[static_cast<long>((3L*x0) + (at::native::div_floor_integer((3L*x1), 2L)))]; return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<long>(1L + (at::native::div_floor_integer((3L*x1), 2L))); auto tmp11 = tmp10 < tmp4; auto tmp12 = tmp2 & tmp11; auto tmp13 = [&] { auto tmp14 = in_ptr0[static_cast<long>(1L + (3L*x0) + (at::native::div_floor_integer((3L*x1), 2L)))]; return tmp14; } ; auto tmp15 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp16 = tmp15 + tmp9; auto tmp17 = [&] { auto tmp18 = c10::convert<double>(1.0); return tmp18; } ; auto tmp19 = tmp6 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = [&] { auto tmp21 = c10::convert<double>(1.0); return tmp21; } ; auto tmp22 = tmp12 ? tmp20() : static_cast<decltype(tmp20())>(0.0); auto tmp23 = tmp22 + tmp19; auto tmp24 = tmp16 / tmp23; if (tmp_acc0.value < tmp24) { tmp_acc0.index = x1; tmp_acc0.value = tmp24; // both x0 and x1 are reduction vars while only x1 is assigned to tmp_acc0.index } } } out_ptr0[static_cast<long>(0L)] = tmp_acc0.index; } } } ``` After fix: ```c++ #include "/tmp/torchinductor_jgong5/tb/ctbgktuhgnnlel6ipqkfk76lfztr5pledachdkcq3asdqtlxpzt6.h" extern "C" void kernel(const double* in_ptr0, long* out_ptr0) { { { struct IndexValue_1 {size_t index; double value;}; IndexValue_1 tmp_acc0{0, -std::numeric_limits<double>::infinity()}; #if !defined(__clang_major__) || __clang_major__ > 9 #pragma omp declare reduction(argmax : IndexValue_1 :\ omp_out.value = omp_in.value < omp_out.value ? omp_out.value : omp_in.value,\ omp_out.index = omp_in.value < omp_out.value ? omp_out.index : omp_in.index)\ initializer(omp_priv = {0, -std::numeric_limits<double>::infinity()}) #endif for(long x0=static_cast<long>(0L); x0<static_cast<long>(9L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(2L); x1+=static_cast<long>(1L)) { auto tmp0 = c10::convert<long>(0); auto tmp1 = c10::convert<long>(1); auto tmp2 = tmp0 < tmp1; auto tmp3 = c10::convert<long>(at::native::div_floor_integer((3L*x1), 2L)); auto tmp4 = c10::convert<long>(2L + (at::native::div_floor_integer((3L*x1), 2L))); auto tmp5 = tmp3 < tmp4; auto tmp6 = tmp2 & tmp5; auto tmp7 = [&] { auto tmp8 = in_ptr0[static_cast<long>((3L*x0) + (at::native::div_floor_integer((3L*x1), 2L)))]; return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<long>(1L + (at::native::div_floor_integer((3L*x1), 2L))); auto tmp11 = tmp10 < tmp4; auto tmp12 = tmp2 & tmp11; auto tmp13 = [&] { auto tmp14 = in_ptr0[static_cast<long>(1L + (3L*x0) + (at::native::div_floor_integer((3L*x1), 2L)))]; return tmp14; } ; auto tmp15 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp16 = tmp15 + tmp9; auto tmp17 = [&] { auto tmp18 = c10::convert<double>(1.0); return tmp18; } ; auto tmp19 = tmp6 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = [&] { auto tmp21 = c10::convert<double>(1.0); return tmp21; } ; auto tmp22 = tmp12 ? tmp20() : static_cast<decltype(tmp20())>(0.0); auto tmp23 = tmp22 + tmp19; auto tmp24 = tmp16 / tmp23; if (tmp_acc0.value < tmp24) { tmp_acc0.index = static_cast<long>(x1 + (2L*x0)); tmp_acc0.value = tmp24; } } } out_ptr0[static_cast<long>(0L)] = tmp_acc0.index; } } } ``` Pull Request resolved: pytorch#113168 Approved by: https://github.com/lezcano, https://github.com/jansel
1 parent be66d5e commit 8c704f7

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7649,6 +7649,16 @@ def fn(x, y):
76497649
b = torch.randn(65, 2**24, device=self.device)
76507650
fn(a, b)
76517651

7652+
def test_adaptive_avg_pool1d_argmax(self):
7653+
# https://github.com/pytorch/pytorch/issues/113013
7654+
def fn(x):
7655+
x = torch.adaptive_avg_pool1d(input=x, output_size=2)
7656+
x = torch.argmax(input=x)
7657+
return x
7658+
7659+
x = torch.rand([3, 3, 3], dtype=torch.float64)
7660+
self.common(fn, (x,))
7661+
76527662

76537663
@dataclasses.dataclass
76547664
class TestFailure:

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def run(*ex, **kwargs):
136136
"test_zeros_dynamic_shapes": TestFailure(("cpu",)),
137137
"test_uint_dynamic_shapes": TestFailure(("cpu",)),
138138
"test_issue102546_dynamic_shapes": TestFailure(("cpu",)),
139+
"test_adaptive_avg_pool1d_argmax_dynamic_shapes": TestFailure(("cpu",)),
139140
#
140141
# Failed to find for loop/triton kernel:
141142
#

torch/_inductor/codegen/cpp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1269,10 +1269,14 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
12691269
argmax_argmin_prefix(reduction_type, src_dtype, acc)
12701270
)
12711271
compare_op = "<" if reduction_type == "argmax" else ">"
1272+
assert self.reduction_depth is not None
1273+
index = self.itervars[self.reduction_depth]
1274+
for i in range(self.reduction_depth + 1, len(self.itervars)):
1275+
index = index * self.ranges[i] + self.itervars[i]
12721276
self.stores.writelines(
12731277
[
12741278
f"if ({acc}.value {compare_op} {value}) {{",
1275-
f" {acc}.index = {self.itervars[-1]}; {acc}.value = {value};",
1279+
f" {acc}.index = {cexpr_index(index)}; {acc}.value = {value};",
12761280
"}",
12771281
],
12781282
)

0 commit comments

Comments
 (0)