Skip to content

Commit f2fae59

Browse files
swolchokkirklandsign
authored andcommitted
Add portable ELU implementation + test (#9520)
1 parent a5353ae commit f2fae59

File tree

7 files changed

+173
-0
lines changed

7 files changed

+173
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@
141141

142142
- op: div.out_mode
143143

144+
- op: elu.out
145+
144146
- op: embedding.out
145147

146148
- op: empty.out

kernels/portable/cpu/op_elu.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
#include <type_traits>
11+
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
16+
namespace torch::executor::native {
17+
18+
Tensor& elu_out(
19+
KernelRuntimeContext& ctx,
20+
const Tensor& in,
21+
const Scalar& alpha,
22+
const Scalar& scale,
23+
const Scalar& input_scale,
24+
Tensor& out) {
25+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
26+
ET_KERNEL_CHECK(
27+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
28+
29+
ET_KERNEL_CHECK(
30+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
31+
32+
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);
33+
34+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
35+
36+
static constexpr const char op_name[] = "elu.out";
37+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
38+
using MathT = std::
39+
conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
40+
MathT math_alpha = 0;
41+
MathT math_scale = 0;
42+
MathT math_input_scale = 0;
43+
ET_EXTRACT_SCALAR(alpha, math_alpha);
44+
ET_EXTRACT_SCALAR(scale, math_scale);
45+
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
46+
const auto negcoef = math_alpha * math_scale;
47+
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
48+
[negcoef, math_scale, math_input_scale](auto x) {
49+
return MathT(x) <= MathT(0)
50+
? std::expm1(MathT(x) * math_input_scale) * negcoef
51+
: MathT(x) * math_scale;
52+
},
53+
ctx,
54+
in,
55+
utils::SupportedTensorDtypes::FLOATHBF16,
56+
out,
57+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
58+
});
59+
return out;
60+
}
61+
62+
} // namespace torch::executor::native

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,11 @@
329329
- arg_meta: null
330330
kernel_name: torch::executor::eq_tensor_out
331331

332+
- op: elu.out
333+
kernels:
334+
- arg_meta: null
335+
kernel_name: torch::executor::elu_out
336+
332337
- op: erf.out
333338
kernels:
334339
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ set(all_test_sources
135135
"op_detach_copy_test.cpp"
136136
"op_diagonal_copy_test.cpp"
137137
"op_div_test.cpp"
138+
"op_elu_test.cpp"
138139
"op_embedding_test.cpp"
139140
"op_empty_test.cpp"
140141
"op_eq_test.cpp"

kernels/test/op_elu_test.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/kernels/test/supported_features.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15+
16+
#include <gtest/gtest.h>
17+
18+
using executorch::aten::Scalar;
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::string_view;
21+
using executorch::aten::Tensor;
22+
using torch::executor::testing::TensorFactory;
23+
24+
class OpEluTest : public OperatorTest {
25+
protected:
26+
Tensor& op_elu_out(
27+
const Tensor& self,
28+
const Scalar& alpha,
29+
const Scalar& scale,
30+
const Scalar& input_scale,
31+
Tensor& out) {
32+
return torch::executor::aten::elu_outf(
33+
context_, self, alpha, scale, input_scale, out);
34+
}
35+
36+
template <ScalarType DTYPE>
37+
void test_elu_execution() {
38+
TensorFactory<DTYPE> tf;
39+
40+
const std::vector<int32_t> sizes = {3, 2};
41+
42+
Tensor in = tf.make(sizes, /*data=*/{-0.125, -0.25, -1, 0, 1.25, 100});
43+
44+
Tensor out = tf.zeros(sizes);
45+
46+
// Run full gelu.
47+
op_elu_out(in, 1.25, 1, 1, out);
48+
49+
// Check that it matches the expected output.
50+
EXPECT_TENSOR_CLOSE(
51+
out,
52+
tf.make(
53+
sizes,
54+
/*data=*/
55+
{-0.146879, -0.276499, -0.790151, 0, 1.25, 100}));
56+
}
57+
58+
template <ScalarType DTYPE>
59+
void test_integer_elu_dies() {
60+
TensorFactory<DTYPE> tf;
61+
62+
Tensor in = tf.ones({1});
63+
Tensor out = tf.ones({1});
64+
ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(in, 1, 1, 1, out));
65+
}
66+
};
67+
68+
TEST_F(OpEluTest, Basic) {
69+
#define TEST_ENTRY(ctype, dtype) test_elu_execution<ScalarType::dtype>();
70+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
71+
#undef TEST_ENTRY
72+
}
73+
74+
TEST_F(OpEluTest, UnhandledDtypeDies) {
75+
#define TEST_ENTRY(ctype, dtype) test_integer_elu_dies<ScalarType::dtype>();
76+
ET_FORALL_INT_TYPES(TEST_ENTRY);
77+
#undef TEST_ENTRY
78+
}
79+
80+
TEST_F(OpEluTest, MismatchedOutputDtypeDies) {
81+
// Two different dtypes. This test uses two types with the same size to
82+
// demonstrate that the ScalarType itself matters, not the size of the
83+
// tensor elements.
84+
TensorFactory<ScalarType::Float> tf_float;
85+
TensorFactory<ScalarType::Double> tf_double;
86+
87+
const std::vector<int32_t> sizes = {2, 2};
88+
89+
Tensor a = tf_float.ones(sizes);
90+
91+
// Destination with a dtype different from the input.
92+
Tensor out = tf_double.zeros(sizes);
93+
94+
ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out));
95+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def define_common_targets():
215215
_common_op_test("op_detach_copy_test", ["aten", "portable"])
216216
_common_op_test("op_diagonal_copy_test", ["aten", "portable"])
217217
_common_op_test("op_div_test", ["aten", "portable", "optimized"])
218+
_common_op_test("op_elu_test", ["aten", "portable"])
218219
_common_op_test("op_embedding_test", ["aten", "portable"])
219220
_common_op_test("op_empty_test", ["aten", "portable"])
220221
_common_op_test("op_eq_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,13 @@ ATEN_OPS = (
482482
":scalar_utils",
483483
],
484484
),
485+
op_target(
486+
name = "op_elu",
487+
deps = [
488+
":scalar_utils",
489+
"//executorch/kernels/portable/cpu/util:elementwise_util",
490+
],
491+
),
485492
op_target(
486493
name = "op_embedding",
487494
deps = [

0 commit comments

Comments
 (0)