From b00ce695d731631f38b71d2ddb7285ba875fa3df Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 6 Feb 2025 08:51:04 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- .gitmodules | 3 + kernels/aten/functions.yaml | 2 + kernels/optimized/CMakeLists.txt | 1 + kernels/optimized/cpu/op_fft_r2c.cpp | 187 ++++++++++++++++++ kernels/optimized/cpu/targets.bzl | 4 + kernels/optimized/optimized-oss.yaml | 5 + kernels/optimized/optimized.yaml | 5 + kernels/test/CMakeLists.txt | 1 + kernels/test/op_fft_r2c_test.cpp | 158 +++++++++++++++ .../exec_aten/testing_util/tensor_util.cpp | 12 ++ third-party/pocketfft | 1 + 11 files changed, 379 insertions(+) create mode 100644 kernels/optimized/cpu/op_fft_r2c.cpp create mode 100644 kernels/test/op_fft_r2c_test.cpp create mode 160000 third-party/pocketfft diff --git a/.gitmodules b/.gitmodules index 1468cca6363..f7da7e771fb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -67,3 +67,6 @@ [submodule "backends/cadence/utils/FACTO"] path = backends/cadence/utils/FACTO url = https://github.com/pytorch-labs/FACTO.git +[submodule "third-party/pocketfft"] + path = third-party/pocketfft + url = https://github.com/mreineck/pocketfft diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 833b37cfac0..463ef0f9d32 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -6,6 +6,8 @@ - op: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out +- op: _fft_r2c.out + - op: _linalg_det.result - op: _linalg_svd.U diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index abdeeb73453..0826fe73af8 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -60,6 +60,7 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) +target_include_directories(optimized_kernels PRIVATE "${EXECUTORCH_ROOT}/third-party/pocketfft") target_link_libraries( optimized_kernels PRIVATE executorch_core cpublas extension_threadpool ) diff --git a/kernels/optimized/cpu/op_fft_r2c.cpp b/kernels/optimized/cpu/op_fft_r2c.cpp new file mode 100644 index 00000000000..07ef46f1598 --- /dev/null +++ b/kernels/optimized/cpu/op_fft_r2c.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +namespace torch::executor::native { + +// TODO: contents of this anonymous namespace are copy/pasted from +// PyTorch core (aten/src/ATen/native/mkl/SpectralOps.cpp). Small +// portions (the parts that don't depend on Tensor) could be reused; +// refactor to enable that once we can share headers from PyTorch +// core. +namespace { +pocketfft::stride_t stride_from_tensor(const Tensor& t) { + pocketfft::stride_t stride(t.strides().begin(), t.strides().end()); + for (auto& s : stride) { + s *= t.element_size(); + } + return stride; +} + +pocketfft::shape_t shape_from_tensor(const Tensor& t) { + return pocketfft::shape_t(t.sizes().begin(), t.sizes().end()); +} + +// NOTE: The reinterpret_cast in tensor_cdata is UB, but it's what +// PyTorch core does and I'm not aware of a portable way to do this +// that doesn't rely on UB. +template +inline std::complex* tensor_cdata(Tensor& t) { + return reinterpret_cast*>( + t.data_ptr>()); +} + +template +inline const std::complex* tensor_cdata(const Tensor& t) { + return reinterpret_cast*>( + t.const_data_ptr>()); +} + +// NOTE: in particular this is in ATen/native/SpectralOpsUtils.h and +// could be shared immediately. +enum class fft_norm_mode { + none, // No normalization + by_root_n, // Divide by sqrt(signal_size) + by_n, // Divide by signal_size +}; + +// NOTE: slight fork from upstream PyTorch to use ET_KERNEL_CHECK; +// upstream with TORCH_CHECK will be fine to use once we have code +// sharing. +template +std::optional +compute_fct(KernelRuntimeContext& ctx, int64_t size, int64_t normalization) { + constexpr auto one = static_cast(1); + switch (static_cast(normalization)) { + case fft_norm_mode::none: + return one; + case fft_norm_mode::by_n: + return one / static_cast(size); + case fft_norm_mode::by_root_n: + return one / std::sqrt(static_cast(size)); + } + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + std::nullopt, + "Unsupported normalization type: %" PRId64, + normalization); +} + +template +std::optional compute_fct( + KernelRuntimeContext& ctx, + const Tensor& t, + IntArrayRef dim, + int64_t normalization) { + if (static_cast(normalization) == fft_norm_mode::none) { + return static_cast(1); + } + const auto& sizes = t.sizes(); + int64_t n = 1; + for (auto idx : dim) { + n *= sizes[idx]; + } + return compute_fct(ctx, n, normalization); +} + +} // namespace + +Tensor& opt_fft_r2c_out( + KernelRuntimeContext& ctx, + const Tensor& in, + IntArrayRef dim, + int64_t normalization, + bool onesided, + Tensor& out) { + auto in_sizes = in.sizes(); + ET_KERNEL_CHECK(ctx, in.dim() <= kTensorDimensionLimit, InvalidArgument, out); + + std::array out_sizes_storage; + executorch::runtime::Span out_sizes( + out_sizes_storage.data(), in_sizes.size()); + std::copy(in_sizes.begin(), in_sizes.end(), out_sizes.begin()); + ET_KERNEL_CHECK(ctx, !dim.empty(), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK_MSG( + ctx, + onesided, + InvalidArgument, + out, + "onesided=False is not supported yet in _fft_r2c"); + + ET_KERNEL_CHECK_MSG( + ctx, + out.scalar_type() == executorch::runtime::toComplexType(in.scalar_type()), + InvalidArgument, + out, + "the output type for _fft_r2c must be the Complex type corresponding to the input type"); + + for (auto d : dim) { + ET_KERNEL_CHECK_MSG( + ctx, + d >= 0 && d < in.dim(), + InvalidArgument, + out, + "dims must be in bounds (got %" PRId64 ")", + d); + } + + if (onesided) { + out_sizes[dim.back()] = out_sizes[dim.back()] / 2 + 1; + } + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor( + out, + executorch::runtime::ArrayRef( + out_sizes.data(), out_sizes.size())) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor (last dim %d).", + out_sizes[dim.back()]); + + pocketfft::shape_t axes(dim.begin(), dim.end()); + auto in_shape = shape_from_tensor(in); + // TODO: if arbitrary strides are a possibility, we need to validate + // these, because pocketfft README says "Strides that lead to + // multiple accesses of the same memory address are not allowed." + auto in_stride = stride_from_tensor(in); + auto out_stride = stride_from_tensor(out); + // NOTE: as of this writing, upstream PyTorch only supports + // float/double, so we follow suit. + ET_SWITCH_FLOAT_TYPES(in.scalar_type(), ctx, "_fft_r2c.out", CTYPE_IN, [&] { + auto fct = compute_fct(ctx, in, dim, normalization); + if (!fct) { + // Check failed, just bail out of the lambda. + return; + } + pocketfft::r2c( + in_shape, + in_stride, + out_stride, + axes, + true, + in.const_data_ptr(), + tensor_cdata(out), + *fct); + + // TODO: fill with conjugate symmetry if not onesided; see + // ATen/native/mkl/SpectralOps.cpp + }); + return out; +} +} // namespace torch::executor::native diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index d97e1eb5122..535555b5dcc 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -25,6 +25,10 @@ _OPTIMIZED_ATEN_OPS = ( ], ), op_target(name = "op_exp"), + op_target( + name = "op_fft_r2c", + deps = [] if runtime.is_oss else ["fbsource//third-party/pocketfft:pocketfft"], + ), op_target(name = "op_sigmoid"), op_target( name = "op_gelu", diff --git a/kernels/optimized/optimized-oss.yaml b/kernels/optimized/optimized-oss.yaml index 52262e2dd53..37a2730f928 100644 --- a/kernels/optimized/optimized-oss.yaml +++ b/kernels/optimized/optimized-oss.yaml @@ -5,6 +5,11 @@ # log_softmax, due to the OSS build not currently including sleef. # TODO (T183193812) +- op: _fft_r2c.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_fft_r2c_out + - op: add.out kernels: - arg_meta: null diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index ea07126a3b9..fd5143b1511 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -2,6 +2,11 @@ # # This yaml file contains operators that have optimized kernels available. +- op: _fft_r2c.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_fft_r2c_out + - op: _log_softmax.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 1bd63a2a5fe..32e53ec8ff1 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -265,6 +265,7 @@ set(_optimized_kernels_test_sources "op_bmm_test.cpp" "op_div_test.cpp" "op_exp_test.cpp" + "op_fft_r2c_test.cpp" "op_gelu_test.cpp" "op_le_test.cpp" "op_log_softmax_test.cpp" diff --git a/kernels/test/op_fft_r2c_test.cpp b/kernels/test/op_fft_r2c_test.cpp new file mode 100644 index 00000000000..629087da491 --- /dev/null +++ b/kernels/test/op_fft_r2c_test.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include +#include + +#include + +using executorch::aten::IntArrayRef; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::testing::TensorFactory; + +class OpFftR2cOutTest : public OperatorTest { + protected: + Tensor& op_fft_r2c_out( + const Tensor& in, + IntArrayRef dim, + int64_t normalization, + bool onesided, + Tensor& out) { + return torch::executor::aten::_fft_r2c_outf( + context_, in, dim, normalization, onesided, out); + } + + template < + class CTYPE, + executorch::aten::ScalarType DTYPE, + bool expect_failure = false> + void test_dtype(int64_t norm, int64_t dim = 1, bool onesided = true) { + TensorFactory tf; + constexpr auto DTYPE_OUT = executorch::runtime::toComplexType(DTYPE); + TensorFactory tf_out; + + using CTYPE_OUT = + typename executorch::runtime::ScalarTypeToCppType::type; + + Tensor in = tf.make({2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + Tensor out = tf_out.full({2, 3}, CTYPE_OUT{0, 0}); + + op_fft_r2c_out(in, {dim}, norm, onesided, out); + + double norm_factor = 1; + if (norm == 1) { + norm_factor = 2; + } else if (norm == 2) { + norm_factor = 4; + } + std::vector expected_data = { + CTYPE_OUT{6, 0}, + CTYPE_OUT{-2, 2}, + CTYPE_OUT{-2, 0}, + CTYPE_OUT{6, 0}, + CTYPE_OUT{-2, 2}, + CTYPE_OUT{-2, 0}}; + for (auto& elem : expected_data) { + elem.real_ /= norm_factor; + elem.imag_ /= norm_factor; + } + Tensor expected = tf_out.make({2, 3}, expected_data); + + if (!expect_failure) { + EXPECT_TENSOR_CLOSE(out, expected); + } + } + + template + void test_dtype_multiple_axes(bool onesided = true) { + TensorFactory tf; + constexpr auto DTYPE_OUT = executorch::runtime::toComplexType(DTYPE); + TensorFactory tf_out; + + using CTYPE_OUT = + typename executorch::runtime::ScalarTypeToCppType::type; + + Tensor in = + tf.make({4, 4}, {0, 1, 2, 3, 3, 2, 1, 0, 2, 3, 0, 1, 1, 2, 3, 0}); + Tensor out = tf_out.full({4, 3}, CTYPE_OUT{0, 0}); + + std::array dim = {0, 1}; + op_fft_r2c_out(in, dim, 0, onesided, out); + + std::vector expected_data = { + CTYPE_OUT{24, 0}, + CTYPE_OUT{0, -4}, + CTYPE_OUT{0, 0}, + + CTYPE_OUT{0, 0}, + CTYPE_OUT{-4, 0}, + CTYPE_OUT{0, 0}, + + CTYPE_OUT{0, 0}, + CTYPE_OUT{0, 4}, + CTYPE_OUT{-8, 0}, + + CTYPE_OUT{0, 0}, + CTYPE_OUT{-4, 8}, + CTYPE_OUT{0, 0}, + }; + Tensor expected = tf_out.make({4, 3}, expected_data); + + EXPECT_TENSOR_CLOSE(out, expected); + } +}; + +TEST_F(OpFftR2cOutTest, AllDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) \ + test_dtype(0); \ + test_dtype(1); \ + test_dtype(2); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpFftR2cOutTest, MultipleDims) { +#define TEST_ENTRY(ctype, dtype) \ + test_dtype_multiple_axes(); + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpFftR2cOutTest, InvalidNorm) { + auto invalid_norm = [this](int64_t norm) { + test_dtype(norm); + }; + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(3)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(4)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(-1)); + ET_EXPECT_KERNEL_FAILURE(context_, invalid_norm(9999999)); +} + +TEST_F(OpFftR2cOutTest, InvalidDim) { + auto negative_dim = [this]() { + test_dtype(0, -1); + test_dtype(0, 3); + test_dtype(0, 9001); + }; + ET_EXPECT_KERNEL_FAILURE(context_, negative_dim()); +} + +// TODO: support this and patch test accordingly! +TEST_F(OpFftR2cOutTest, TwoSidedIsNotSupported) { + auto twosided = [this]() { + test_dtype( + 0, 1, false); + }; + ET_EXPECT_KERNEL_FAILURE(context_, twosided()); +} diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index d1ea069b6e8..56d7382aab2 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -254,6 +254,17 @@ std::ostream& print_data(std::ostream& os, const T* data, size_t numel) { return os; } +template +std::ostream& +print_data(std::ostream& os, const etensor::complex* data, size_t numel) { + for (auto i = 0; i < numel; i++) { + os << data[i].real_ << " + " << data[i].imag_ << "j"; + if (i < numel - 1) { + os << ", "; + } + } + return os; +} /** * Prints the elements of `data` to the stream as comma-separated strings. * @@ -297,6 +308,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) { switch (t.scalar_type()) { ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, PRINT_CASE) + ET_FORALL_COMPLEX_TYPES(PRINT_CASE) default: ET_CHECK_MSG( false, diff --git a/third-party/pocketfft b/third-party/pocketfft new file mode 160000 index 00000000000..0fa0ef591e3 --- /dev/null +++ b/third-party/pocketfft @@ -0,0 +1 @@ +Subproject commit 0fa0ef591e38c2758e3184c6c23e497b9f732ffa From e23ad5140d69b0cdab802a45ab5851a1200e2802 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 6 Feb 2025 10:11:59 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/optimized/cpu/op_fft_r2c.cpp | 2 ++ kernels/optimized/cpu/targets.bzl | 2 +- kernels/test/op_fft_r2c_test.cpp | 13 +++++++++++++ kernels/test/targets.bzl | 1 + 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/kernels/optimized/cpu/op_fft_r2c.cpp b/kernels/optimized/cpu/op_fft_r2c.cpp index 07ef46f1598..45d3d9acb42 100644 --- a/kernels/optimized/cpu/op_fft_r2c.cpp +++ b/kernels/optimized/cpu/op_fft_r2c.cpp @@ -11,6 +11,8 @@ #include +#include + namespace torch::executor::native { // TODO: contents of this anonymous namespace are copy/pasted from diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 535555b5dcc..8fcaf210246 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -27,7 +27,7 @@ _OPTIMIZED_ATEN_OPS = ( op_target(name = "op_exp"), op_target( name = "op_fft_r2c", - deps = [] if runtime.is_oss else ["fbsource//third-party/pocketfft:pocketfft"], + deps = [] if runtime.is_oss else ["fbsource//third-party/pocket_fft:pocketfft"], ), op_target(name = "op_sigmoid"), op_target( diff --git a/kernels/test/op_fft_r2c_test.cpp b/kernels/test/op_fft_r2c_test.cpp index 629087da491..8730053bdc0 100644 --- a/kernels/test/op_fft_r2c_test.cpp +++ b/kernels/test/op_fft_r2c_test.cpp @@ -8,6 +8,7 @@ #include // Declares the operator #include +#include #include #include #include @@ -130,6 +131,10 @@ TEST_F(OpFftR2cOutTest, MultipleDims) { } TEST_F(OpFftR2cOutTest, InvalidNorm) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen MKL path does not validate norm"; + return; + } auto invalid_norm = [this](int64_t norm) { test_dtype(norm); }; @@ -140,6 +145,10 @@ TEST_F(OpFftR2cOutTest, InvalidNorm) { } TEST_F(OpFftR2cOutTest, InvalidDim) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen fails UBSAN"; + return; + } auto negative_dim = [this]() { test_dtype(0, -1); test_dtype(0, 3); @@ -150,6 +159,10 @@ TEST_F(OpFftR2cOutTest, InvalidDim) { // TODO: support this and patch test accordingly! TEST_F(OpFftR2cOutTest, TwoSidedIsNotSupported) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports two-sided"; + return; + } auto twosided = [this]() { test_dtype( 0, 1, false); diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index db14173050b..60c59374174 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -244,6 +244,7 @@ def define_common_targets(): _common_op_test("op_exp_test", ["aten", "portable", "optimized"]) _common_op_test("op_expand_copy_test", ["aten", "portable"]) _common_op_test("op_expm1_test", ["aten", "portable"]) + _common_op_test("op_fft_r2c_test", ["aten", "optimized"]) _common_op_test("op_fill_test", ["aten", "portable"]) _common_op_test("op_flip_test", ["aten", "portable"]) _common_op_test("op_floor_divide_test", ["aten", "portable"])