6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cmath>
9
+ // patternlint-disable-next-line executorch-cpp-nostdinc
10
+ #include < functional>
10
11
12
+ #include < executorch/kernels/portable/cpu/pattern/bitwise_op.h>
11
13
#include < executorch/kernels/portable/cpu/scalar_utils.h>
12
14
#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13
15
#include < executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,28 +19,13 @@ namespace torch {
17
19
namespace executor {
18
20
namespace native {
19
21
20
- namespace {
21
-
22
- template <typename CTYPE>
23
- CTYPE bitwise_xor (CTYPE a, CTYPE b) {
24
- return a ^ b;
25
- }
26
-
27
- template <>
28
- bool bitwise_xor<bool >(bool a, bool b) {
29
- return a != b;
30
- }
31
-
32
- } // namespace
33
-
34
22
using Tensor = exec_aten::Tensor;
35
23
36
24
Tensor& bitwise_xor_Tensor_out (
37
25
RuntimeContext& ctx,
38
26
const Tensor& a,
39
27
const Tensor& b,
40
28
Tensor& out) {
41
- // Determine output size and resize for dynamic shapes
42
29
ET_KERNEL_CHECK (
43
30
ctx,
44
31
resize_to_broadcast_target_size (a, b, out) == Error::Ok,
@@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out(
56
43
Bool, a_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_A, [&]() {
57
44
ET_SWITCH_INT_TYPES_AND (
58
45
Bool, b_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_B, [&]() {
59
- ET_SWITCH_INT_TYPES_AND (
46
+ using CTYPE_IN = typename torch::executor::
47
+ promote_types<CTYPE_A, CTYPE_B>::type;
48
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
49
+ ET_SWITCH_REAL_TYPES_AND (
60
50
Bool,
61
- common_type ,
51
+ out_type ,
62
52
ctx,
63
53
" bitwise_xor.Tensor_out" ,
64
- CTYPE_IN ,
54
+ CTYPE_OUT ,
65
55
[&]() {
66
- ET_SWITCH_REAL_TYPES_AND (
67
- Bool,
68
- out_type,
69
- ctx,
70
- " bitwise_xor.Tensor_out" ,
71
- CTYPE_OUT,
72
- [&]() {
73
- apply_binary_elementwise_fn<
74
- CTYPE_A,
75
- CTYPE_B,
76
- CTYPE_OUT>(
77
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
78
- CTYPE_IN a_casted =
79
- static_cast <CTYPE_IN>(val_a);
80
- CTYPE_IN b_casted =
81
- static_cast <CTYPE_IN>(val_b);
82
- CTYPE_IN value =
83
- bitwise_xor (a_casted, b_casted);
84
-
85
- return static_cast <CTYPE_OUT>(value);
86
- },
87
- a,
88
- b,
89
- out);
90
- });
56
+ internal::BitwiseOpInner<
57
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
58
+ std::bit_xor,
59
+ CTYPE_A,
60
+ CTYPE_B,
61
+ CTYPE_IN,
62
+ CTYPE_OUT>::run (a, b, out);
91
63
});
92
64
});
93
65
});
@@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out(
143
115
static_cast <CTYPE_IN>(val_a);
144
116
CTYPE_IN b_casted =
145
117
static_cast <CTYPE_IN>(val_b);
146
- CTYPE_IN value =
147
- bitwise_xor ( a_casted, b_casted);
118
+ CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
119
+ a_casted, b_casted);
148
120
149
121
return static_cast <CTYPE_OUT>(value);
150
122
},
0 commit comments