Skip to content

Commit 8f7e3af

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use template to implement binary Tensor/Scalar comparisons
Summary: Similar to D56744651, now we can make changes all at once. Differential Revision: D56744904
1 parent ad548a4 commit 8f7e3af

File tree

7 files changed

+111
-233
lines changed

7 files changed

+111
-233
lines changed

kernels/portable/cpu/op_eq.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,8 @@ Tensor& eq_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
40-
// Resize for dynamic shape
41-
ET_KERNEL_CHECK_MSG(
42-
ctx,
43-
resize_tensor(out, a.sizes()) == Error::Ok,
44-
InvalidArgument,
45-
out,
46-
"Failed to resize output tensor.");
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = utils::get_scalar_dtype(b);
50-
ScalarType common_type = promoteTypes(a_type, b_type);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() {
54-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() {
55-
ET_SWITCH_REAL_TYPES_AND(
56-
Bool, common_type, ctx, "eq.Scalar_out", CTYPE_IN, [&]() {
57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() {
59-
CTYPE_B val_b = 0;
60-
utils::extract_scalar(b, &val_b);
61-
apply_unary_map_fn(
62-
[val_b](const CTYPE_A val_a) {
63-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65-
bool value = a_casted == b_casted;
66-
return static_cast<CTYPE_OUT>(value);
67-
},
68-
a.const_data_ptr<CTYPE_A>(),
69-
out.mutable_data_ptr<CTYPE_OUT>(),
70-
out.numel());
71-
});
72-
});
73-
});
74-
});
75-
76-
return out;
38+
return scalar_comparison_op_with_regular_promotion_out<std::equal_to>(
39+
ctx, a, b, out, "eq.Scalar_out");
7740
}
7841

7942
} // namespace native

kernels/portable/cpu/op_ge.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,8 @@ Tensor& ge_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
40-
// Resize for dynamic shape
41-
ET_KERNEL_CHECK_MSG(
42-
ctx,
43-
resize_tensor(out, a.sizes()) == Error::Ok,
44-
InvalidArgument,
45-
out,
46-
"Failed to resize output tensor.");
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = utils::get_scalar_dtype(b);
50-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ge.Scalar_out", CTYPE_A, [&]() {
54-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "ge.Scalar_out", CTYPE_B, [&]() {
55-
ET_SWITCH_REAL_TYPES_AND(
56-
Bool, common_type, ctx, "ge.Scalar_out", CTYPE_IN, [&]() {
57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, out_type, ctx, "ge.Scalar_out", CTYPE_OUT, [&]() {
59-
CTYPE_B val_b = 0;
60-
utils::extract_scalar(b, &val_b);
61-
apply_unary_map_fn(
62-
[val_b](const CTYPE_A val_a) {
63-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65-
bool value = a_casted >= b_casted;
66-
return static_cast<CTYPE_OUT>(value);
67-
},
68-
a.const_data_ptr<CTYPE_A>(),
69-
out.mutable_data_ptr<CTYPE_OUT>(),
70-
out.numel());
71-
});
72-
});
73-
});
74-
});
75-
76-
return out;
38+
return scalar_comparison_op_with_scalar_promotion_out<std::greater_equal>(
39+
ctx, a, b, out, "ge.Scalar_out");
7740
}
7841

7942
} // namespace native

kernels/portable/cpu/op_gt.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,8 @@ Tensor& gt_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
40-
// Resize for dynamic shape
41-
ET_KERNEL_CHECK_MSG(
42-
ctx,
43-
resize_tensor(out, a.sizes()) == Error::Ok,
44-
InvalidArgument,
45-
out,
46-
"Failed to resize output tensor.");
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = utils::get_scalar_dtype(b);
50-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "gt.Scalar_out", CTYPE_A, [&]() {
54-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "gt.Scalar_out", CTYPE_B, [&]() {
55-
ET_SWITCH_REAL_TYPES_AND(
56-
Bool, common_type, ctx, "gt.Scalar_out", CTYPE_IN, [&]() {
57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, out_type, ctx, "gt.Scalar_out", CTYPE_OUT, [&]() {
59-
CTYPE_B val_b = 0;
60-
utils::extract_scalar(b, &val_b);
61-
apply_unary_map_fn(
62-
[val_b](const CTYPE_A val_a) {
63-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65-
bool value = a_casted > b_casted;
66-
return static_cast<CTYPE_OUT>(value);
67-
},
68-
a.const_data_ptr<CTYPE_A>(),
69-
out.mutable_data_ptr<CTYPE_OUT>(),
70-
out.numel());
71-
});
72-
});
73-
});
74-
});
75-
76-
return out;
38+
return scalar_comparison_op_with_scalar_promotion_out<std::greater>(
39+
ctx, a, b, out, "gt.Scalar_out");
7740
}
7841

7942
} // namespace native

