Skip to content

Commit 0aee706

Browse files
swolchokfacebook-github-bot
authored andcommitted
Add pattern template for binary bitwise ops
Summary: Similar to D56744651. Differential Revision: D56852163
1 parent 58091a9 commit 0aee706

File tree

6 files changed

+95
-212
lines changed

6 files changed

+95
-212
lines changed

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,81 +19,14 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_and(CTYPE a, CTYPE b) {
24-
return a & b;
25-
}
26-
27-
template <>
28-
bool bitwise_and<bool>(bool a, bool b) {
29-
return a && b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_and_Tensor_out(
3725
RuntimeContext& ctx,
3826
const Tensor& a,
3927
const Tensor& b,
4028
Tensor& out) {
41-
ET_KERNEL_CHECK(
42-
ctx,
43-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44-
InvalidArgument,
45-
out);
46-
47-
ScalarType a_type = a.scalar_type();
48-
ScalarType b_type = b.scalar_type();
49-
ScalarType common_type = promoteTypes(a_type, b_type);
50-
ScalarType out_type = out.scalar_type();
51-
52-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
53-
54-
ET_SWITCH_INT_TYPES_AND(
55-
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
56-
ET_SWITCH_INT_TYPES_AND(
57-
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
58-
ET_SWITCH_INT_TYPES_AND(
59-
Bool,
60-
common_type,
61-
ctx,
62-
"bitwise_and.Tensor_out",
63-
CTYPE_IN,
64-
[&]() {
65-
ET_SWITCH_REAL_TYPES_AND(
66-
Bool,
67-
out_type,
68-
ctx,
69-
"bitwise_and.Tensor_out",
70-
CTYPE_OUT,
71-
[&]() {
72-
apply_binary_elementwise_fn<
73-
CTYPE_A,
74-
CTYPE_B,
75-
CTYPE_OUT>(
76-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
77-
CTYPE_IN a_casted =
78-
static_cast<CTYPE_IN>(val_a);
79-
CTYPE_IN b_casted =
80-
static_cast<CTYPE_IN>(val_b);
81-
CTYPE_IN value =
82-
bitwise_and(a_casted, b_casted);
83-
84-
return static_cast<CTYPE_OUT>(value);
85-
},
86-
a,
87-
b,
88-
out);
89-
});
90-
});
91-
});
92-
});
93-
94-
return out;
29+
return bitwise_op_out<std::bit_and>(ctx, a, b, out, "bitwise_and.Tensor_out");
9530
}
9631

9732
Tensor& bitwise_and_Scalar_out(
@@ -142,8 +77,8 @@ Tensor& bitwise_and_Scalar_out(
14277
static_cast<CTYPE_IN>(val_a);
14378
CTYPE_IN b_casted =
14479
static_cast<CTYPE_IN>(val_b);
145-
CTYPE_IN value =
146-
bitwise_and(a_casted, b_casted);
80+
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
81+
a_casted, b_casted);
14782

14883
return static_cast<CTYPE_OUT>(value);
14984
},

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 6 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,80 +19,14 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_or(CTYPE a, CTYPE b) {
24-
return a | b;
25-
}
26-
27-
template <>
28-
bool bitwise_or<bool>(bool a, bool b) {
29-
return a || b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_or_Tensor_out(
3725
RuntimeContext& ctx,
3826
const Tensor& a,
3927
const Tensor& b,
4028
Tensor& out) {
41-
ET_KERNEL_CHECK(
42-
ctx,
43-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44-
InvalidArgument,
45-
out);
46-
47-
ScalarType a_type = a.scalar_type();
48-
ScalarType b_type = b.scalar_type();
49-
ScalarType common_type = promoteTypes(a_type, b_type);
50-
ScalarType out_type = out.scalar_type();
51-
52-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
53-
54-
ET_SWITCH_INT_TYPES_AND(
55-
Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() {
56-
ET_SWITCH_INT_TYPES_AND(
57-
Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() {
58-
ET_SWITCH_INT_TYPES_AND(
59-
Bool,
60-
common_type,
61-
ctx,
62-
"bitwise_or.Tensor_out",
63-
CTYPE_IN,
64-
[&]() {
65-
ET_SWITCH_REAL_TYPES_AND(
66-
Bool,
67-
out_type,
68-
ctx,
69-
"bitwise_or.Tensor_out",
70-
CTYPE_OUT,
71-
[&]() {
72-
apply_binary_elementwise_fn<
73-
CTYPE_A,
74-
CTYPE_B,
75-
CTYPE_OUT>(
76-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
77-
CTYPE_IN a_casted =
78-
static_cast<CTYPE_IN>(val_a);
79-
CTYPE_IN b_casted =
80-
static_cast<CTYPE_IN>(val_b);
81-
CTYPE_IN value = bitwise_or(a_casted, b_casted);
82-
83-
return static_cast<CTYPE_OUT>(value);
84-
},
85-
a,
86-
b,
87-
out);
88-
});
89-
});
90-
});
91-
});
92-
93-
return out;
29+
return bitwise_op_out<std::bit_or>(ctx, a, b, out, "bitwise_or.Tensor_out");
9430
}
9531

