diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 474745f9481..ce420901e31 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,6 +29,7 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), + (1024 * 128, 1024, 1024), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 7238813a299..bed374cf4d5 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,7 +15,8 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( @@ -76,6 +77,13 @@ def test_fused_moe( else: e_map = None + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w1, w2, score, topk, e_map) iterative_output = iterative_moe(a, @@ -103,7 +111,20 @@ def test_fused_moe( expert_map=e_map, renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + m_triton_output = m_fused_moe(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(m_triton_output, + torch_output, + atol=2e-2, + rtol=0) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index ef3e6adcfa3..d90202dfcb3 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import pytest import torch -from tests.pplx_utils import ProcessGroupInfo, parallel_launch from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul @@ -14,6 +15,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from .deepep_utils import ProcessGroupInfo, parallel_launch + try: from pplx_kernels import AllToAll from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, @@ -64,6 +67,7 @@ def pplx_cutlass_moe( out_dtype, per_act_token: bool, per_out_ch: bool, + group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) @@ -84,7 +88,7 @@ def pplx_cutlass_moe( else: scale_elems = (hidden_dim + block_size - 1) // block_size - ata = AllToAll.internode( + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -96,6 +100,12 @@ def pplx_cutlass_moe( hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize, ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + w1 = w1.to(device) w2 = w2.to(device) w1_scale = w1_scale.to(device) @@ -113,7 +123,10 @@ def pplx_cutlass_moe( ) experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, - out_dtype, per_act_token, per_out_ch) + out_dtype, + per_act_token, + per_out_ch, + use_batched_format=True) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -184,11 +197,17 @@ def _pplx_moe( w2_full: torch.Tensor, per_act_token: bool, per_out_ch: bool, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, @@ -196,7 +215,7 @@ def _pplx_moe( pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, - per_out_ch) + per_out_ch, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -207,7 +226,8 @@ def _pplx_moe( torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("m", [2, 224]) @@ -218,6 +238,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), @@ -232,6 +253,7 @@ def test_cutlass_moe_pplx( per_act_token: bool, per_out_ch: bool, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) @@ -284,4 +306,5 @@ def test_cutlass_moe_pplx( parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, - dtype, a, w1_d, w2_d, per_act_token, per_out_ch) + dtype, a, w1_d, w2_d, per_act_token, per_out_ch, + use_internode) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 0b48bbef6ce..2d6a8f39cec 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,7 +18,6 @@ except ImportError: has_pplx = False -from tests.pplx_utils import ProcessGroupInfo, parallel_launch from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config @@ -30,6 +29,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from .deepep_utils import ProcessGroupInfo, parallel_launch + requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", @@ -153,7 +154,10 @@ def batched_moe( num_experts = w1.shape[0] fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedPrepareAndFinalize(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1, + rank=0), BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) @@ -229,9 +233,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: return t[(r * chunk):(r + 1) * chunk] -def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, - topk_weight: torch.Tensor, topk_ids: torch.Tensor, - num_experts: int) -> torch.Tensor: +def pplx_prepare_finalize( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + group_name: Optional[str], +) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) @@ -245,7 +255,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, world_size = pgi.world_size max_num_tokens = rank_chunk(num_tokens, 0, world_size) - ata = AllToAll.internode( + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -259,6 +269,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, torch.float32.itemsize)), ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + topk_ids = topk_ids.to(dtype=torch.uint32) prepare_finalize = PplxPrepareAndFinalize( @@ -318,11 +334,19 @@ def _pplx_prepare_finalize( score: torch.Tensor, topk: torch.Tensor, num_experts: int, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name + device = pgi.device topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -335,14 +359,15 @@ def _pplx_prepare_finalize( a.dtype) pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, - num_experts) + num_experts, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() # TODO (bnell): this test point does not work for odd M due to how the test is @@ -353,6 +378,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_prepare_finalize( mnk: tuple[int, int, int], @@ -360,6 +386,7 @@ def test_pplx_prepare_finalize( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk @@ -369,10 +396,11 @@ def test_pplx_prepare_finalize( score = torch.randn((m, e), device=device, dtype=dtype) parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, - topk, e) + topk, e, use_internode) def pplx_moe( + group_name: Optional[str], rank: int, world_size: int, dp_size: int, @@ -394,7 +422,7 @@ def pplx_moe( topk = topk_ids.shape[1] max_num_tokens = rank_chunk(a.shape[0], 0, world_size) - ata = AllToAll.internode( + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -408,6 +436,12 @@ def pplx_moe( torch.float32.itemsize)), ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + topk_ids = topk_ids.to(dtype=torch.uint32) prepare_finalize = PplxPrepareAndFinalize( @@ -522,11 +556,18 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name m, k = a.shape e, _, n = w2.shape @@ -536,8 +577,8 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, - topk_weight, topk_ids) + pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, + a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, # topk_ids) @@ -548,7 +589,8 @@ def _pplx_moe( torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) @@ -556,6 +598,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( mnk: tuple[int, int, int], @@ -563,6 +606,7 @@ def test_pplx_moe( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk @@ -572,4 +616,5 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + use_internode) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 8c5ee98743d..eec59573792 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -45,7 +46,7 @@ K = [256, 3884, 4096, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 2048] +M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] M_moe_dg = [128, 192, 1335, 2048] N_moe = [128, 256, 1024, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] @@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=block_size) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): out = fused_moe( @@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + m_out = m_fused_moe(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=E, + w1_scale=w1_s, + w2_scale=w2_s) + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + rel_diff = (torch.mean( + torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + def per_block_cast_to_fp8( x: torch.Tensor, diff --git a/tests/pplx_utils.py b/tests/pplx_utils.py deleted file mode 100644 index 2d5d5be80c3..00000000000 --- a/tests/pplx_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import os -import traceback -from typing import Callable - -import torch -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec - -P = ParamSpec("P") - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exc() - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - "tcp://localhost:29500", - worker, - ) + args, - nprocs=world_size, - join=True, - ) - - -def parallel_launch_from_env( - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - """ - Launches a worker function in parallel across all processes in the current - environment. The environment must have the following variables set: - - WORLD_SIZE: The total number of processes. - - WORLD_LOCAL_SIZE: The number of processes on the current node. - - NODE_RANK: The rank of the current - - MASTER_ADDR: The address of the master process. - - MASTER_PORT: The port of the master process. - """ - assert not kwargs - world_size = int(os.environ["WORLD_SIZE"]) - world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) - node_rank = int(os.environ["NODE_RANK"]) - assert "MASTER_ADDR" in os.environ - assert "MASTER_PORT" in os.environ - spawn( - _worker_parallel_launch, - args=( - world_size, - world_local_size, - node_rank, - "env://", - worker, - ) + args, - nprocs=world_local_size, - join=True, - ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 76d71ca0885..30b74165657 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -36,6 +36,9 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, assert (len(self.block_shape) == 2 and all( [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -45,17 +48,19 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) - workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + output = (num_experts, max_num_tokens * num_dp, K) + return (workspace13, workspace2, output, a.dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -72,7 +77,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): import deep_gemm as dg assert hidden_states.ndim == 3 @@ -89,7 +94,6 @@ def apply( workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K)) # (from deepgemm docs) : A value hint (which is a value on CPU) # for the M expectation of each batch, correctly setting this value @@ -118,8 +122,6 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), - out=workspace3, + out=output, masked_m=expert_num_tokens, expected_m=expected_m) - - return workspace3 diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index d62d519af8d..d0ce59ba1e6 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -64,6 +64,15 @@ def __init__(self, block_shape=self.block_shape, # type: ignore[arg-type] ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None + assert (self.batched_deep_gemm_experts is not None + or self.batched_triton_experts is not None) + + def supports_chunking(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_chunking()) + and (bte is None or bte.supports_chunking())) + def workspace_shapes( self, a: torch.Tensor, @@ -73,7 +82,7 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. @@ -87,6 +96,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -103,7 +113,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): use_batched_deep_gemm_experts = (self.allow_deep_gemm and self.batched_deep_gemm_experts is not None) @@ -111,7 +121,7 @@ def apply( if use_batched_deep_gemm_experts else self.batched_triton_experts) assert experts is not None - return experts.apply(hidden_states, w1, w2, topk_ids, activation, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, - workspace13, workspace2, expert_num_tokens) + experts.apply(output, hidden_states, w1, w2, topk_ids, activation, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, + workspace2, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 6e7b1a4f2b6..f380cb77c7e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -14,6 +14,7 @@ def run_cutlass_moe_fp8( + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -31,7 +32,8 @@ def run_cutlass_moe_fp8( out_dtype: torch.dtype, per_act_token: bool, per_out_ch: bool, -) -> torch.Tensor: + use_batched_format: bool, +): a1q = hidden_states assert w1_scale is not None @@ -61,23 +63,20 @@ def run_cutlass_moe_fp8( if expert_map is not None: assert expert_num_tokens is None - # We have two modes: PPLX and non-PPLX. We differentiate them by checking - # if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX - # uses to track the number of tokens per expert). - # In the non-PPLX mode, the input tokens are not padded: thus, the shape + # We have two modes: batched experts and non-batched experts. + # In the non-batched mode, the input tokens are not padded: thus, the shape # of the input is [total_num_tokens, hidden_size]. The input and output # require shuffling by a_map and c_map such that the tokens assigned to # each expert are contiguous. - # In the PPLX mode, the input tokens are padded per expert to ensure that - # the PPLX dispatch and combine functions work correctly: thus, the shape + # In the batched mode, the input tokens are padded per expert to ensure that + # the batched dispatch and combine functions work correctly: thus, the shape # of the input is [num_experts, max_num_tokens_per_expert, hidden_size]. - # The PPLX input and output require no shuffling by a_map and c_map since + # The batched input and output require no shuffling by a_map and c_map since # their tokens are already contiguous for each expert as a result of # the dispatch function. - is_pplx = expert_num_tokens is not None - M = a1q.shape[0] # no pplx - padded_M = a1q.shape[1] # pplx + M = a1q.shape[0] # non batched expert M + padded_M = a1q.shape[1] # batched expert M _, K, N = w2.shape device = a1q.device @@ -95,7 +94,9 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.shape[1] local_E = w1.shape[0] - if is_pplx: + if use_batched_format: + assert expert_num_tokens is not None + expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device) @@ -167,7 +168,7 @@ def run_cutlass_moe_fp8( device=device, dtype=torch.int64) - if is_pplx: + if use_batched_format: c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) c2 = _resize_cache(workspace2, (local_E * padded_M, N)) c3 = _resize_cache(workspace13, (local_E * padded_M, K)) @@ -192,12 +193,15 @@ def run_cutlass_moe_fp8( problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) - if is_pplx: - return c3.reshape(local_E, padded_M, K) + if use_batched_format: + output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) else: - return c3[c_map].view(M, topk, K) + # We can't do this inplace because output may point to the same tensor + # as c3. + output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) +# TODO (bnell): split class batched vs. non-batched? class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -206,12 +210,17 @@ def __init__( out_dtype: torch.dtype, per_act_token: bool, per_out_ch: bool, + use_batched_format: bool = False, ): super().__init__() self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.per_act_token = per_act_token self.per_out_ch = per_out_ch + self.use_batched_format = use_batched_format + + def supports_chunking(self) -> bool: + return not self.use_batched_format def workspace_shapes( self, @@ -222,14 +231,24 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: - padded_M = aq.shape[1] - workspace1 = self.max_experts_per_worker * padded_M * max(N, K) - workspace2 = self.max_experts_per_worker * padded_M * (N // 2) - return (workspace1, workspace2, self.out_dtype) + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () + if self.use_batched_format: + padded_M = aq.shape[1] + workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M * topk, K) + return (workspace1, workspace2, output, self.out_dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -246,16 +265,17 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" activation_callable = lambda i, o: self.activation(activation, i, o) - return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids, - activation_callable, global_num_experts, - expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, - expert_num_tokens, self.out_dtype, - self.per_act_token, self.per_out_ch) + run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids, + activation_callable, global_num_experts, + expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, + expert_num_tokens, self.out_dtype, + self.per_act_token, self.per_out_ch, + self.use_batched_format) def cutlass_moe_fp8( @@ -325,6 +345,7 @@ def cutlass_moe_fp8( out_dtype=out_dtype, per_act_token=per_act_token, per_out_ch=per_out_ch, + use_batched_format=False, ), ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 436c632be9c..595e8c99514 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -70,6 +70,9 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() + def supports_chunking(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -79,18 +82,18 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: - + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) - workspace1 = M_sum * max(N * 2, K) - workspace2 = M_sum * max(N, K) - - return (workspace1, workspace2, a.dtype) + workspace1 = (M_sum, max(N * 2, K)) + workspace2 = (M_sum, max(N, K)) + output = (M * topk, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -107,7 +110,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): import deep_gemm as dg a1q = hidden_states @@ -143,7 +146,6 @@ def apply( quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - out = _resize_cache(workspace13, (inv_perm.size(0), K)) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) @@ -159,9 +161,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - torch.index_select(mm2_out, 0, inv_perm, out=out) - - return out + torch.index_select(mm2_out, 0, inv_perm, out=output) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 68a3485ff1f..fb66e96c794 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel( BLOCK_M = config['BLOCK_SIZE_M'] BLOCK_N = config['BLOCK_SIZE_N'] BLOCK_K = config['BLOCK_SIZE_K'] - assert (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing() - or max_num_tokens % BLOCK_M == 0) grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) @@ -390,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: Optional[int], world_size: int, - dp_size: int, rank: int): + def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, + rank: int): super().__init__() self.world_size = world_size self.dp_size = dp_size @@ -430,14 +427,9 @@ def prepare( num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) - if self.max_num_tokens is None: - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) - self.max_num_tokens = int(tokens_per_expert.max().item()) - else: - tokens_per_expert = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) assert num_experts % self.world_size == 0 @@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + max_num_tokens: int, world_size: int, dp_size: int, - max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -518,6 +510,9 @@ def __init__( self.world_size = world_size self.dp_size = dp_size + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -527,18 +522,16 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - #print(f"WORKSPACE {max_num_tokens} {num_dp}") - workspace13 = num_experts * max_num_tokens * num_dp * K - workspace2 = max_num_tokens * num_dp * N - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, self.max_num_tokens * num_dp, K) + workspace2 = (self.max_num_tokens * num_dp, N) + return (workspace13, workspace2, workspace13, a.dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -555,20 +548,12 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): assert hidden_states.dim() == 3 assert expert_num_tokens is not None - hidden_dim = hidden_states.size(-1) - - if self.max_num_tokens is None: - max_num_tokens = hidden_states.size(1) - else: - max_num_tokens = self.max_num_tokens + max_num_tokens = self.max_num_tokens num_dp = self.world_size // self.dp_size - num_experts = global_num_experts - out = _resize_cache(workspace13, - (num_experts, max_num_tokens * num_dp, hidden_dim)) num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") @@ -585,15 +570,13 @@ def apply( # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor if (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): - num = max_num_tokens * num_dp + num = hidden_states.shape[1] else: num = int(expert_num_tokens[expert].item()) tmp = _resize_cache(workspace2, (num, N)) input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) self.activation(activation, tmp, input) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) - - return out + output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -630,6 +613,9 @@ def __init__( assert not use_int4_w4a16, "NYI" assert self.block_shape is None, "NYI" + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -639,17 +625,19 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) - workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + output = (num_experts, max_num_tokens * num_dp, K) + return (workspace13, workspace2, output, a.dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -666,7 +654,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -723,8 +711,6 @@ def apply( (E, max_num_tokens, N)) intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, - (E, max_num_tokens, K)) # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, @@ -761,7 +747,7 @@ def apply( invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, - C=intermediate_cache3, + C=output, expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a2q_scale, @@ -772,4 +758,3 @@ def apply( use_int4_w4a16=self.use_int4_w4a16, config=config, block_shape=self.block_shape) - return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ba1498e6531..d9b1ba13267 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1542,6 +1542,9 @@ def __init__( use_int4_w4a16=use_int4_w4a16) self.per_channel_quant = per_channel_quant + def supports_chunking(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -1551,14 +1554,15 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: - factor = num_experts if a.dim() == 3 else 1 - workspace1 = M * topk * max(N * 2, K) * factor - workspace2 = M * topk * N * factor - return (workspace1, workspace2, a.dtype) + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1 = (M, topk, max(N * 2, K)) + workspace2 = (M, topk, N) + output = (M, topk, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1575,7 +1579,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -1632,8 +1636,6 @@ def apply( (num_tokens, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace2, (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, - (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], @@ -1671,7 +1673,7 @@ def apply( invoke_fused_moe_kernel(qintermediate_cache2, w2, - intermediate_cache3, + output, a2q_scale, w2_scale, w2_zp, @@ -1690,8 +1692,6 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - return intermediate_cache3 - def modular_triton_fused_moe( use_fp8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index e7aaf62fb34..9ef6a126680 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from math import prod from typing import Optional import torch +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.utils import cdiv + # # This file defines a set of base classes used to make MoE kernels more modular. # The goal is to be able to utilize different communication mechanisms with @@ -115,9 +120,9 @@ def prepare( - quantized + dispatched a. - quantized + dispatched a1_scales. - Optional tensor as big as number of local experts that contains the - number of tokens assigned to each local expert. + number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - - Optional dispatched expert topk weight + - Optional dispatched expert topk weight """ raise NotImplementedError @@ -159,7 +164,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: Some PrepareFinalize All2All implementations are batched. Meaning, they can processes only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens - the implementation can process at a time. + the implementation can process at a time. Return None if there are no such restrictions. """ raise NotImplementedError @@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ + # TODO (bnell): make this return a CHUNK_SIZE or None instead? + @abstractmethod + def supports_chunking(self) -> bool: + """ + A flag indicating whether or not this class supports activation + chunking. + """ + raise NotImplementedError + @abstractmethod def workspace_shapes( self, @@ -181,19 +195,22 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ - Compute the number of elements for the temporary outputs of the two - gemms and activation in the fused expert function. Since the - gemms are independent, the workspace for the first gemm can be shared - with the workspace for the last gemm. + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. Returns a tuple of: - - Number of workspace13 elements: must be large enough to hold the + - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - - Number of workspace2 elements: must be large enough to hold the + - workspace2 shape tuple: must be large enough to hold the result of the activation function. + - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. """ raise NotImplementedError @@ -210,6 +227,7 @@ def activation(self, activation: str, output: torch.Tensor, @abstractmethod def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -226,12 +244,13 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. Parameters: + - output: (torch.Tensor): The unweighted, unreduced output tensor. - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. @@ -259,13 +278,20 @@ def apply( function. - expert_num_tokens: An optional tensor containing the number of tokens assigned to each expert when using batched experts format input. - - Returns: - - torch.Tensor: The unweighted, unreduced output tensor """ raise NotImplementedError +def _chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + class FusedMoEModularKernel(torch.nn.Module): """ This class combines a FusedMoEPrepareAndFinalize instance and @@ -288,61 +314,6 @@ def __init__( self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts - def _do_fused_experts( - self, - a1: torch.Tensor, # input to forward fn - a1q: torch.Tensor, # output of prepare fn - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - expert_num_tokens: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor]) -> torch.Tensor: - - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) - - # Use a1 here to decipher the correct workspace datatype - workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, - global_num_experts)) - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.zeros(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.zeros(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) - - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) - - return fused_out - def forward( self, hidden_states: torch.Tensor, @@ -408,12 +379,14 @@ def forward( _expert_topk_weights) = self.prepare_finalize.prepare( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_weights = (topk_weights if _expert_topk_weights is None else _expert_topk_weights) fused_out = None + if a1q.numel() == 0: # This happens when none of the tokens from the all2all reach this # EP rank. Also, note that this is only relevant for CUDAGraph @@ -423,22 +396,107 @@ def forward( # and can never run into the tensor.numel() == 0 case. fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) else: - fused_out = self._do_fused_experts( - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_ids=topk_ids, - expert_num_tokens=expert_num_tokens, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale) + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + if self.fused_experts.supports_chunking(): + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = cdiv(M, CHUNK_SIZE) + else: + CHUNK_SIZE = M + num_chunks = 1 + + if num_chunks == 1: + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts) + else: + # Use the full M to get the final output shape. + _, _, fused_out_shape, _ = ( + self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts)) + # Use the CHUNK_SIZE to get the workspace shapes. + workspace13_shape, workspace2_shape, _, workspace_dtype = ( + self.fused_experts.workspace_shapes( + a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts)) + + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = torch.zeros(prod(workspace13_shape), + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.zeros(prod(workspace2_shape), + device=a1.device, + dtype=workspace_dtype) + + if num_chunks == 1: + fused_out = _resize_cache(workspace13, fused_out_shape) + + self.fused_experts.apply( + fused_out, + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + else: + # The leading output dimension may not be equal to M, so + # we compute output indices separately. + M_out = fused_out_shape[0] + assert M_out >= M + factor = M_out // M + assert factor > 0 + OUT_CHUNK_SIZE = CHUNK_SIZE * factor + + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=workspace_dtype) + + assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( + f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") + + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + begin_out_idx = chunk * OUT_CHUNK_SIZE + end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + + self.fused_experts.apply( + fused_out[begin_out_idx:end_out_idx], + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 87de29444c0..d4233c23f53 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -34,6 +34,12 @@ def __init__(self, self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None + def supports_chunking(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_chunking()) + and (te is None or te.supports_chunking())) + def workspace_shapes( self, a: torch.Tensor, @@ -43,7 +49,7 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. @@ -57,6 +63,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -73,45 +80,31 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): N = w1.size(1) - if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)): - assert self.deep_gemm_expert is not None - return self.deep_gemm_expert.apply( - hidden_states, - w1, - w2, - topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_num_tokens, - ) - else: - return self.triton_expert.apply( - hidden_states, - w1, - w2, - topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_num_tokens, - ) + + use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)) + + experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert + assert experts is not None + + experts.apply( + output, + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bc9d399cf13..f14131c5f05 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -562,9 +562,12 @@ def select_gemm_impl(self, prepare_finalize, moe): (moe.num_experts + prepare_finalize.world_size - 1) // prepare_finalize.world_size) experts = CutlassExpertsFp8( - max_experts_per_worker, moe.in_dtype, + max_experts_per_worker, + moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + use_batched_format=True, + ) if has_pplx and isinstance( prepare_finalize,