From d2d69fb0129883eea560093823f1a039de29d9fa Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 10 Jul 2025 19:08:56 +0000 Subject: [PATCH 1/3] [fix]: disable cutlass block scaled group gemm for EP Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 11 +++++++++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c8a8415baf23..4b3d716435e6 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -547,8 +547,9 @@ def cutlass_moe_fp4(a: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def _valid_cutlass_block_scaled_grouped_gemm( + w1: torch.Tensor, w2: torch.Tensor, + expert_map: Optional[torch.Tensor]) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return N % 128 == 0 and K % 128 == 0 @@ -564,6 +565,12 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") return False + if expert_map is not None: + logger.warning( + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" + " not supported.") + return False + return True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 26eeed1cd07f..36029c944409 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1190,7 +1190,7 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(w1, w2, expert_map)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states, From 0e4142add167157531fc45eeba522674560d086e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 10 Jul 2025 20:28:52 +0000 Subject: [PATCH 2/3] extra checks Signed-off-by: Duncan Moss --- vllm/entrypoints/openai/serving_score.py | 4 ++-- .../layers/fused_moe/cutlass_moe.py | 22 +++++++++++++++++-- .../layers/fused_moe/fused_moe.py | 5 +++-- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index b4fdbfcc7f60..8d47a417f9cd 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -216,8 +216,8 @@ async def _cross_encoding_score( # cross_encoder models defaults to using pad_token. tokenized_prompts = await asyncio.gather(*( tokenize_async( - text=t1, # type: ignore[arg-type] - text_pair=t2, # type: ignore[arg-type] + text=t1, # type: ignore[arg-type] + text_pair=t2, # type: ignore[arg-type] **tokenization_kwargs) for t1, t2 in input_pairs)) else: # `llm as reranker` models defaults to not using pad_token. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 4b3d716435e6..836e0a64052a 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -548,7 +548,8 @@ def cutlass_moe_fp4(a: torch.Tensor, def _valid_cutlass_block_scaled_grouped_gemm( - w1: torch.Tensor, w2: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, + apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor]) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): @@ -566,11 +567,28 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return False if expert_map is not None: - logger.warning( + logger.debug( "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" " not supported.") return False + if activation != "silu": + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: only activation silu is" + " supported.") + return False + + if apply_router_weight_on_input: + logger.debug("CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported.") + return False + + if inplace: + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: inplace is not supported." + ) + return False + return True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 36029c944409..63496aac6290 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1190,8 +1190,9 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(w1, w2, expert_map)): - assert apply_router_weight_on_input is False + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, + expert_map)): return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, From 05d9515764356f2aa4e74a8ef0c81dc71a853c94 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 10 Jul 2025 21:34:11 +0000 Subject: [PATCH 3/3] small opt Signed-off-by: Duncan Moss --- .../cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu index 236d76ed5208..6c8f6309ef43 100644 --- a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm( reinterpret_cast( layout_sfb.data_ptr())}; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a_ptrs.get_device(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + int device_id = a_ptrs.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; // Epilogue Arguments typename GemmKernel::EpilogueArguments epilogue_args{