Skip to content

Commit 724d09a

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce bitwise op size & build time (#3487)
Summary: Finally getting close to the end of compile-time promotion for Tensor ops! Differential Revision: D56855548
1 parent 2efd867 commit 724d09a

File tree

7 files changed

+152
-144
lines changed

7 files changed

+152
-144
lines changed

kernels/portable/cpu/op_bitwise_and.cpp

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,20 +19,6 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_and(CTYPE a, CTYPE b) {
24-
return a & b;
25-
}
26-
27-
template <>
28-
bool bitwise_and<bool>(bool a, bool b) {
29-
return a && b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_and_Tensor_out(
@@ -55,38 +43,23 @@ Tensor& bitwise_and_Tensor_out(
5543
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
5644
ET_SWITCH_INT_TYPES_AND(
5745
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
58-
ET_SWITCH_INT_TYPES_AND(
46+
using CTYPE_IN = typename torch::executor::
47+
promote_types<CTYPE_A, CTYPE_B>::type;
48+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
49+
ET_SWITCH_REAL_TYPES_AND(
5950
Bool,
60-
common_type,
51+
out_type,
6152
ctx,
6253
"bitwise_and.Tensor_out",
63-
CTYPE_IN,
54+
CTYPE_OUT,
6455
[&]() {
65-
ET_SWITCH_REAL_TYPES_AND(
66-
Bool,
67-
out_type,
68-
ctx,
69-
"bitwise_and.Tensor_out",
70-
CTYPE_OUT,
71-
[&]() {
72-
apply_binary_elementwise_fn<
73-
CTYPE_A,
74-
CTYPE_B,
75-
CTYPE_OUT>(
76-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
77-
CTYPE_IN a_casted =
78-
static_cast<CTYPE_IN>(val_a);
79-
CTYPE_IN b_casted =
80-
static_cast<CTYPE_IN>(val_b);
81-
CTYPE_IN value =
82-
bitwise_and(a_casted, b_casted);
83-
84-
return static_cast<CTYPE_OUT>(value);
85-
},
86-
a,
87-
b,
88-
out);
89-
});
56+
internal::BitwiseOpInner<
57+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
58+
std::bit_and,
59+
CTYPE_A,
60+
CTYPE_B,
61+
CTYPE_IN,
62+
CTYPE_OUT>::run(a, b, out);
9063
});
9164
});
9265
});
@@ -142,8 +115,8 @@ Tensor& bitwise_and_Scalar_out(
142115
static_cast<CTYPE_IN>(val_a);
143116
CTYPE_IN b_casted =
144117
static_cast<CTYPE_IN>(val_b);
145-
CTYPE_IN value =
146-
bitwise_and(a_casted, b_casted);
118+
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
119+
a_casted, b_casted);
147120

148121
return static_cast<CTYPE_OUT>(value);
149122
},

