Skip to content

Commit 6ff3cb6

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce floor_divide size & build time
Summary: Continuing rollout of this technique. Differential Revision: D56827786
1 parent 3f8077a commit 6ff3cb6

File tree

2 files changed

+69
-29
lines changed

2 files changed

+69
-29
lines changed

kernels/portable/cpu/op_floor_divide.cpp

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,59 @@ namespace native {
2020
using Tensor = exec_aten::Tensor;
2121
using ScalarType = exec_aten::ScalarType;
2222

23+
namespace {
24+
template <
25+
bool can_cast,
26+
typename CTYPE_A,
27+
typename CTYPE_B,
28+
typename CTYPE_IN,
29+
typename CTYPE_OUT>
30+
struct FloorDivideInner;
31+
32+
template <
33+
typename CTYPE_A,
34+
typename CTYPE_B,
35+
typename CTYPE_IN,
36+
typename CTYPE_OUT>
37+
struct FloorDivideInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38+
static void
39+
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
40+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
41+
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
42+
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
43+
if (val_b == 0) {
44+
div_by_zero_error = true;
45+
return static_cast<CTYPE_OUT>(0);
46+
}
47+
}
48+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
49+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
50+
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);
51+
52+
return static_cast<CTYPE_OUT>(value);
53+
},
54+
a,
55+
b,
56+
out);
57+
}
58+
};
59+
60+
struct ReportCanCastBug {
61+
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
62+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
63+
}
64+
};
65+
66+
template <
67+
typename CTYPE_A,
68+
typename CTYPE_B,
69+
typename CTYPE_IN,
70+
typename CTYPE_OUT>
71+
struct FloorDivideInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
72+
: public ReportCanCastBug {};
73+
74+
} // namespace
75+
2376
Tensor& floor_divide_out(
2477
RuntimeContext& ctx,
2578
const Tensor& a,
@@ -46,36 +99,17 @@ Tensor& floor_divide_out(
4699
Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() {
47100
ET_SWITCH_REAL_TYPES_AND(
48101
Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() {
102+
using CTYPE_IN = typename torch::executor::
103+
promote_types<CTYPE_A, CTYPE_B>::type;
104+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
49105
ET_SWITCH_REAL_TYPES(
50-
common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() {
51-
ET_SWITCH_REAL_TYPES(
52-
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
53-
apply_binary_elementwise_fn<
54-
CTYPE_A,
55-
CTYPE_B,
56-
CTYPE_OUT>(
57-
[common_type, &div_by_zero_error](
58-
const CTYPE_A val_a, const CTYPE_B val_b) {
59-
if (isIntegralType(
60-
common_type, /*includeBool=*/true)) {
61-
if (val_b == 0) {
62-
div_by_zero_error = true;
63-
return static_cast<CTYPE_OUT>(0);
64-
}
65-
}
66-
CTYPE_IN a_casted =
67-
static_cast<CTYPE_IN>(val_a);
68-
CTYPE_IN b_casted =
69-
static_cast<CTYPE_IN>(val_b);
70-
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(
71-
a_casted, b_casted);
72-
73-
return static_cast<CTYPE_OUT>(value);
74-
},
75-
a,
76-
b,
77-
out);
78-
});
106+
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
107+
FloorDivideInner<
108+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
109+
CTYPE_A,
110+
CTYPE_B,
111+
CTYPE_IN,
112+
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
79113
});
80114
});
81115
});

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ inline constexpr bool isIntegralType(
353353
t == exec_aten::ScalarType::Short);
354354
}
355355

356+
template <typename T, bool includeBool>
357+
struct is_integral_type
358+
: public std::integral_constant<
359+
bool,
360+
isIntegralType(CppTypeToScalarType<T>::value, includeBool)> {};
361+
356362
inline constexpr bool isFloatingType(exec_aten::ScalarType t) {
357363
return (
358364
t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Float ||

0 commit comments

Comments
 (0)