9632
Tensor& bitwise_or_Scalar_out(
@@ -141,7 +77,8 @@ Tensor& bitwise_or_Scalar_out(
14177
static_cast<CTYPE_IN>(val_a);
14278
CTYPE_IN b_casted =
14379
static_cast<CTYPE_IN>(val_b);
144-
CTYPE_IN value = bitwise_or(a_casted, b_casted);
80+
CTYPE_IN value =
81+
std::bit_or<CTYPE_IN>()(a_casted, b_casted);
14582

14683
return static_cast<CTYPE_OUT>(value);
14784
},

kernels/portable/cpu/op_bitwise_xor.cpp

Lines changed: 6 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,82 +19,14 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_xor(CTYPE a, CTYPE b) {
24-
return a ^ b;
25-
}
26-
27-
template <>
28-
bool bitwise_xor<bool>(bool a, bool b) {
29-
return a != b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_xor_Tensor_out(
3725
RuntimeContext& ctx,
3826
const Tensor& a,
3927
const Tensor& b,
4028
Tensor& out) {
41-
// Determine output size and resize for dynamic shapes
42-
ET_KERNEL_CHECK(
43-
ctx,
44-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
45-
InvalidArgument,
46-
out);
47-
48-
ScalarType a_type = a.scalar_type();
49-
ScalarType b_type = b.scalar_type();
50-
ScalarType common_type = promoteTypes(a_type, b_type);
51-
ScalarType out_type = out.scalar_type();
52-
53-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
54-
55-
ET_SWITCH_INT_TYPES_AND(
56-
Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() {
57-
ET_SWITCH_INT_TYPES_AND(
58-
Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() {
59-
ET_SWITCH_INT_TYPES_AND(
60-
Bool,
61-
common_type,
62-
ctx,
63-
"bitwise_xor.Tensor_out",
64-
CTYPE_IN,
65-
[&]() {
66-
ET_SWITCH_REAL_TYPES_AND(
67-
Bool,
68-
out_type,
69-
ctx,
70-
"bitwise_xor.Tensor_out",
71-
CTYPE_OUT,
72-
[&]() {
73-
apply_binary_elementwise_fn<
74-
CTYPE_A,
75-
CTYPE_B,
76-
CTYPE_OUT>(
77-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
78-
CTYPE_IN a_casted =
79-
static_cast<CTYPE_IN>(val_a);
80-
CTYPE_IN b_casted =
81-
static_cast<CTYPE_IN>(val_b);
82-
CTYPE_IN value =
83-
bitwise_xor(a_casted, b_casted);
84-
85-
return static_cast<CTYPE_OUT>(value);
86-
},
87-
a,
88-
b,
89-
out);
90-
});
91-
});
92-
});
93-
});
94-
95-
return out;
29+
return bitwise_op_out<std::bit_xor>(ctx, a, b, out, "bitwise_xor.Tensor_out");
9630
}
9731

9832
Tensor& bitwise_xor_Scalar_out(
@@ -143,8 +77,8 @@ Tensor& bitwise_xor_Scalar_out(
14377
static_cast<CTYPE_IN>(val_a);
14478
CTYPE_IN b_casted =
14579
static_cast<CTYPE_IN>(val_b);
146-
CTYPE_IN value =
147-
bitwise_xor(a_casted, b_casted);
80+
CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
81+
a_casted, b_casted);
14882

14983
return static_cast<CTYPE_OUT>(value);
15084
},
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
template <template <typename> typename OpFunc>
18+
Tensor& bitwise_op_out(
19+
RuntimeContext& ctx,
20+
const Tensor& a,
21+
const Tensor& b,
22+
Tensor& out,
23+
const char* op_name) {
24+
ET_KERNEL_CHECK(
25+
ctx,
26+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
27+
InvalidArgument,
28+
out);
29+
30+
ScalarType a_type = a.scalar_type();
31+
ScalarType b_type = b.scalar_type();
32+
ScalarType common_type = promoteTypes(a_type, b_type);
33+
ScalarType out_type = out.scalar_type();
34+
35+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
36+
37+
ET_SWITCH_INT_TYPES_AND(Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
38+
ET_SWITCH_INT_TYPES_AND(Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
39+
ET_SWITCH_INT_TYPES_AND(Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40+
ET_SWITCH_REAL_TYPES_AND(
41+
Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
45+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
46+
CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47+
48+
return static_cast<CTYPE_OUT>(value);
49+
},
50+
a,
51+
b,
52+
out);
53+
});
54+
});
55+
});
56+
});
57+
58+
return out;
59+
}
60+
} // namespace native
61+
} // namespace executor
62+
} // namespace torch

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9+
runtime.cxx_library(
10+
name = "bitwise_op",
11+
exported_headers = [
12+
"bitwise_op.h",
13+
],
14+
compiler_flags = [],
15+
deps = [
16+
"//executorch/runtime/kernel:kernel_includes",
17+
],
18+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
19+
)
20+
921
runtime.cxx_library(
1022
name = "comparison_op",
1123
exported_headers = [

0 commit comments

Comments
 (0)