kernels/portable/cpu/op_bitwise_or.cpp

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,20 +19,6 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_or(CTYPE a, CTYPE b) {
24-
return a | b;
25-
}
26-
27-
template <>
28-
bool bitwise_or<bool>(bool a, bool b) {
29-
return a || b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_or_Tensor_out(
@@ -55,37 +43,23 @@ Tensor& bitwise_or_Tensor_out(
5543
Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() {
5644
ET_SWITCH_INT_TYPES_AND(
5745
Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() {
58-
ET_SWITCH_INT_TYPES_AND(
46+
using CTYPE_IN = typename torch::executor::
47+
promote_types<CTYPE_A, CTYPE_B>::type;
48+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
49+
ET_SWITCH_REAL_TYPES_AND(
5950
Bool,
60-
common_type,
51+
out_type,
6152
ctx,
6253
"bitwise_or.Tensor_out",
63-
CTYPE_IN,
54+
CTYPE_OUT,
6455
[&]() {
65-
ET_SWITCH_REAL_TYPES_AND(
66-
Bool,
67-
out_type,
68-
ctx,
69-
"bitwise_or.Tensor_out",
70-
CTYPE_OUT,
71-
[&]() {
72-
apply_binary_elementwise_fn<
73-
CTYPE_A,
74-
CTYPE_B,
75-
CTYPE_OUT>(
76-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
77-
CTYPE_IN a_casted =
78-
static_cast<CTYPE_IN>(val_a);
79-
CTYPE_IN b_casted =
80-
static_cast<CTYPE_IN>(val_b);
81-
CTYPE_IN value = bitwise_or(a_casted, b_casted);
82-
83-
return static_cast<CTYPE_OUT>(value);
84-
},
85-
a,
86-
b,
87-
out);
88-
});
56+
internal::BitwiseOpInner<
57+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
58+
std::bit_or,
59+
CTYPE_A,
60+
CTYPE_B,
61+
CTYPE_IN,
62+
CTYPE_OUT>::run(a, b, out);
8963
});
9064
});
9165
});
@@ -141,7 +115,8 @@ Tensor& bitwise_or_Scalar_out(
141115
static_cast<CTYPE_IN>(val_a);
142116
CTYPE_IN b_casted =
143117
static_cast<CTYPE_IN>(val_b);
144-
CTYPE_IN value = bitwise_or(a_casted, b_casted);
118+
CTYPE_IN value =
119+
std::bit_or<CTYPE_IN>()(a_casted, b_casted);
145120

146121
return static_cast<CTYPE_OUT>(value);
147122
},

kernels/portable/cpu/op_bitwise_xor.cpp

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cmath>
9+
// patternlint-disable-next-line executorch-cpp-nostdinc
10+
#include <functional>
1011

12+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
1113
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1214
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1315
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -17,28 +19,13 @@ namespace torch {
1719
namespace executor {
1820
namespace native {
1921

20-
namespace {
21-
22-
template <typename CTYPE>
23-
CTYPE bitwise_xor(CTYPE a, CTYPE b) {
24-
return a ^ b;
25-
}
26-
27-
template <>
28-
bool bitwise_xor<bool>(bool a, bool b) {
29-
return a != b;
30-
}
31-
32-
} // namespace
33-
3422
using Tensor = exec_aten::Tensor;
3523

3624
Tensor& bitwise_xor_Tensor_out(
3725
RuntimeContext& ctx,
3826
const Tensor& a,
3927
const Tensor& b,
4028
Tensor& out) {
41-
// Determine output size and resize for dynamic shapes
4229
ET_KERNEL_CHECK(
4330
ctx,
4431
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
@@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out(
5643
Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() {
5744
ET_SWITCH_INT_TYPES_AND(
5845
Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() {
59-
ET_SWITCH_INT_TYPES_AND(
46+
using CTYPE_IN = typename torch::executor::
47+
promote_types<CTYPE_A, CTYPE_B>::type;
48+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
49+
ET_SWITCH_REAL_TYPES_AND(
6050
Bool,
61-
common_type,
51+
out_type,
6252
ctx,
6353
"bitwise_xor.Tensor_out",
64-
CTYPE_IN,
54+
CTYPE_OUT,
6555
[&]() {
66-
ET_SWITCH_REAL_TYPES_AND(
67-
Bool,
68-
out_type,
69-
ctx,
70-
"bitwise_xor.Tensor_out",
71-
CTYPE_OUT,
72-
[&]() {
73-
apply_binary_elementwise_fn<
74-
CTYPE_A,
75-
CTYPE_B,
76-
CTYPE_OUT>(
77-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
78-
CTYPE_IN a_casted =
79-
static_cast<CTYPE_IN>(val_a);
80-
CTYPE_IN b_casted =
81-
static_cast<CTYPE_IN>(val_b);
82-
CTYPE_IN value =
83-
bitwise_xor(a_casted, b_casted);
84-
85-
return static_cast<CTYPE_OUT>(value);
86-
},
87-
a,
88-
b,
89-
out);
90-
});
56+
internal::BitwiseOpInner<
57+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
58+
std::bit_xor,
59+
CTYPE_A,
60+
CTYPE_B,
61+
CTYPE_IN,
62+
CTYPE_OUT>::run(a, b, out);
9163
});
9264
});
9365
});
@@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out(
143115
static_cast<CTYPE_IN>(val_a);
144116
CTYPE_IN b_casted =
145117
static_cast<CTYPE_IN>(val_b);
146-
CTYPE_IN value =
147-
bitwise_xor(a_casted, b_casted);
118+
CTYPE_IN value = std::bit_xor<CTYPE_IN>()(
119+
a_casted, b_casted);
148120

149121
return static_cast<CTYPE_OUT>(value);
150122
},
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
namespace internal {
18+
19+
template <
20+
bool can_cast,
21+
template <typename>
22+
typename OpFunc,
23+
typename CTYPE_A,
24+
typename CTYPE_B,
25+
typename CTYPE_IN,
26+
typename CTYPE_OUT>
27+
struct BitwiseOpInner;
28+
29+
template <
30+
template <typename>
31+
typename OpFunc,
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct BitwiseOpInner<true, OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
struct ReportCanCastBug {
54+
static void run(const Tensor&, const Tensor&, Tensor&) {
55+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56+
}
57+
};
58+
59+
template <
60+
template <typename>
61+
typename OpFunc,
62+
typename CTYPE_A,
63+
typename CTYPE_B,
64+
typename CTYPE_IN,
65+
typename CTYPE_OUT>
66+
struct BitwiseOpInner<false, OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67+
: public ReportCanCastBug {};
68+
69+
} // namespace internal
70+
} // namespace native
71+
} // namespace executor
72+
} // namespace torch

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@ def define_common_targets():
66
The directory containing this targets.bzl file should also contain both
77
TARGETS and BUCK files that call this function.
88
"""
9+
runtime.cxx_library(
10+
name = "bitwise_op",
11+
exported_headers = [
12+
"bitwise_op.h",
13+
],
14+
compiler_flags = [],
15+
deps = [
16+
"//executorch/runtime/kernel:kernel_includes",
17+
],
18+
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
19+
)
920

1021
runtime.cxx_library(
1122
name = "pattern",

0 commit comments

Comments
 (0)