14
14
namespace torch {
15
15
namespace executor {
16
16
namespace native {
17
+ namespace internal {
18
+
19
+ template <
20
+ bool can_cast,
21
+ template <typename >
22
+ typename OpFunc,
23
+ typename CTYPE_A,
24
+ typename CTYPE_B,
25
+ typename CTYPE_IN,
26
+ typename CTYPE_OUT>
27
+ struct BitwiseOpInner ;
28
+
29
+ template <
30
+ template <typename >
31
+ typename OpFunc,
32
+ typename CTYPE_A,
33
+ typename CTYPE_B,
34
+ typename CTYPE_IN,
35
+ typename CTYPE_OUT>
36
+ struct BitwiseOpInner <true , OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37
+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
38
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
41
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43
+ CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
44
+
45
+ return static_cast <CTYPE_OUT>(value);
46
+ },
47
+ a,
48
+ b,
49
+ out);
50
+ }
51
+ };
52
+
53
+ struct ReportCanCastBug {
54
+ static void run (const Tensor&, const Tensor&, Tensor&) {
55
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56
+ }
57
+ };
58
+
59
+ template <
60
+ template <typename >
61
+ typename OpFunc,
62
+ typename CTYPE_A,
63
+ typename CTYPE_B,
64
+ typename CTYPE_IN,
65
+ typename CTYPE_OUT>
66
+ struct BitwiseOpInner <false , OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67
+ : public ReportCanCastBug {};
68
+
69
+ } // namespace internal
70
+
17
71
template <template <typename > typename OpFunc>
18
72
Tensor& bitwise_op_out (
19
73
RuntimeContext& ctx,
@@ -36,21 +90,17 @@ Tensor& bitwise_op_out(
36
90
37
91
ET_SWITCH_INT_TYPES_AND (Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
38
92
ET_SWITCH_INT_TYPES_AND (Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
39
- ET_SWITCH_INT_TYPES_AND (Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40
- ET_SWITCH_REAL_TYPES_AND (
41
- Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
44
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
45
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
46
- CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47
-
48
- return static_cast <CTYPE_OUT>(value);
49
- },
50
- a,
51
- b,
52
- out);
53
- });
93
+ using CTYPE_IN =
94
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
95
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
96
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
97
+ internal::BitwiseOpInner<
98
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
99
+ OpFunc,
100
+ CTYPE_A,
101
+ CTYPE_B,
102
+ CTYPE_IN,
103
+ CTYPE_OUT>::run (a, b, out);
54
104
});
55
105
});
56
106
});
0 commit comments