From e9a4489a1639dfd26766c05ac46d49c978396138 Mon Sep 17 00:00:00 2001
From: Scott Wolchok <swolchok@meta.com>
Date: Fri, 21 Mar 2025 17:29:56 -0700
Subject: [PATCH] Update

[ghstack-poisoned]
---
 kernels/aten/functions.yaml                   |  2 +
 kernels/portable/cpu/op_elu.cpp               | 62 ++++++++++++
 kernels/portable/functions.yaml               |  5 +
 kernels/test/CMakeLists.txt                   |  1 +
 kernels/test/op_elu_test.cpp                  | 95 +++++++++++++++++++
 kernels/test/targets.bzl                      |  1 +
 .../kernels/portable/op_registration_util.bzl |  7 ++
 7 files changed, 173 insertions(+)
 create mode 100644 kernels/portable/cpu/op_elu.cpp
 create mode 100644 kernels/test/op_elu_test.cpp

diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml
index 7069f9140ab..a8fa6611478 100644
--- a/kernels/aten/functions.yaml
+++ b/kernels/aten/functions.yaml
@@ -141,6 +141,8 @@
 
 - op: div.out_mode
 
+- op: elu.out
+
 - op: embedding.out
 
 - op: empty.out
diff --git a/kernels/portable/cpu/op_elu.cpp b/kernels/portable/cpu/op_elu.cpp
new file mode 100644
index 00000000000..d4846fb1bfb
--- /dev/null
+++ b/kernels/portable/cpu/op_elu.cpp
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <cmath>
+#include <type_traits>
+
+#include <executorch/kernels/portable/cpu/scalar_utils.h>
+#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
+#include <executorch/runtime/kernel/kernel_includes.h>
+
+namespace torch::executor::native {
+
+Tensor& elu_out(
+    KernelRuntimeContext& ctx,
+    const Tensor& in,
+    const Scalar& alpha,
+    const Scalar& scale,
+    const Scalar& input_scale,
+    Tensor& out) {
+  ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
+  ET_KERNEL_CHECK(
+      ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
+
+  ET_KERNEL_CHECK(
+      ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
+
+  ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);
+
+  ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
+
+  static constexpr const char op_name[] = "elu.out";
+  ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
+    using MathT = std::
+        conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
+    MathT math_alpha = 0;
+    MathT math_scale = 0;
+    MathT math_input_scale = 0;
+    ET_EXTRACT_SCALAR(alpha, math_alpha);
+    ET_EXTRACT_SCALAR(scale, math_scale);
+    ET_EXTRACT_SCALAR(input_scale, math_input_scale);
+    const auto negcoef = math_alpha * math_scale;
+    utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
+        [negcoef, math_scale, math_input_scale](auto x) {
+          return MathT(x) <= MathT(0)
+              ? std::expm1(MathT(x) * math_input_scale) * negcoef
+              : MathT(x) * math_scale;
+        },
+        ctx,
+        in,
+        utils::SupportedTensorDtypes::FLOATHBF16,
+        out,
+        utils::SupportedTensorDtypes::SAME_AS_COMMON);
+  });
+  return out;
+}
+
+} // namespace torch::executor::native
diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml
index 29dfe8b1a0c..5e45a210a70 100644
--- a/kernels/portable/functions.yaml
+++ b/kernels/portable/functions.yaml
@@ -329,6 +329,11 @@
     - arg_meta: null
       kernel_name: torch::executor::eq_tensor_out
 
+- op: elu.out
+  kernels:
+    - arg_meta: null
+      kernel_name: torch::executor::elu_out
+
 - op: erf.out
   kernels:
     - arg_meta: null
diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt
index b9f48f0c9a1..42578acbedd 100644
--- a/kernels/test/CMakeLists.txt
+++ b/kernels/test/CMakeLists.txt
@@ -135,6 +135,7 @@ set(all_test_sources
     "op_detach_copy_test.cpp"
     "op_diagonal_copy_test.cpp"
     "op_div_test.cpp"
+    "op_elu_test.cpp"
     "op_embedding_test.cpp"
     "op_empty_test.cpp"
     "op_eq_test.cpp"
diff --git a/kernels/test/op_elu_test.cpp b/kernels/test/op_elu_test.cpp
new file mode 100644
index 00000000000..73ee8ac31a7
--- /dev/null
+++ b/kernels/test/op_elu_test.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
+#include <executorch/kernels/test/TestUtil.h>
+#include <executorch/kernels/test/supported_features.h>
+#include <executorch/runtime/core/exec_aten/exec_aten.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
+
+#include <gtest/gtest.h>
+
+using executorch::aten::Scalar;
+using executorch::aten::ScalarType;
+using executorch::aten::string_view;
+using executorch::aten::Tensor;
+using torch::executor::testing::TensorFactory;
+
+class OpEluTest : public OperatorTest {
+ protected:
+  Tensor& op_elu_out(
+      const Tensor& self,
+      const Scalar& alpha,
+      const Scalar& scale,
+      const Scalar& input_scale,
+      Tensor& out) {
+    return torch::executor::aten::elu_outf(
+        context_, self, alpha, scale, input_scale, out);
+  }
+
+  template <ScalarType DTYPE>
+  void test_elu_execution() {
+    TensorFactory<DTYPE> tf;
+
+    const std::vector<int32_t> sizes = {3, 2};
+
+    Tensor in = tf.make(sizes, /*data=*/{-0.125, -0.25, -1, 0, 1.25, 100});
+
+    Tensor out = tf.zeros(sizes);
+
+    // Run full gelu.
+    op_elu_out(in, 1.25, 1, 1, out);
+
+    // Check that it matches the expected output.
+    EXPECT_TENSOR_CLOSE(
+        out,
+        tf.make(
+            sizes,
+            /*data=*/
+            {-0.146879, -0.276499, -0.790151, 0, 1.25, 100}));
+  }
+
+  template <ScalarType DTYPE>
+  void test_integer_elu_dies() {
+    TensorFactory<DTYPE> tf;
+
+    Tensor in = tf.ones({1});
+    Tensor out = tf.ones({1});
+    ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(in, 1, 1, 1, out));
+  }
+};
+
+TEST_F(OpEluTest, Basic) {
+#define TEST_ENTRY(ctype, dtype) test_elu_execution<ScalarType::dtype>();
+  ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
+#undef TEST_ENTRY
+}
+
+TEST_F(OpEluTest, UnhandledDtypeDies) {
+#define TEST_ENTRY(ctype, dtype) test_integer_elu_dies<ScalarType::dtype>();
+  ET_FORALL_INT_TYPES(TEST_ENTRY);
+#undef TEST_ENTRY
+}
+
+TEST_F(OpEluTest, MismatchedOutputDtypeDies) {
+  // Two different dtypes. This test uses two types with the same size to
+  // demonstrate that the ScalarType itself matters, not the size of the
+  // tensor elements.
+  TensorFactory<ScalarType::Float> tf_float;
+  TensorFactory<ScalarType::Double> tf_double;
+
+  const std::vector<int32_t> sizes = {2, 2};
+
+  Tensor a = tf_float.ones(sizes);
+
+  // Destination with a dtype different from the input.
+  Tensor out = tf_double.zeros(sizes);
+
+  ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out));
+}
diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl
index 18ab0ac2e28..3824551a46b 100644
--- a/kernels/test/targets.bzl
+++ b/kernels/test/targets.bzl
@@ -215,6 +215,7 @@ def define_common_targets():
     _common_op_test("op_detach_copy_test", ["aten", "portable"])
     _common_op_test("op_diagonal_copy_test", ["aten", "portable"])
     _common_op_test("op_div_test", ["aten", "portable", "optimized"])
+    _common_op_test("op_elu_test", ["aten", "portable"])
     _common_op_test("op_embedding_test", ["aten", "portable"])
     _common_op_test("op_empty_test", ["aten", "portable"])
     _common_op_test("op_eq_test", ["aten", "portable"])
diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl
index b56413b92f4..a1ffdc1eed3 100644
--- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl
+++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl
@@ -482,6 +482,13 @@ ATEN_OPS = (
             ":scalar_utils",
         ],
     ),
+    op_target(
+        name = "op_elu",
+        deps = [
+            ":scalar_utils",
+            "//executorch/kernels/portable/cpu/util:elementwise_util",
+        ],
+    ),
     op_target(
         name = "op_embedding",
         deps = [