Skip to content

Migrate elementwise_util callers to the variants with out_dtypes in template arguments #9841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Apr 23, 2025
Merged
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
31a49e0
Update
swolchok Mar 19, 2025
9fcd885
Update
swolchok Mar 19, 2025
29d6de9
Update
swolchok Mar 19, 2025
79b908c
Update
swolchok Mar 19, 2025
fd62a07
Update
swolchok Mar 19, 2025
854c991
Update
swolchok Mar 19, 2025
def7ed4
Update
swolchok Mar 19, 2025
40c1b1b
Update
swolchok Mar 19, 2025
7c78357
Update
swolchok Mar 19, 2025
7ba269a
Update
swolchok Mar 19, 2025
edd45fb
Update
swolchok Mar 19, 2025
b9c545f
Update
swolchok Mar 20, 2025
3091007
Update
swolchok Mar 20, 2025
4a00cac
Update
swolchok Mar 20, 2025
21b81bf
Update
swolchok Mar 20, 2025
4c4add0
Update
swolchok Mar 20, 2025
8782a90
Update
swolchok Mar 20, 2025
75f8970
Update
swolchok Mar 20, 2025
2d19e75
Update
swolchok Mar 20, 2025
b61a8a2
Update
swolchok Mar 25, 2025
91161bd
Update
swolchok Mar 25, 2025
4add706
Update
swolchok Mar 25, 2025
5348a92
Update
swolchok Mar 25, 2025
001d72c
Update
swolchok Mar 25, 2025
e49080d
Update
swolchok Mar 25, 2025
44ee51a
Update
swolchok Mar 25, 2025
f659627
Update
swolchok Mar 25, 2025
f1c5429
Update
swolchok Mar 25, 2025
b34f04f
Update
swolchok Mar 25, 2025
f934bc0
Update
swolchok Mar 25, 2025
3a74f25
Update
swolchok Mar 25, 2025
9a93839
Update
swolchok Mar 26, 2025
bb16a55
Update
swolchok Mar 26, 2025
2242f1e
Update
swolchok Mar 26, 2025
7f57a19
Update
swolchok Mar 26, 2025
5d95c06
Update
swolchok Mar 26, 2025
42623bb
Update
swolchok Mar 26, 2025
4553283
Update
swolchok Mar 26, 2025
39610ad
Update
swolchok Mar 26, 2025
b3120fa
Update
swolchok Mar 26, 2025
ff2c358
Update
swolchok Mar 26, 2025
7086659
Update
swolchok Mar 28, 2025
e13de0e
Update
swolchok Mar 28, 2025
943ab82
Update
swolchok Mar 28, 2025
f22d039
Update
swolchok Mar 28, 2025
45ce46d
Update
swolchok Mar 28, 2025
754dba4
Update
swolchok Mar 28, 2025
34eb5d4
Update
swolchok Mar 28, 2025
ea9dc6f
Update
swolchok Mar 28, 2025
7d7859e
Update
swolchok Mar 28, 2025
b98829d
Update
swolchok Mar 28, 2025
3140910
Update
swolchok Mar 28, 2025
946f2e0
Update
swolchok Mar 28, 2025
7f2bbdb
Update
swolchok Apr 2, 2025
960315e
Update
swolchok Apr 2, 2025
9e42e93
Update
swolchok Apr 2, 2025
96d258e
Update
swolchok Apr 2, 2025
e6f66ab
Update
swolchok Apr 2, 2025
de9d52f
Update
swolchok Apr 2, 2025
20f3046
Update
swolchok Apr 2, 2025
3aa266d
Update
swolchok Apr 2, 2025
3c88a56
Update
swolchok Apr 2, 2025
153735d
Update
swolchok Apr 2, 2025
cac4293
Update
swolchok Apr 2, 2025
85451ea
Update
swolchok Apr 2, 2025
77a4fc6
Update
swolchok Apr 2, 2025
b0fc7f9
Update
swolchok Apr 22, 2025
8dd15f7
Update
swolchok Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
@@ -52,17 +52,19 @@ Tensor& add_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[val_alpha](const auto val_a, const auto val_b) {
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
@@ -100,17 +102,19 @@ Tensor& add_scalar_out(
static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[b, alpha](const CTYPE_COMPUTE val_a) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[b, alpha](const auto val_a) {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_addmm.cpp
Original file line number Diff line number Diff line change
@@ -88,17 +88,19 @@ Tensor& addmm_out(
n,
p);

utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[alpha_val, beta_val](const auto val_a, const auto val_b) {
return val_a * alpha_val + val_b * beta_val;
},
ctx,
out,
utils::SupportedTensorDtypes::REALHBF16,
in,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
}
});

