@@ -128,24 +128,23 @@ Tensor& scalar_comparison_op_with_regular_promotion_out(
128
128
129
129
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
130
130
ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
131
- ET_SWITCH_REAL_TYPES_AND (
132
- Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
133
- ET_SWITCH_REAL_TYPES_AND (
134
- Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
135
- CTYPE_B val_b = 0 ;
136
- utils::extract_scalar (b, &val_b);
137
- apply_unary_map_fn (
138
- [val_b](const CTYPE_A val_a) {
139
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
140
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
141
- bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
142
- return static_cast <CTYPE_OUT>(value);
143
- },
144
- a.const_data_ptr <CTYPE_A>(),
145
- out.mutable_data_ptr <CTYPE_OUT>(),
146
- out.numel ());
147
- });
148
- });
131
+ using CTYPE_IN =
132
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
133
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
134
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
135
+ CTYPE_B val_b = 0 ;
136
+ utils::extract_scalar (b, &val_b);
137
+ apply_unary_map_fn (
138
+ [val_b](const CTYPE_A val_a) {
139
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
140
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
141
+ bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
142
+ return static_cast <CTYPE_OUT>(value);
143
+ },
144
+ a.const_data_ptr <CTYPE_A>(),
145
+ out.mutable_data_ptr <CTYPE_OUT>(),
146
+ out.numel ());
147
+ });
149
148
});
150
149
});
151
150
0 commit comments