From 6f1bf3ed124706bc2f61cd1476f0bb085cf837d0 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 1 Jul 2025 02:08:56 +0000 Subject: [PATCH] modular kernel unit test / profiling / tools Signed-off-by: Varun Sundar Rabindranath --- .../moe/modular_kernel_tools/__init__.py | 0 .../moe/modular_kernel_tools/cli_args.py | 160 +++++ .../moe/modular_kernel_tools/common.py | 641 ++++++++++++++++++ .../make_feature_matrix.py | 173 +++++ .../moe/modular_kernel_tools/mk_objects.py | 87 +++ .../modular_kernel_tools/parallel_utils.py | 138 ++++ .../profile_modular_kernel.py | 127 ++++ .../kernels/moe/modular_kernel_tools/utils.py | 142 ++++ tests/kernels/moe/parallel_utils.py | 6 +- .../moe/test_modular_kernel_combinations.py | 214 ++++++ tests/kernels/utils.py | 30 +- .../base_device_communicator.py | 3 +- .../batched_triton_or_deep_gemm_moe.py | 1 - vllm/model_executor/layers/fused_moe/layer.py | 18 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 9 +- 15 files changed, 1727 insertions(+), 22 deletions(-) create mode 100644 tests/kernels/moe/modular_kernel_tools/__init__.py create mode 100644 tests/kernels/moe/modular_kernel_tools/cli_args.py create mode 100644 tests/kernels/moe/modular_kernel_tools/common.py create mode 100644 tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py create mode 100644 tests/kernels/moe/modular_kernel_tools/mk_objects.py create mode 100644 tests/kernels/moe/modular_kernel_tools/parallel_utils.py create mode 100644 tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py create mode 100644 tests/kernels/moe/modular_kernel_tools/utils.py create mode 100644 tests/kernels/moe/test_modular_kernel_combinations.py diff --git a/tests/kernels/moe/modular_kernel_tools/__init__.py b/tests/kernels/moe/modular_kernel_tools/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py new file mode 100644 index 000000000000..261f1eb6e5c3 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig + +from .common import Config +from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + + +def make_config_arg_parser(description: str): + + def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: + for pf in MK_ALL_PREPARE_FINALIZE_TYPES: + if pf.__name__ == s: + return pf + raise ValueError( + f"Cannot find a PrepareFinalize type that matches {s}") + + def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: + for fe in MK_FUSED_EXPERT_TYPES: + if fe.__name__ == s: + return fe + raise ValueError(f"Cannot find a FusedExperts type that matches {s}") + + def to_quant_torch_dtype(s: str) -> torch.dtype: + if s == "torch.float8_e4m3fn": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported quant type {s}") + + parser = argparse.ArgumentParser(description=description) + + parser.add_argument( + "--world-size", + type=int, + default=2, + help="Number of ranks that participate in all2all", + ) + parser.add_argument( + "--pf-type", + type=to_pf_class_type, + required=True, + help=("Choose a PrepareFinalize Type : " + f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"), + ) + parser.add_argument( + "--experts-type", + type=to_experts_class_type, + required=True, + help=(f"Choose a FusedExpert type : " + f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"), + ) + parser.add_argument( + "-m", + nargs="+", + type=int, + default=[64], + help="num tokens per rank", + ) + parser.add_argument( + "-k", + type=int, + default=7168, + help="hidden-size", + ) + parser.add_argument( + "-n", + type=int, + default=1024, + help="N dimension of the first fused-moe matmul", + ) + parser.add_argument("--num-experts", + type=int, + default=32, + help="Global num experts") + parser.add_argument("--topk", + nargs="+", + type=int, + default=[4, 1], + help="num topk") + parser.add_argument( + "--fused-moe-chunk-size", + nargs="+", + type=int, + help="Fused moe chunk size used for the non-batched fused experts impl." + ) + + # Quant args + parser.add_argument("--quant-dtype", + type=to_quant_torch_dtype, + help="Quant datatype") + parser.add_argument("--per-token-quantized-activations", + action='store_true', + help=("The input activations must be per-token " + "quantized")) + parser.add_argument("--per-channel-quantized-weights", + action="store_true", + help="The weights must be per-channel quantized.") + parser.add_argument("--block-shape", + nargs="+", + type=int, + help="Quantization block shape") + + # Torch trace profile generation args + parser.add_argument("--torch-trace-dir-path", + type=str, + default=None, + help="Get torch trace for single execution") + + return parser + + +def _validate_args(args: argparse.Namespace): + + if args.quant_dtype is not None: + assert args.quant_dtype == torch.float8_e4m3fn + if args.block_shape is not None: + assert len(args.block_shape) == 2, ( + f"block shape must have 2 elements. got {args.block_shape}") + + if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: + assert args.world_size == 1, ( + "Single GPU objects need world size set to 1") + + if args.torch_trace_dir_path is not None: + from pathlib import Path + assert Path(args.torch_trace_dir_path).is_dir(), ( + f"Please create {args.torch_trace_dir_path}") + + +def make_config(args: argparse.Namespace) -> Config: + + _validate_args(args) + + quant_config = None + if args.quant_dtype is not None: + quant_config = FusedMoEQuantConfig( + quant_dtype=args.quant_dtype, + per_act_token_quant=args.per_token_quantized_activations, + per_out_ch_quant=args.per_channel_quantized_weights, + block_shape=args.block_shape) + + return Config( + Ms=args.m, + K=args.k, + N=args.n, + E=args.num_experts, + topks=args.topk, + dtype=torch.bfloat16, # hard-code + quant_config=quant_config, + prepare_finalize_type=args.pf_type, + fused_experts_type=args.experts_type, + fused_moe_chunk_size=args.fused_moe_chunk_size, + world_size=args.world_size, + torch_trace_dir_path=args.torch_trace_dir_path) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py new file mode 100644 index 000000000000..a1319ab0509b --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -0,0 +1,641 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch + +import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.utils import torch_experts +from vllm.config import VllmConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size +# Fused experts and PrepareFinalize imports +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + TritonExperts) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx + +from .parallel_utils import ProcessGroupInfo +from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, + make_quant_fp8_weights, per_token_cast_to_fp8) + +if has_pplx(): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) +if has_deep_ep(): + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + + +def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: + if t is None: + return f"{name} : None" + else: + return f"{name} : {t.shape} {t.dtype} {t.device}" + + +@dataclass +class Config: + Ms: Union[list[int], int] + K: int + N: int + E: int + topks: Union[list[int], int] + dtype: torch.dtype + quant_config: Optional[FusedMoEQuantConfig] + + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute + + fused_moe_chunk_size: Optional[int] + world_size: int + + torch_trace_dir_path: Optional[str] = None + + def describe(self) -> str: + s = "" + s += "== Config: \n" + s += f" world_size={self.world_size} \n" + s += f" PF={self.prepare_finalize_type.__name__} \n" + s += f" FE={self.fused_experts_type.__name__} \n" + s += f" topk={self.topks} \n" + s += f" dtype={self.dtype} \n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" + s += " Quant: \n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " + if self.quant_config is not None: + s += f" q_dtype={self.quant_dtype} \n" + s += f" q_block_shape={self.quant_block_shape} \n" + s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" + s += f" q_per_act_token={self.is_per_act_token_quant} \n" + else: + s += " quant=None \n" + return s + + @property + def M(self) -> int: + assert isinstance(self.Ms, int) + return self.Ms + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is None: + return None + return self.quant_config.quant_dtype + + @property + def is_per_act_token_quant(self) -> bool: + if self.quant_config is None: + return False + return self.quant_config.per_act_token_quant + + @property + def is_per_tensor_act_quant(self) -> bool: + if self.quant_config is None: + return False + return (not self.is_per_act_token_quant + and self.quant_block_shape is None) + + @property + def is_per_out_ch_quant(self) -> bool: + if self.quant_config is None: + return False + return self.quant_config.per_out_ch_quant + + @property + def quant_block_shape(self) -> Optional[list[int]]: + if self.quant_config is None: + return None + return self.quant_config.block_shape + + @property + def topk(self) -> int: + assert isinstance(self.topks, int) + return self.topks + + @property + def topk_ids_dtype(self) -> Optional[torch.dtype]: + topk_ids_dtype = None + if self.prepare_finalize_type == PplxPrepareAndFinalize: + topk_ids_dtype = torch.uint32 + elif self.prepare_finalize_type in [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ]: + topk_ids_dtype = torch.int64 + return topk_ids_dtype + + @property + def num_local_experts(self) -> int: + return self.E // self.world_size + + def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: + """ + make env data for vllm launch. + """ + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = self.world_size + vllm_config.parallel_config.enable_expert_parallel = True + + env_dict = { + "VLLM_ALL2ALL_BACKEND": self.all2all_backend(), + "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), + } + if self.fused_moe_chunk_size is not None: + env_dict.update( + {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + return vllm_config, env_dict + + def is_fp8_block_quantized(self): + return (self.quant_dtype == torch.float8_e4m3fn + and self.quant_block_shape is not None) + + def is_batched_prepare_finalize(self): + return self.prepare_finalize_type in [ + PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + + def is_batched_fused_experts(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, + NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts + ] + + def is_standard_fused_experts(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, + TritonExperts + ] + + def is_fe_16bit_supported(self): + return self.fused_experts_type in [ + BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, + NaiveBatchedExperts, TritonExperts + ] + + def is_fe_fp8_supported(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + BatchedTritonExperts, + BatchedTritonOrDeepGemmExperts, + CutlassExpertsFp8, + DeepGemmExperts, + TritonExperts, + TritonOrDeepGemmExperts, + NaiveBatchedExperts, + ] + + def is_fe_block_fp8_supported(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + BatchedTritonOrDeepGemmExperts, + DeepGemmExperts, + TritonExperts, + TritonOrDeepGemmExperts, + BatchedTritonExperts, + NaiveBatchedExperts, + ] + + def is_fe_supports_chunking(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, + TritonExperts + ] + + def needs_deep_gemm(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + DeepGemmExperts, + ] + + def needs_pplx(self): + return self.prepare_finalize_type in [PplxPrepareAndFinalize] + + def needs_deep_ep(self): + return self.prepare_finalize_type in [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + + def all2all_backend(self): + if self.needs_pplx(): + return "pplx" + if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize: + return "deepep_high_throughput" + if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize: + return "deepep_low_latency" + return "naive" + + def needs_all2all(self): + return self.prepare_finalize_type in [ + PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize + ] + + def is_valid(self): + # Check prepare-finalize and fused-experts compatibility + if self.is_batched_prepare_finalize(): + if not self.is_batched_fused_experts(): + return False + else: + if not self.is_standard_fused_experts(): + return False + + use_chunking = self.fused_moe_chunk_size is not None + if use_chunking and not self.is_fe_supports_chunking(): + return False + + # Check quantization sanity + if (int(self.is_per_act_token_quant) + + int(self.is_per_tensor_act_quant) + + int(self.quant_block_shape is not None)) > 1: + # invalid quant config + return False + + # check bf16 / fp16 support + is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) + if is_16bit and not self.is_fe_16bit_supported(): + return False + + # Check fp8 support + is_fp8 = self.quant_dtype == torch.float8_e4m3fn + if is_fp8 and not self.is_fe_fp8_supported(): + return False + + # Check fp8 block quanization support + is_block_quatized = self.quant_block_shape is not None + if is_block_quatized and not is_fp8: + return False + if is_block_quatized and not self.is_fe_block_fp8_supported(): + return False + + # deep_gemm only works with block-quantized + if self.needs_deep_gemm() and not is_block_quatized: + return False + + # Check dependencies + if self.needs_deep_ep() and not has_deep_ep(): + return False + if self.needs_deep_gemm() and not has_deep_gemm(): + return False + if self.needs_pplx() and not has_pplx(): # noqa: SIM103 + return False + + return True + + +@dataclass +class WeightTensors: + w1: torch.Tensor + w2: torch.Tensor + w1_scale: Optional[torch.Tensor] + w2_scale: Optional[torch.Tensor] + + def describe(self): + s = "" + s += "== Weight Tensors: \n" + s += f' - {_describe_tensor(self.w1, "w1")} \n' + s += f' - {_describe_tensor(self.w2, "w2")} \n' + s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' + s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + return s + + def to_current_device(self): + self.w1 = self.w1.to(device=torch.cuda.current_device()) + self.w2 = self.w2.to(device=torch.cuda.current_device()) + is_quantized = self.w1.dtype == torch.float8_e4m3fn + if is_quantized: + assert self.w1_scale is not None + assert self.w2_scale is not None + self.w1_scale = self.w1_scale.to( + device=torch.cuda.current_device()) + self.w2_scale = self.w2_scale.to( + device=torch.cuda.current_device()) + + def slice_weights(self, rank: int, + num_local_experts: int) -> "WeightTensors": + s = rank * num_local_experts + e = s + num_local_experts + w1 = self.w1[s:e, :, :] + w2 = self.w2[s:e, :, :] + is_quantized = self.w1.dtype == torch.float8_e4m3fn + w1_scale, w2_scale = (None, None) + if is_quantized: + assert self.w1_scale is not None + assert self.w2_scale is not None + w1_scale = self.w1_scale[s:e, :, :] + w2_scale = self.w2_scale[s:e, :, :] + return WeightTensors(w1, w2, w1_scale, w2_scale) + + @staticmethod + def make(config: Config) -> "WeightTensors": + + if config.quant_dtype is None: + # just make normal dtype weights + w1, w2 = make_non_quant_weights(e=config.E, + n=config.N, + k=config.K, + dtype=config.dtype) + return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) + + assert config.quant_dtype == torch.float8_e4m3fn + if not config.is_fp8_block_quantized(): + w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( + e=config.E, + n=config.N, + k=config.K, + per_out_channel_quant=config.is_per_out_ch_quant, + ) + return WeightTensors(w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale) + + assert config.quant_block_shape is not None + w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + e=config.E, + n=config.N, + k=config.K, + block_size=config.quant_block_shape, + ) + return WeightTensors(w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale) + + +@dataclass +class RankTensors: + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + + topk_weights: torch.Tensor + topk_ids: torch.Tensor + expert_map: Optional[torch.Tensor] + + quant_config: Optional[FusedMoEQuantConfig] + + def describe(self): + s = "" + s += "== Rank Tensors: \n" + s += f' - {_describe_tensor(self.hidden_states, "HS")} \n' + s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n' + s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n' + s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n' + s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n' + return s + + @staticmethod + def make_hidden_states( + config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Return hidden_states + """ + m, k, dtype = (config.M, config.K, config.dtype) + a = (torch.randn( + (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0) + + if config.quant_dtype is None: + return a, None + + # We dequant and use that as hidden_states so the tests are stable. + # quantizing and dequantizing yield slightly different results + # depending on the hardware. Here we, quantize and dequantize + # first - so further quantize and dequantize will yeild the same + # values. + if config.is_per_tensor_act_quant: + a_q, a_scales = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=False) + return a_q.float().mul(a_scales).to(dtype), a_scales + + if config.is_per_act_token_quant: + a_q, a_scales = ops.scaled_fp8_quant(a, + use_per_token_if_dynamic=True) + return a_q.float().mul(a_scales).to(dtype), None + + assert config.quant_block_shape is not None + block_k = config.quant_block_shape[1] + a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k) + return a_q.float().view( + (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None + + @staticmethod + def make(config: Config, pgi: ProcessGroupInfo): + + dtype = config.dtype + topk, m, _ = (config.topk, config.M, config.K) + hidden_states, hidden_states_scale = RankTensors.make_hidden_states( + config) + + num_local_experts, global_num_experts = (config.num_local_experts, + config.E) + score = torch.randn((m, global_num_experts), + device="cuda", + dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, + False) + topk_ids = topk_ids.to(config.topk_ids_dtype) + + # distribute topk_ids evenly + for mi in range(m): + topk_ids[mi] = torch.randperm(config.E)[:topk] + topk_ids = topk_ids.to(device=torch.cuda.current_device()) + + expert_map = None + if config.world_size > 1: + expert_map = torch.full((global_num_experts, ), + fill_value=-1, + dtype=torch.int32) + s = pgi.rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts))) + expert_map = expert_map.to(device=torch.cuda.current_device(), + dtype=torch.int32) + + return RankTensors( + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + quant_config=config.quant_config, + ) + + +def reference_moe_impl(config: Config, weights: WeightTensors, + rank_tensors: RankTensors) -> torch.Tensor: + + return torch_experts(a=rank_tensors.hidden_states, + w1=weights.w1, + w2=weights.w2, + topk_weight=rank_tensors.topk_weights, + topk_ids=rank_tensors.topk_ids, + global_num_experts=config.E, + expert_map=None, + w1_scale=weights.w1_scale, + w2_scale=weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + quant_dtype=config.quant_dtype, + per_act_token_quant=config.is_per_act_token_quant, + block_shape=config.quant_block_shape, + apply_router_weights_on_input=config.topk == 1) + + +def make_fused_experts( + config: Config, moe: FusedMoEConfig, + num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: + + use_fp8 = config.quant_dtype == torch.float8_e4m3fn + batch_kwargs = { + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + quant_kwargs = { + "use_fp8_w8a8": use_fp8, + "use_int8_w8a8": False, + "use_int8_w8a16": False, + "use_int4_w4a16": False, + "block_shape": config.quant_block_shape, + "per_act_token_quant": config.is_per_act_token_quant, + } + deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + + if config.fused_experts_type == BatchedDeepGemmExperts: + kwargs = batch_kwargs | { + "block_shape": config.quant_block_shape, + "per_act_token_quant": config.is_per_act_token_quant, + } + print(f"Making BatchedDeepGemmExperts {kwargs} ...") + experts = BatchedDeepGemmExperts(**kwargs) + elif config.fused_experts_type == BatchedTritonExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making BatchedTritonExperts {kwargs} ...") + experts = BatchedTritonExperts(**kwargs) + elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts: + kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs + print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") + experts = BatchedTritonOrDeepGemmExperts(**kwargs) + elif config.fused_experts_type == DeepGemmExperts: + print("Making DeepGemmExperts () ...") + experts = DeepGemmExperts() + elif config.fused_experts_type == TritonExperts: + kwargs = quant_kwargs + print(f"Making TritonExperts {kwargs} ...") + experts = TritonExperts(**kwargs) + elif config.fused_experts_type == TritonOrDeepGemmExperts: + kwargs = quant_kwargs | deepgemm_kwargs + print(f"Making TritonOrDeepGemmExperts {kwargs} ...") + experts = TritonOrDeepGemmExperts(**kwargs) + elif config.fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) + elif config.fused_experts_type == CutlassExpertsFp8: + use_batched_format = config.is_batched_prepare_finalize() + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) + kwargs = { + "max_experts_per_worker": num_experts, + "out_dtype": moe.in_dtype, + "per_act_token_quant": config.is_per_act_token_quant, + "per_out_ch_quant": config.is_per_out_ch_quant, + "block_shape": config.quant_block_shape, + "num_dispatchers": num_dispatchers, + "use_batched_format": use_batched_format + } + print(f"Making CutlassExpertsFp8 {kwargs} ...") + experts = CutlassExpertsFp8(**kwargs) + + return experts + + +def make_modular_kernel(config: Config, + vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: + + def next_power_of_2(x): + import math + if x == 0: + return 1 + return 2**math.ceil(math.log2(x)) + + # make moe config + moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=get_tensor_model_parallel_world_size(), + dp_size_=get_dp_group().world_size, + vllm_parallel_config=vllm_config.parallel_config, + ) + moe = FusedMoEConfig( + num_experts=config.E, + experts_per_token=config.topk, + hidden_dim=config.K, + num_local_experts=config.num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=config.dtype, + quant_config=config.quant_config, + max_num_tokens=next_power_of_2(config.M), + ) + + # make modular kernel + prepare_finalize = None + if config.needs_all2all(): + prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe) + assert prepare_finalize is not None + else: + prepare_finalize = MoEPrepareAndFinalizeNoEP() + + fused_experts = make_fused_experts(config, moe, + prepare_finalize.num_dispatchers()) + + modular_kernel = mk.FusedMoEModularKernel( + prepare_finalize=prepare_finalize, fused_experts=fused_experts) + + return modular_kernel + + +def run_modular_kernel( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + config: Config, + weights: WeightTensors, + rank_tensors: RankTensors, +) -> torch.Tensor: + assert isinstance(config.Ms, int) + assert isinstance(config.topks, int) + + # weights for rank + rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) + + mk = make_modular_kernel(config, vllm_config) + + mk_kwargs = { + "hidden_states": rank_tensors.hidden_states.clone( + ), # impls might update the tensor in place + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": rank_tensors.topk_ids, + "expert_map": rank_tensors.expert_map, + "w1_scale": rank_weights.w1_scale, + "w2_scale": rank_weights.w2_scale, + "a1_scale": rank_tensors.hidden_states_scale, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1, + } + out = mk.forward(**mk_kwargs) + + return out diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py new file mode 100644 index 000000000000..5dbfdfc153f9 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from enum import Enum +from itertools import product +from typing import Optional + +import torch +from tqdm import tqdm + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.platforms import current_platform + +from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, + run_modular_kernel) +from .mk_objects import (MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS) +from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config + + +class Result(Enum): + PASS = 1 + FAIL = 2 + SKIP = 3 + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) + + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + + torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + + +def make_feature_matrix(csv_file_path: str): + + from dataclasses import asdict + + import pandas as pd + + def add_to_results(config: Config, + success: Result, + results_df: Optional[pd.DataFrame] = None): + config_dict = asdict(config) + config_dict['prepare_finalize_type'] = config_dict[ + 'prepare_finalize_type'].__name__ + config_dict['fused_experts_type'] = config_dict[ + 'fused_experts_type'].__name__ + config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant + quant_config_dict = config_dict['quant_config'] + del config_dict['quant_config'] + if quant_config_dict is None: + quant_config = FusedMoEQuantConfig(None) + quant_config_dict = asdict(quant_config) + + config_dict |= quant_config_dict + result_dict = config_dict | {'success': success.name} + + result_df = pd.DataFrame([result_dict]) + if results_df is None: + results_df = result_df + else: + results_df = pd.concat([results_df, result_df], ignore_index=True) + + return results_df + + Ms = [64] + Ks = [7168] # hidden sizes + Ns = [2048] + TOPKs = [[4, 1]] + Es = [32] + DTYPEs = [torch.bfloat16] + PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + FE_TYPES = MK_FUSED_EXPERT_TYPES + Q_TYPES = MK_QUANT_CONFIGS + + combinations = list( + product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)) + + results_df: Optional[pd.DataFrame] = None + for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( + combinations): #noqa: E501 + config = Config(Ms=[m], + K=k, + N=n, + E=e, + topks=topks, + dtype=dtype, + prepare_finalize_type=pf_type, + fused_experts_type=experts_type, + quant_config=quant_config, + world_size=2, + fused_moe_chunk_size=None) + + success = None + if config.is_valid(): + print(f"Running config : {config.describe()} ...") + try: + weights: WeightTensors = WeightTensors.make(config) + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, + vllm_config, env_dict, config, + weights) + success = Result.PASS + except Exception as _: + success = Result.FAIL + else: + success = Result.SKIP + + results_df = add_to_results(config, success, results_df) + + if results_df is not None: + results_df.to_csv(f"{csv_file_path}") + + +if __name__ == '__main__': + import argparse + from pathlib import Path + parser = argparse.ArgumentParser(description=( + "Make ModularKernel feature matrix \n" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501 + "-f ./feature_matrices/feature_matrix.csv")) + + parser.add_argument("-f", + "--feature-matrix-csv-file-path", + type=str, + required=True, + help="File name to Generate a .csv file") + args = parser.parse_args() + + csv_path = args.feature_matrix_csv_file_path + assert csv_path.endswith( + 'csv'), f"Need a file path ending with .csv, got {csv_path}" + assert Path(csv_path).parent.is_dir( + ), f"Cannot find parent directory for {Path(csv_path).parent}" + + make_feature_matrix(args.feature_matrix_csv_file_path) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py new file mode 100644 index 000000000000..73214066f7ea --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +# Fused experts and PrepareFinalize imports +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_pplx + +if has_deep_ep(): + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + +if has_pplx(): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] +if has_pplx(): + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] +if has_deep_ep(): + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] + +MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + +MK_FUSED_EXPERT_TYPES = [ + BatchedDeepGemmExperts, + BatchedTritonExperts, + NaiveBatchedExperts, + BatchedTritonOrDeepGemmExperts, + CutlassExpertsFp8, + DeepGemmExperts, + TritonOrDeepGemmExperts, + TritonExperts, +] + +MK_QUANT_CONFIGS = [ + None, + # per-channel / per-column weights and per-tensor activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None), + # per-channel / per-column weights and per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None), + # per-tensor weights and per-tensor activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), + # per-tensor weights and per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None), + # block-quantized weights and 128 block per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128]), + # TODO (varun) : Should we test the following combinations ? + # block-quantized weights and per-token activations + # block-quantized weights and per-tensor activations +] diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py new file mode 100644 index 000000000000..1f8d21a7a702 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import os +import traceback +from typing import Any, Callable, Optional + +import torch +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) +from vllm.utils import get_open_port + +## Parallel Processes Utils + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, + local_rank: int): + + import tempfile + temp_file = tempfile.mkstemp()[1] + + set_current_vllm_config(vllm_config) + with set_current_vllm_config(vllm_config): + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=f"file://{temp_file}", + local_rank=local_rank, + backend="nccl", + ) + + initialize_model_parallel( + tensor_model_parallel_size=vllm_config.parallel_config. + tensor_parallel_size, + pipeline_model_parallel_size=vllm_config.parallel_config. + pipeline_parallel_size, + ) + cpu_group = torch.distributed.new_group(list(range(world_size)), + backend="gloo") + return cpu_group + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, + P], None], + vllm_config: Optional[VllmConfig], + env_dict: Optional[dict], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + if env_dict is not None: + os.environ.update(env_dict) + + cpu_group = None + if vllm_config is not None: + cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + vllm_config, + cpu_group, + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch_with_config( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None], + vllm_config: VllmConfig, + env_dict: dict[Any, Any], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", + worker, + vllm_config, + env_dict, + ) + args, + nprocs=world_size, + join=True, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py new file mode 100644 index 000000000000..dd16ffb2eabe --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from itertools import product +from typing import Any, Callable + +import torch + +from vllm.config import VllmConfig +from vllm.platforms import current_platform + +from .common import Config, RankTensors, WeightTensors, make_modular_kernel +from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config + + +def do_profile(fn: Callable, + fn_kwargs: dict[Any, Any], + pgi: ProcessGroupInfo, + config: Config, + num_warmups: int = 5): + for _ in range(num_warmups): + fn(**fn_kwargs) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=True, + ) as tprof: + fn(**fn_kwargs) + torch.cuda.synchronize(torch.cuda.current_device()) + + # TODO (varun): Add a descriptive trace file name + tprof.export_chrome_trace( + f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json") + + +def profile_modular_kernel( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + config: Config, + weights: WeightTensors, + rank_tensors: RankTensors, +) -> None: + assert isinstance(config.Ms, int) + assert isinstance(config.topks, int) + + # weights for rank + rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) + + # make modular kernel + mk = make_modular_kernel(config, vllm_config) + + mk_kwargs = { + "hidden_states": rank_tensors.hidden_states, + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": rank_tensors.topk_ids, + "expert_map": rank_tensors.expert_map, + "w1_scale": rank_weights.w1_scale, + "w2_scale": rank_weights.w2_scale, + "a1_scale": rank_tensors.hidden_states_scale, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1, + } + + do_profile(mk.forward, mk_kwargs, pgi, config) + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) + + +def run(config: Config): + weights: WeightTensors = WeightTensors.make(config) + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, vllm_config, + env_dict, config, weights) + + +if __name__ == '__main__': + from .cli_args import make_config, make_config_arg_parser + parser = make_config_arg_parser(description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + )) + args = parser.parse_args() + assert args.torch_trace_dir_path is not None, ( + "Please pass in a directory to store torch traces") + config = make_config(args) + + run(config) diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py new file mode 100644 index 000000000000..09bb4a34f318 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/utils.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +import torch + +import vllm._custom_ops as ops + + +def per_token_cast_to_fp8( + x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (block_size - (n % block_size)) % block_size + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, block_size) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8( + x: torch.Tensor, block_size_k: int, + block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + ( + int(math.ceil(m / block_size_k)) * block_size_k, + int(math.ceil(n / block_size_n)) * block_size_n, + ), + dtype=x.dtype, + device=x.device, + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_size_k, + x_padded.size(1) // block_size_k, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_non_quant_weights( + e: int, + n: int, + k: int, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2 + """ + device = torch.cuda.current_device() + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15 + return w1, w2 + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2, w1_scale, w2_scale + """ + dtype = torch.bfloat16 + device = torch.cuda.current_device() + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype) + w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * n) + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w2 = (n + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) + + w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), + device=device, + dtype=torch.float32) + w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), + device=device, + dtype=torch.float32) + + assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, + (k + (block_k - 1)) // block_k) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], + block_size_k=block_k, + block_size_n=block_n) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], + block_size_k=block_k, + block_size_n=block_n) + + return w1, w2, w1_s, w2_s + + +def make_quant_fp8_weights( + e: int, + n: int, + k: int, + per_out_channel_quant: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return w1, w2, w1_scale, w2_scale + """ + q_dtype = torch.float8_e4m3fn + + w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16) + + # w1 -> w1_q, w2 -> w2_q + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + n_b_scales = 2 * n if per_out_channel_quant else 1 + k_b_scales = k if per_out_channel_quant else 1 + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_channel_quant) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_channel_quant) + return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index f4049eb0d095..1ad361ae0733 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -4,7 +4,6 @@ DeepEP test utilities """ import dataclasses -import importlib import os import traceback from typing import Callable, Optional @@ -15,10 +14,9 @@ spawn) # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec -from vllm.utils import get_open_port +from vllm.utils import get_open_port, has_deep_ep -has_deep_ep = importlib.util.find_spec("deep_ep") is not None -if has_deep_ep: +if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py new file mode 100644 index 000000000000..6f2869c3a61d --- /dev/null +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from itertools import product +from typing import Optional + +import pytest +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config import VllmConfig, current_platform, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx + +from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, + reference_moe_impl, + run_modular_kernel) +from .modular_kernel_tools.mk_objects import ( + MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, + parallel_launch_with_config) + +# TODO (varun): These requirements are very strict and could be relaxed. +has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) + +meets_package_requirements = pytest.mark.skipif( + not has_all_packages, + reason="Requires deep_ep & deep_gemm & pplx packages", +) + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) + + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + + torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + + +def run(config: Config): + assert config.is_valid() + print(f"Testing config \n{config.describe()} ...") + + weights: WeightTensors = WeightTensors.make(config) + + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, vllm_config, + env_dict, config, weights) + + +Ms = [32, 64] +Ks = [7168] # hidden sizes +Ns = [2048] +TOPKs = [4, 1] +Es = [32] +DTYPEs = [torch.bfloat16] +FUSED_MOE_CHUNK_SIZEs = [None, 16] + + +def is_nyi_config(config: Config) -> bool: + # We know these configs to be legitimate. but still fail. + + if (config.fused_experts_type in [ + BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, + TritonExperts, TritonOrDeepGemmExperts + ]): + # The triton kernels expect both per-act-token-quant and + # per-out-ch-quant or neither. + unsupported_quant_config = ((config.is_per_act_token_quant + + config.is_per_out_ch_quant) == 1) + return unsupported_quant_config + + # cutlass kernels dont support expert_maps yet. + return config.fused_experts_type == CutlassExpertsFp8 + + +@pytest.mark.parametrize("k", Ks) +@pytest.mark.parametrize("n", Ns) +@pytest.mark.parametrize("e", Es) +@pytest.mark.parametrize("dtype", DTYPEs) +@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +@pytest.mark.parametrize( + "combination", + product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) +@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) +@pytest.mark.parametrize("world_size", [2]) +@meets_package_requirements +def test_modular_kernel_combinations_multigpu( + k: int, n: int, e: int, dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[mk.FusedMoEPrepareAndFinalize, + mk.FusedMoEPermuteExpertsUnpermute], + fused_moe_chunk_size: Optional[int], world_size: int): + + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=fused_moe_chunk_size, + world_size=world_size, + ) + if not config.is_valid(): + pytest.skip(f"Tests config {config} is not valid. Skipping ...") + + if is_nyi_config(config): + pytest.skip(f"Tests config {config} is nyi. Skipping ...") + + print(f"{config.describe()}") + run(config) + + +@pytest.mark.parametrize("k", Ks) +@pytest.mark.parametrize("n", Ns) +@pytest.mark.parametrize("e", Es) +@pytest.mark.parametrize("dtype", DTYPEs) +@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +@pytest.mark.parametrize( + "combination", + product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) +@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) +@pytest.mark.parametrize("world_size", [1]) +@meets_package_requirements +def test_modular_kernel_combinations_singlegpu( + k: int, n: int, e: int, dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[mk.FusedMoEPrepareAndFinalize, + mk.FusedMoEPermuteExpertsUnpermute], + fused_moe_chunk_size: Optional[int], world_size: int): + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=fused_moe_chunk_size, + world_size=world_size, + ) + + if not config.is_valid(): + pytest.skip(f"Tests config {config} is not valid. Skipping ...") + + if is_nyi_config(config): + pytest.skip(f"Tests config {config} is nyi. Skipping ...") + + run(config) + + +if __name__ == '__main__': + # Ability to test individual PrepareAndFinalize and FusedExperts combination + from .modular_kernel_tools.cli_args import (make_config, + make_config_arg_parser) + parser = make_config_arg_parser(description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + )) + args = parser.parse_args() + config = make_config(args) + + run(config) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fcaa93762856..2e8febbdcf26 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1072,6 +1072,7 @@ def torch_experts( quant_dtype: Optional[torch.dtype] = None, per_act_token_quant=False, block_shape: Optional[list[int]] = None, + apply_router_weights_on_input: bool = False, ) -> torch.Tensor: assert (global_num_experts == -1 or (global_num_experts == w1.shape[0] and expert_map is None) @@ -1081,11 +1082,17 @@ def torch_experts( M, K = a.shape topk = topk_ids.shape[1] + if apply_router_weights_on_input: + assert topk == 1 + a = a * topk_weight.to(a.dtype) + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype, + if a1_scale: + assert not per_act_token_quant and block_shape is None + a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype, per_act_token_quant, block_shape) num_experts = w1.shape[0] @@ -1104,6 +1111,7 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) elif block_shape is not None: + # block quantized assert (a_scale is not None and w1_scale is not None and w2_scale is not None) tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], @@ -1121,15 +1129,27 @@ def torch_experts( assert (a_scale is not None and w1_scale is not None and w2_scale is not None) scales = a_scale if a_scale.numel() == 1 else a_scale[mask] + tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) - tmp1 = tmp1 @ w1_dq - tmp2 = SiluAndMul()(tmp1) + tmp1 = (tmp1 @ w1_dq).to(out.dtype) + + tmp2 = SiluAndMul()(tmp1).to(out.dtype) + + tmp2, b_scale = moe_kernel_quantize_input( + tmp2, a2_scale, quant_dtype, per_act_token_quant, + block_shape) + assert b_scale is not None + + tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) - return (out.view(M, -1, w2.shape[1]).to(f32) * - topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) + if apply_router_weights_on_input: + return out + else: + return (out.view(M, -1, w2.shape[1]).to(f32) * + topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) def torch_moe(a: torch.Tensor, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 1bc2d8e0281c..eb467bb0736a 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -240,8 +240,7 @@ def prepare_communication_buffer_for_model(self, if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module.moe_config, - module.quant_config) + module.quant_method.init_prepare_finalize(module.moe_config) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 66abd8d7db7b..41faced58f1a 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -37,7 +37,6 @@ def __init__(self, block_shape=block_shape, per_act_token_quant=per_act_token_quant, )) - self.allow_deep_gemm = allow_deep_gemm self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4b..1d4c75228856 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -81,13 +81,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self, moe: FusedMoEConfig, - quant_config: Optional[QuantizationConfig]): + @staticmethod + def maybe_make_prepare_finalize( + moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - self.moe = moe - prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None if moe.use_pplx_kernels: @@ -160,8 +159,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, and moe.quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) - # Note (varun): Whether to use FP8 dispatch or not needs some - # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, @@ -169,11 +166,18 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, use_fp8_dispatch=use_fp8_dispatch, ) + return prepare_finalize + + def init_prepare_finalize(self, moe: FusedMoEConfig): + self.moe = moe + prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( + self.moe) + self.topk_indices_dtype = None if prepare_finalize is not None: logger.debug("%s", prepare_finalize.__class__.__name__) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 891ffd1c79b4..7188dc8707f7 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -7,7 +7,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts @@ -43,8 +44,10 @@ def __init__( per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) - self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant - and use_fp8_w8a8) + + self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and + self.block_shape == deep_gemm_block_shape()) + self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None