diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 810026d034c0..05f515e2e783 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Calculate the result of moe by summing up the partial results // from all selected experts. - m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); // Aligning the number of tokens to be processed by each expert such diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 43ddc79fcb81..9a8ac242af79 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck(): opcheck(torch.ops._moe_C.moe_align_block_size, (topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad)) + + +@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128]) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): + input = torch.randn((m, topk, k), device="cuda", dtype=dtype) + actual = torch.empty((m, k), device="cuda", dtype=dtype) + + expected = input.sum(dim=1) + torch.ops._moe_C.moe_sum(input, actual) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) + + opcheck(torch.ops._moe_C.moe_sum, (input, actual))