Skip to content

Complex Support: bmm #10197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions kernels/optimized/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
@@ -6,9 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

#include <executorch/kernels/optimized/blas/CPUBlas.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

// Performs a batch matrix-matrix product of matrices stored in input and mat2.

@@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) {

// bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
Tensor& opt_bmm_out(
KernelRuntimeContext& context,
KernelRuntimeContext& ctx,
const Tensor& self,
const Tensor& mat2,
Tensor& out) {
(void)context;
(void)ctx;

ET_KERNEL_CHECK(
context,
ctx,
resize_out_tensor(self, mat2, out) == Error::Ok,
InvalidArgument,
out);
ET_KERNEL_CHECK(
context, check_bmm_out_args(self, mat2, out), InvalidArgument, out);

#define BMM_TENSOR(ctype, dtype) \
case ScalarType::dtype: \
bmm_kernel<ctype>(self, mat2, out); \
break;

auto scalar_type = self.scalar_type();
switch (scalar_type) {
ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR)
default:
ET_CHECK_MSG(
false, "Unhandled dtype %" PRId8, static_cast<int8_t>(scalar_type));
ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out);

constexpr auto name = "bmm.out";
auto self_type = self.scalar_type();

if (executorch::runtime::isComplexType(self_type)) {
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(self, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {
bmm_kernel<CTYPE>(self, mat2, out);
});
}
#undef BMM_TENSOR

return out;
}
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ _OPTIMIZED_ATEN_OPS = (
name = "op_bmm",
deps = [
"//executorch/kernels/optimized:libblas",
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
],
),
op_target(
30 changes: 11 additions & 19 deletions kernels/portable/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
*/

#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
@@ -37,26 +36,19 @@ Tensor& bmm_out(
InvalidArgument,
out);

ET_SWITCH_REAL_TYPES_AND(
Half, in.scalar_type(), ctx, "bmm.out", CTYPE, [&]() {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
constexpr auto name = "bmm.out";

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);
auto in_type = in.scalar_type();

for (int i = 0; i < batch_size; ++i) {
const CTYPE* in_data_offset = in_data + i * m * n;
const CTYPE* mat2_data_offset = mat2_data + i * n * p;
CTYPE* out_data_offset = out_data + i * m * p;

vec_matmul<CTYPE>(
out_data_offset, in_data_offset, mat2_data_offset, m, n, p);
}
});
if (executorch::runtime::isComplexType(in_type)) {
ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
}

return out;
}
31 changes: 31 additions & 0 deletions kernels/portable/cpu/util/matmul_ops_util.h
Original file line number Diff line number Diff line change
@@ -45,5 +45,36 @@ void get_linear_out_target_size(
Tensor::SizesType* out_sizes,
size_t* out_ndim);

namespace internal {

template <typename CTYPE>
void bmm_out_impl(const Tensor& in, const Tensor& mat2, Tensor& out) {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

for (int b = 0; b < batch_size; ++b) {
const CTYPE* in_data_offset = in_data + b * m * n;
const CTYPE* mat2_data_offset = mat2_data + b * n * p;
CTYPE* out_data_offset = out_data + b * m * p;

for (const auto i : c10::irange(m)) {
for (const auto j : c10::irange(p)) {
CTYPE sum = static_cast<CTYPE>(0.0);
for (const auto k : c10::irange(n)) {
sum += in_data_offset[i * n + k] * mat2_data_offset[k * p + j];
}
out_data_offset[i * p + j] = sum;
}
}
}
}

} // namespace internal
} // namespace executor
} // namespace torch
67 changes: 66 additions & 1 deletion kernels/test/op_bmm_test.cpp
Original file line number Diff line number Diff line change
@@ -43,6 +43,61 @@ class OpBmmOutTest : public OperatorTest {

EXPECT_TENSOR_EQ(out, expected);
}

template <typename CTYPE, ScalarType DTYPE>
void test_complex_dtype() {
TensorFactory<DTYPE> tf;
Tensor x = tf.make(
{2, 2, 3},
{CTYPE(1, 1),
CTYPE(2, 2),
CTYPE(3, 3),
CTYPE(4, 4),
CTYPE(5, 5),
CTYPE(6, 6),
CTYPE(7, 7),
CTYPE(8, 8),
CTYPE(9, 9),
CTYPE(10, 10),
CTYPE(11, 11),
CTYPE(12, 12)});
Tensor y = tf.make(
{2, 3, 2},
{CTYPE(2, 1),
CTYPE(4, 2),
CTYPE(6, 3),
CTYPE(8, 4),
CTYPE(10, 5),
CTYPE(12, 6),
CTYPE(14, 7),
CTYPE(16, 8),
CTYPE(18, 9),
CTYPE(20, 10),
CTYPE(22, 11),
CTYPE(24, 12)});
Tensor out = tf.make(
{2, 2, 2},
{CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0),
CTYPE(0, 0)});
Tensor expected = tf.make(
{2, 2, 2},
{CTYPE(22, 66),
CTYPE(28, 84),
CTYPE(49, 147),
CTYPE(64, 192),
CTYPE(220, 660),
CTYPE(244, 732),
CTYPE(301, 903),
CTYPE(334, 1002)});
op_bmm_out(x, y, out);
EXPECT_TENSOR_CLOSE(out, expected);
}
};

TEST_F(OpBmmOutTest, OutputDim) {
@@ -132,7 +187,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) {

/// A generic smoke test that works for any dtype that supports ones() and
/// zeros().
TEST_F(OpBmmOutTest, AllDtypesSupported) {
TEST_F(OpBmmOutTest, AllRealDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
@@ -141,6 +196,16 @@ TEST_F(OpBmmOutTest, AllDtypesSupported) {
// for those types.
}

TEST_F(OpBmmOutTest, AllComplexDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_complex_dtype<ctype, ScalarType::dtype>();
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
ET_FORALL_COMPLEX_TYPES(TEST_ENTRY);
} else {
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
}
#undef TEST_ENTRY
}

TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) {
TensorFactory<ScalarType::Int> tf;

Original file line number Diff line number Diff line change
@@ -372,7 +372,6 @@ ATEN_OPS = (
name = "op_bmm",
deps = [
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
":vec_ops",
],
),
op_target(