@@ -35,25 +35,23 @@ Tensor& comparison_op_out(
35
35
36
36
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
37
37
ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
38
- ET_SWITCH_REAL_TYPES_AND (
39
- Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40
- ET_SWITCH_REAL_TYPES_AND (
41
- Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
44
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
45
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
46
- bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47
- return static_cast <CTYPE_OUT>(value);
48
- },
49
- a,
50
- b,
51
- out);
52
- });
53
- });
38
+ using CTYPE_IN =
39
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
40
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
41
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
44
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
45
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
46
+ bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47
+ return static_cast <CTYPE_OUT>(value);
48
+ },
49
+ a,
50
+ b,
51
+ out);
52
+ });
54
53
});
55
54
});
56
-
57
55
return out;
58
56
}
59
57
} // namespace native
0 commit comments