Skip to content

Commit 8e35e24

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce eq/ne scalar op size & build time
Summary: These two scalar ops use promoteTypes, so we can use compile-time promotion right away. Differential Revision: D56744985
1 parent 3b68b9c commit 8e35e24

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

kernels/portable/cpu/pattern/comparison_op.h

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,23 @@ Tensor& scalar_comparison_op_with_regular_promotion_out(
128128

129129
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
130130
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
131-
ET_SWITCH_REAL_TYPES_AND(
132-
Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
133-
ET_SWITCH_REAL_TYPES_AND(
134-
Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
135-
CTYPE_B val_b = 0;
136-
utils::extract_scalar(b, &val_b);
137-
apply_unary_map_fn(
138-
[val_b](const CTYPE_A val_a) {
139-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
140-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
141-
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
142-
return static_cast<CTYPE_OUT>(value);
143-
},
144-
a.const_data_ptr<CTYPE_A>(),
145-
out.mutable_data_ptr<CTYPE_OUT>(),
146-
out.numel());
147-
});
148-
});
131+
using CTYPE_IN =
132+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
133+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
134+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
135+
CTYPE_B val_b = 0;
136+
utils::extract_scalar(b, &val_b);
137+
apply_unary_map_fn(
138+
[val_b](const CTYPE_A val_a) {
139+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
140+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
141+
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
142+
return static_cast<CTYPE_OUT>(value);
143+
},
144+
a.const_data_ptr<CTYPE_A>(),
145+
out.mutable_data_ptr<CTYPE_OUT>(),
146+
out.numel());
147+
});
149148
});
150149
});
151150

0 commit comments

Comments
 (0)