Skip to content

Commit ad548a4

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce comparison op size & build time
Summary: Yet another 6 improved ops (see previous diffs). Differential Revision: D56744787
1 parent d347569 commit ad548a4

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

kernels/portable/cpu/pattern/comparison_op.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,23 @@ Tensor& comparison_op_out(
3535

3636
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
3737
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
38-
ET_SWITCH_REAL_TYPES_AND(
39-
Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40-
ET_SWITCH_REAL_TYPES_AND(
41-
Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
45-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
46-
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47-
return static_cast<CTYPE_OUT>(value);
48-
},
49-
a,
50-
b,
51-
out);
52-
});
53-
});
38+
using CTYPE_IN =
39+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
40+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
41+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
45+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
46+
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47+
return static_cast<CTYPE_OUT>(value);
48+
},
49+
a,
50+
b,
51+
out);
52+
});
5453
});
5554
});
56-
5755
return out;
5856
}
5957
} // namespace native

0 commit comments

Comments
 (0)