6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < cmath>
9
+ // patternlint-disable-next-line executorch-cpp-nostdinc
10
+ #include < functional>
10
11
12
+ #include < executorch/kernels/portable/cpu/pattern/bitwise_op.h>
11
13
#include < executorch/kernels/portable/cpu/scalar_utils.h>
12
14
#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13
15
#include < executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,82 +19,14 @@ namespace torch {
17
19
namespace executor {
18
20
namespace native {
19
21
20
- namespace {
21
-
22
- template <typename CTYPE>
23
- CTYPE bitwise_xor (CTYPE a, CTYPE b) {
24
- return a ^ b;
25
- }
26
-
27
- template <>
28
- bool bitwise_xor<bool >(bool a, bool b) {
29
- return a != b;
30
- }
31
-
32
- } // namespace
33
-
34
22
using Tensor = exec_aten::Tensor;
35
23
36
24
Tensor& bitwise_xor_Tensor_out (
37
25
RuntimeContext& ctx,
38
26
const Tensor& a,
39
27
const Tensor& b,
40
28
Tensor& out) {
41
- // Determine output size and resize for dynamic shapes
42
- ET_KERNEL_CHECK (
43
- ctx,
44
- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
45
- InvalidArgument,
46
- out);
47
-
48
- ScalarType a_type = a.scalar_type ();
49
- ScalarType b_type = b.scalar_type ();
50
- ScalarType common_type = promoteTypes (a_type, b_type);
51
- ScalarType out_type = out.scalar_type ();
52
-
53
- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
54
-
55
- ET_SWITCH_INT_TYPES_AND (
56
- Bool, a_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_A, [&]() {
57
- ET_SWITCH_INT_TYPES_AND (
58
- Bool, b_type, ctx, " bitwise_xor.Tensor_out" , CTYPE_B, [&]() {
59
- ET_SWITCH_INT_TYPES_AND (
60
- Bool,
61
- common_type,
62
- ctx,
63
- " bitwise_xor.Tensor_out" ,
64
- CTYPE_IN,
65
- [&]() {
66
- ET_SWITCH_REAL_TYPES_AND (
67
- Bool,
68
- out_type,
69
- ctx,
70
- " bitwise_xor.Tensor_out" ,
71
- CTYPE_OUT,
72
- [&]() {
73
- apply_binary_elementwise_fn<
74
- CTYPE_A,
75
- CTYPE_B,
76
- CTYPE_OUT>(
77
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
78
- CTYPE_IN a_casted =
79
- static_cast <CTYPE_IN>(val_a);
80
- CTYPE_IN b_casted =
81
- static_cast <CTYPE_IN>(val_b);
82
- CTYPE_IN value =
83
- bitwise_xor (a_casted, b_casted);
84
-
85
- return static_cast <CTYPE_OUT>(value);
86
- },
87
- a,
88
- b,
89
- out);
90
- });
91
- });
92
- });
93
- });
94
-
95
- return out;
29
+ return bitwise_op_out<std::bit_xor>(ctx, a, b, out, " bitwise_xor.Tensor_out" );
96
30
}
97
31
98
32
Tensor& bitwise_xor_Scalar_out (
@@ -143,8 +77,8 @@ Tensor& bitwise_xor_Scalar_out(
143
77
static_cast <CTYPE_IN>(val_a);
144
78
CTYPE_IN b_casted =
145
79
static_cast <CTYPE_IN>(val_b);
146
- CTYPE_IN value =
147
- bitwise_xor ( a_casted, b_casted);
80
+ CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
81
+ a_casted, b_casted);
148
82
149
83
return static_cast <CTYPE_OUT>(value);
150
84
},
0 commit comments