10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_atan2.cpp
Original file line number Diff line number Diff line change
@@ -55,17 +55,19 @@ Tensor& atan2_out(
static constexpr const char op_name[] = "atan2.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) {
return std::atan2(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
@@ -134,8 +134,12 @@ Tensor& clamp_out(
static constexpr const char op_name[] = "clamp.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(
@@ -150,8 +154,7 @@ Tensor& clamp_out(
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
@@ -210,11 +213,15 @@ Tensor& clamp_tensor_out(
static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_tritensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[has_min, has_max](
const CTYPE_COMPUTE val_in,
const CTYPE_COMPUTE val_min,
const CTYPE_COMPUTE val_max) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(val_out, val_min);
@@ -231,8 +238,7 @@ Tensor& clamp_tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
max,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
20 changes: 12 additions & 8 deletions kernels/portable/cpu/op_copy.cpp
Original file line number Diff line number Diff line change
@@ -47,15 +47,17 @@ Tensor& copy_out(
static constexpr const char op_name[] = "copy.out";

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
@@ -80,15 +82,17 @@ Tensor& copy_(
static constexpr const char op_name[] = "copy_";

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const auto _, const auto val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
in,
utils::SupportedTensorDtypes::REALHBBF16);
in);
});

return in;
31 changes: 18 additions & 13 deletions kernels/portable/cpu/op_div.cpp
Original file line number Diff line number Diff line change
@@ -58,17 +58,17 @@ Tensor& div_out(
static constexpr const char op_name[] = "div.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a / val_b;
},
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
@@ -122,9 +122,13 @@ Tensor& div_out_mode(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[mode_is_trunc, &div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
@@ -146,8 +150,7 @@ Tensor& div_out_mode(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
@@ -188,13 +191,15 @@ Tensor& div_scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b](const auto val_a) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
11 changes: 7 additions & 4 deletions kernels/portable/cpu/op_elu.cpp
Original file line number Diff line number Diff line change
@@ -44,17 +44,20 @@ Tensor& elu_out(
ET_EXTRACT_SCALAR(scale, math_scale);
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
const auto negcoef = math_alpha * math_scale;
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
[negcoef, math_scale, math_input_scale](auto x) {
utils::apply_unitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[negcoef, math_scale, math_input_scale](const auto x) {
// TODO: rewrite this to be vectorization-capable.
return MathT(x) <= MathT(0)
? std::expm1(MathT(x) * math_input_scale) * negcoef
: MathT(x) * math_scale;
},
ctx,
in,
utils::SupportedTensorDtypes::FLOATHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});
return out;
}
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_floor_divide.cpp
Original file line number Diff line number Diff line change
@@ -53,9 +53,13 @@ Tensor& floor_divide_out(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
@@ -69,8 +73,7 @@ Tensor& floor_divide_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_fmod.cpp
Original file line number Diff line number Diff line change
@@ -55,9 +55,13 @@ Tensor& fmod_Tensor_out(
bool div_by_zero_error = false;

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = 0;
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
@@ -73,8 +77,7 @@ Tensor& fmod_Tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
@@ -131,16 +134,19 @@ Tensor& fmod_Scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
return value;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
8 changes: 5 additions & 3 deletions kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
@@ -45,7 +45,10 @@ Tensor& maximum_out(
static constexpr const char op_name[] = "maximum.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return utils::max_override(val_a, val_b);
},
@@ -54,8 +57,7 @@ Tensor& maximum_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_minimum.cpp
Original file line number Diff line number Diff line change
@@ -45,17 +45,20 @@ Tensor& minimum_out(
static constexpr const char op_name[] = "minimum.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
return utils::min_override(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
18 changes: 11 additions & 7 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
@@ -52,7 +52,10 @@ Tensor& mul_out(
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
@@ -61,8 +64,7 @@ Tensor& mul_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
@@ -95,13 +97,15 @@ Tensor& mul_scalar_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b](const auto val_a) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
27 changes: 18 additions & 9 deletions kernels/portable/cpu/op_pow.cpp
Original file line number Diff line number Diff line change
@@ -53,17 +53,20 @@ Tensor& pow_Tensor_Tensor_out(
static constexpr const char op_name[] = "pow.Tensor_Tensor_out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
return std::pow(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
@@ -104,13 +107,16 @@ Tensor& pow_Tensor_Scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
// TODO: rewrite this to be vectorization-capable.
[val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
@@ -151,13 +157,16 @@ Tensor& pow_Scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
// TODO: rewrite this to be vectorization-capable.
[val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); },
ctx,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_remainder.cpp
Original file line number Diff line number Diff line change
@@ -53,9 +53,13 @@ Tensor& remainder_Tensor_out(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = 0;
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
@@ -71,8 +75,7 @@ Tensor& remainder_Tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
@@ -126,15 +129,18 @@ Tensor& remainder_Scalar_out(

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
return utils::remainder_override(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_rsub.cpp
Original file line number Diff line number Diff line change
@@ -52,15 +52,17 @@ Tensor& rsub_scalar_out(
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b, val_alpha](const auto val_a) {
return val_b - val_alpha * val_a;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
11 changes: 7 additions & 4 deletions kernels/portable/cpu/op_sigmoid.cpp
Original file line number Diff line number Diff line change
@@ -45,17 +45,20 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
static constexpr const char op_name[] = "sigmoid.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_in) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_in) -> CTYPE_COMPUTE {
// TODO: rewrite this to be vectorization-capable
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
return out_val;
},
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
20 changes: 12 additions & 8 deletions kernels/portable/cpu/op_sub.cpp
Original file line number Diff line number Diff line change
@@ -56,17 +56,19 @@ Tensor& sub_out(

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_alpha](const auto val_a, const auto val_b) {
return val_a - val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBF16,
b,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
@@ -110,15 +112,17 @@ Tensor& sub_scalar_out(
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b, val_alpha](const auto val_a) {
return val_a - val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
14 changes: 8 additions & 6 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
@@ -43,19 +43,21 @@ Tensor& where_out(
static constexpr const char op_name[] = "where.self_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a,
const CTYPE_COMPUTE val_b,
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
utils::apply_tritensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[](const auto val_a, const auto val_b, const auto val_c) {
return val_c ? val_a : val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
cond,
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
18 changes: 12 additions & 6 deletions kernels/portable/cpu/pattern/bitwise_op.h
Original file line number Diff line number Diff line change
@@ -80,15 +80,18 @@ Tensor& bitwise_tensor_out(

ET_SWITCH_INT_TYPES_AND(
Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
// TODO: rewrite this to be vectorization-capable.
BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value,
ctx,
a,
utils::SupportedTensorDtypes::INTB,
b,
utils::SupportedTensorDtypes::INTB,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
@@ -121,16 +124,19 @@ Tensor& bitwise_scalar_out(
ET_SWITCH_INT_TYPES_AND(
Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
return BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value(
val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::INTB,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
18 changes: 12 additions & 6 deletions kernels/portable/cpu/pattern/comparison_op.h
Original file line number Diff line number Diff line change
@@ -91,15 +91,18 @@ Tensor& comparison_tensor_out(
ScalarType compute_type = utils::get_compute_type(common_type);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
// TODO: rewrite this to be vectorization-capable.
ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value,
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
@@ -127,15 +130,18 @@ Tensor& comparison_scalar_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
return ComparisonFnForOp<CTYPE_COMPUTE, op_name>::value(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
9 changes: 6 additions & 3 deletions kernels/portable/cpu/pattern/logical_op.h
Original file line number Diff line number Diff line change
@@ -34,15 +34,18 @@ Tensor& logical_tensor_out(
InvalidArgument,
out);

utils::apply_bitensor_elementwise_fn<bool, op_name>(
utils::apply_bitensor_elementwise_fn<
bool,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
// TODO: rewrite this to be vectorization-capable.
fn,
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);

return out;
}
19 changes: 19 additions & 0 deletions kernels/portable/cpu/util/dtype_util.h
Original file line number Diff line number Diff line change
@@ -290,6 +290,25 @@ bool check_tensor_dtype(
SupportedTensorDtypes dtypes,
const ScalarType compute_type);

/// Return the one output type we are willing to emit specialized code
/// to handle, given a compute type of CTYPE_COMMON and supported
/// output types of out_dtypes.
template <typename CTYPE_COMPUTE>
inline constexpr ScalarType specialized_output_scalar_type(
SupportedTensorDtypes out_dtypes) {
switch (out_dtypes) {
case SupportedTensorDtypes::BOOL_OR_BYTE:
return ScalarType::Bool;
case SupportedTensorDtypes::REALHBBF16:
case SupportedTensorDtypes::REALHBF16:
case SupportedTensorDtypes::FLOATHBF16:
case SupportedTensorDtypes::INTB:
case SupportedTensorDtypes::SAME_AS_COMPUTE:
case SupportedTensorDtypes::SAME_AS_COMMON:
return CppTypeToScalarType<CTYPE_COMPUTE>::value;
}
}

} // namespace internal
} // namespace utils
} // namespace native
119 changes: 103 additions & 16 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
@@ -51,6 +51,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
}

namespace internal {
template <
typename CTYPE_COMPUTE,
typename CTYPE_OUT,
typename Op,
typename... Args>
inline void dtype_specialized_elementwise_fn_impl(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
constexpr auto kNumInputs = sizeof...(inputs);
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));

::executorch::extension::parallel_for(
0,
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

const auto range =
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto& indexes = *begin_it;
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
}
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
}
});
}

template <typename CTYPE_COMPUTE, typename Op, typename... Args>
inline bool validate_elementwise_fn_inputs(
const Op& compute_fun,
@@ -81,18 +119,12 @@ template <
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
inline void apply_elementwise_fn_generic_impl(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto kNumInputs = sizeof...(inputs);

struct InputInfo {
@@ -138,6 +170,63 @@ inline void apply_elementwise_fn(
});
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn_runtime_out_dtypes(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, inputs...);
}

template <
typename CTYPE_COMPUTE,
const char* op_name,
SupportedTensorDtypes out_dtypes,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
compute_fun, ctx, out, out_dtypes, inputs...);
if (!inputs_valid) {
return;
}

constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
const bool all_inputs_compute_dtype =
((inputs.first->scalar_type() == compute_type) && ...);

constexpr ScalarType out_specialized_scalar_type =
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes);
if (all_inputs_compute_dtype &&
out.scalar_type() == out_specialized_scalar_type) {
using CTYPE_OUT =
typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>(
compute_fun, ctx, out, inputs...);
return;
}

apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, inputs...);
}

/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
@@ -147,7 +236,7 @@ inline void apply_unitensor_elementwise_fn(
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

@@ -162,8 +251,8 @@ inline void apply_unitensor_elementwise_fn(
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun, ctx, out, std::make_pair(&a, a_dtypes));
}

/**
@@ -179,7 +268,7 @@ inline void apply_bitensor_elementwise_fn(
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
@@ -206,11 +295,10 @@ inline void apply_bitensor_elementwise_fn(
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
}
@@ -230,7 +318,7 @@ inline void apply_tritensor_elementwise_fn(
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>(
compute_fun,
ctx,
out,
@@ -275,11 +363,10 @@ inline void apply_tritensor_elementwise_fn(
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out) {
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>(
compute_fun,
ctx,
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes),
std::make_pair(&c, c_dtypes));