Skip to content

Commit fc4ec44

Browse files
committed
Update
[ghstack-poisoned]
2 parents bd46558 + 80589b0 commit fc4ec44

File tree

4 files changed

+2
-41
lines changed

4 files changed

+2
-41
lines changed

kernels/portable/cpu/op_argmax.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ Tensor& argmax_out(
5050
for (const auto out_ix : c10::irange(out.numel())) {
5151
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
5252
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
// the below condition as written is equivalent to
54-
// !isnan(accval) && (isnan(v) || v > acc_val). See
55-
// argument in op_argmin.cpp.
56-
if (!std::isnan(acc_val) && !(v <= acc_val)) {
53+
if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) {
5754
acc_val = v;
5855
acc_ix = ix;
5956
}

kernels/portable/cpu/op_argmin.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,7 @@ Tensor& argmin_out(
5050
for (const auto out_ix : c10::irange(out.numel())) {
5151
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
5252
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
// the below condition as written is equivalent to !isnan(accval) &&
54-
// (isnan(v) || v < acc_val). cases:
55-
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
56-
// trivially equivalent to v < acc_val.
57-
// - if acc_val is NaN, the whole thing is trivially false.
58-
// - if acc_val is not NaN and v is NaN, then v >= acc_val
59-
// - is false because all comparisons involving NaN are
60-
// - false, so the result is true. The result is trivially
61-
// - true for the above condition that uses isnan(v) as
62-
// - well.
63-
if (!std::isnan(acc_val) && !(v >= acc_val)) {
53+
if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
6454
acc_val = v;
6555
acc_ix = ix;
6656
}

kernels/test/op_argmax_test.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,3 @@ TEST_F(OpArgmaxTest, SanityCheckNullDim) {
9090
EXPECT_TENSOR_EQ(out, expected);
9191
// clang-format on
9292
}
93-
94-
TEST_F(OpArgmaxTest, FirstNaNWins) {
95-
TensorFactory<ScalarType::Float> tf_float;
96-
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});
97-
98-
TensorFactory<ScalarType::Long> tf_long;
99-
Tensor out = tf_long.zeros({});
100-
Tensor expected = tf_long.make({}, {1});
101-
102-
Tensor ret = op_argmax_out(in, {}, false, out);
103-
EXPECT_TENSOR_EQ(out, ret);
104-
EXPECT_TENSOR_EQ(out, expected);
105-
}

kernels/test/op_argmin_test.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,3 @@ TEST_F(OpArgminTest, SanityCheckNullDim) {
9090
EXPECT_TENSOR_EQ(out, expected);
9191
// clang-format on
9292
}
93-
94-
TEST_F(OpArgminTest, FirstNaNWins) {
95-
TensorFactory<ScalarType::Float> tf_float;
96-
Tensor in = tf_float.make({4}, {1, NAN, -4, NAN});
97-
98-
TensorFactory<ScalarType::Long> tf_long;
99-
Tensor out = tf_long.zeros({});
100-
Tensor expected = tf_long.make({}, {1});
101-
102-
Tensor ret = op_argmin_out(in, {}, false, out);
103-
EXPECT_TENSOR_EQ(out, ret);
104-
EXPECT_TENSOR_EQ(out, expected);
105-
}

0 commit comments

Comments
 (0)