Skip to content

Commit 446ed37

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce bitwise op size & build time (#3487)
Summary: Finally getting close to the end of compile-time promotion for Tensor ops! Differential Revision: D56855548
1 parent 58dba0c commit 446ed37

File tree

1 file changed

+65
-15
lines changed

1 file changed

+65
-15
lines changed

kernels/portable/cpu/pattern/bitwise_op.h

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,60 @@
1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17+
namespace internal {
18+
19+
template <
20+
bool can_cast,
21+
template <typename>
22+
typename OpFunc,
23+
typename CTYPE_A,
24+
typename CTYPE_B,
25+
typename CTYPE_IN,
26+
typename CTYPE_OUT>
27+
struct BitwiseOpInner;
28+
29+
template <
30+
template <typename>
31+
typename OpFunc,
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct BitwiseOpInner<true, OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
struct ReportCanCastBug {
54+
static void run(const Tensor&, const Tensor&, Tensor&) {
55+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56+
}
57+
};
58+
59+
template <
60+
template <typename>
61+
typename OpFunc,
62+
typename CTYPE_A,
63+
typename CTYPE_B,
64+
typename CTYPE_IN,
65+
typename CTYPE_OUT>
66+
struct BitwiseOpInner<false, OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67+
: public ReportCanCastBug {};
68+
69+
} // namespace internal
70+
1771
template <template <typename> typename OpFunc>
1872
Tensor& bitwise_op_out(
1973
RuntimeContext& ctx,
@@ -36,21 +90,17 @@ Tensor& bitwise_op_out(
3690

3791
ET_SWITCH_INT_TYPES_AND(Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
3892
ET_SWITCH_INT_TYPES_AND(Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
39-
ET_SWITCH_INT_TYPES_AND(Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40-
ET_SWITCH_REAL_TYPES_AND(
41-
Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
44-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
45-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
46-
CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47-
48-
return static_cast<CTYPE_OUT>(value);
49-
},
50-
a,
51-
b,
52-
out);
53-
});
93+
using CTYPE_IN =
94+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
95+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
96+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
97+
internal::BitwiseOpInner<
98+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
99+
OpFunc,
100+
CTYPE_A,
101+
CTYPE_B,
102+
CTYPE_IN,
103+
CTYPE_OUT>::run(a, b, out);
54104
});
55105
});
56106
});

0 commit comments

Comments
 (0)