kernels/portable/cpu/op_le.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,8 @@ Tensor& le_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
40-
// Resize for dynamic shape
41-
ET_KERNEL_CHECK_MSG(
42-
ctx,
43-
resize_tensor(out, a.sizes()) == Error::Ok,
44-
InvalidArgument,
45-
out,
46-
"Failed to resize output tensor.");
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = utils::get_scalar_dtype(b);
50-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() {
54-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() {
55-
ET_SWITCH_REAL_TYPES_AND(
56-
Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() {
57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, out_type, ctx, "le.Scalar_out", CTYPE_OUT, [&]() {
59-
CTYPE_B val_b = 0;
60-
utils::extract_scalar(b, &val_b);
61-
apply_unary_map_fn(
62-
[val_b](const CTYPE_A val_a) {
63-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65-
bool value = a_casted <= b_casted;
66-
return static_cast<CTYPE_OUT>(value);
67-
},
68-
a.const_data_ptr<CTYPE_A>(),
69-
out.mutable_data_ptr<CTYPE_OUT>(),
70-
out.numel());
71-
});
72-
});
73-
});
74-
});
75-
76-
return out;
38+
return scalar_comparison_op_with_scalar_promotion_out<std::less_equal>(
39+
ctx, a, b, out, "le.Scalar_out");
7740
}
7841

7942
} // namespace native

kernels/portable/cpu/op_lt.cpp

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,45 +35,8 @@ Tensor& lt_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
40-
// Resize for dynamic shape
41-
ET_KERNEL_CHECK_MSG(
42-
ctx,
43-
resize_tensor(out, a.sizes()) == Error::Ok,
44-
InvalidArgument,
45-
out,
46-
"Failed to resize output tensor.");
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = utils::get_scalar_dtype(b);
50-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "lt.Scalar_out", CTYPE_A, [&]() {
54-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "lt.Scalar_out", CTYPE_B, [&]() {
55-
ET_SWITCH_REAL_TYPES_AND(
56-
Bool, common_type, ctx, "lt.Scalar_out", CTYPE_IN, [&]() {
57-
ET_SWITCH_REAL_TYPES_AND(
58-
Bool, out_type, ctx, "lt.Scalar_out", CTYPE_OUT, [&]() {
59-
CTYPE_B val_b = 0;
60-
utils::extract_scalar(b, &val_b);
61-
apply_unary_map_fn(
62-
[val_b](const CTYPE_A val_a) {
63-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65-
bool value = a_casted < b_casted;
66-
return static_cast<CTYPE_OUT>(value);
67-
},
68-
a.const_data_ptr<CTYPE_A>(),
69-
out.mutable_data_ptr<CTYPE_OUT>(),
70-
out.numel());
71-
});
72-
});
73-
});
74-
});
75-
76-
return out;
38+
return scalar_comparison_op_with_scalar_promotion_out<std::less>(
39+
ctx, a, b, out, "lt.Scalar_out");
7740
}
7841

7942
} // namespace native

kernels/portable/cpu/op_ne.cpp

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,44 +35,8 @@ Tensor& ne_scalar_out(
3535
const Tensor& a,
3636
const Scalar& b,
3737
Tensor& out) {
38-
(void)ctx;
39-
// Resize for dynamic shape
40-
ET_KERNEL_CHECK_MSG(
41-
ctx,
42-
resize_tensor(out, a.sizes()) == Error::Ok,
43-
InvalidArgument,
44-
out,
45-
"Failed to resize output tensor.");
46-
47-
ScalarType a_type = a.scalar_type();
48-
ScalarType b_type = utils::get_scalar_dtype(b);
49-
ScalarType common_type = promoteTypes(a_type, b_type);
50-
ScalarType out_type = out.scalar_type();
51-
52-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ne.Scalar_out", CTYPE_A, [&]() {
53-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "ne.Scalar_out", CTYPE_B, [&]() {
54-
ET_SWITCH_REAL_TYPES_AND(
55-
Bool, common_type, ctx, "ne.Scalar_out", CTYPE_IN, [&]() {
56-
ET_SWITCH_REAL_TYPES_AND(
57-
Bool, out_type, ctx, "ne.Scalar_out", CTYPE_OUT, [&]() {
58-
CTYPE_B val_b = 0;
59-
utils::extract_scalar(b, &val_b);
60-
apply_unary_map_fn(
61-
[val_b](const CTYPE_A val_a) {
62-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
63-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
64-
bool value = a_casted != b_casted;
65-
return static_cast<CTYPE_OUT>(value);
66-
},
67-
a.const_data_ptr<CTYPE_A>(),
68-
out.mutable_data_ptr<CTYPE_OUT>(),
69-
out.numel());
70-
});
71-
});
72-
});
73-
});
74-
75-
return out;
38+
return scalar_comparison_op_with_regular_promotion_out<std::not_equal_to>(
39+
ctx, a, b, out, "ne.Scalar_out");
7640
}
7741

7842
} // namespace native

0 commit comments

Comments
 (0)