From 8d657d65d03556cfc80b33b70451cc189ccb5476 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 6 Mar 2025 21:50:10 +0000 Subject: [PATCH 01/11] Porting over MoE padding feature from ROCm/vllm Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_moe.py | 40 ++++++++++++++----- vllm/envs.py | 5 +++ .../layers/fused_moe/fused_moe.py | 8 ++-- vllm/model_executor/layers/fused_moe/layer.py | 13 ++++++ 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 52893f4329ec..fb8baa90a2b7 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -5,6 +5,8 @@ """ import pytest import torch +from torch.nn import Parameter +from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock @@ -12,6 +14,7 @@ from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm import _custom_ops as ops +from vllm import envs from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) @@ -65,16 +68,7 @@ def test_fused_moe( else: e_map = None - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk, e_map) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) iterative_output = iterative_moe(a, w1, w2, @@ -83,6 +77,23 @@ def test_fused_moe( global_num_experts=e, expert_map=e_map, renormalize=False) + + # Pad the weight if moe padding is enabled + if envs.VLLM_ROCM_MOE_PADDING: + w1 = F.pad(w1, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0) + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, @@ -233,6 +244,17 @@ def test_mixtral_moe(dtype: torch.dtype): # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) + # Pad the weight if moe padding is enabled + if envs.VLLM_ROCM_MOE_PADDING: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) diff --git a/vllm/envs.py b/vllm/envs.py index 2489affbcbd2..1da1fdd2d0cb 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = False VLLM_ROCM_FP8_PADDING: bool = True + VLLM_ROCM_MOE_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -521,6 +522,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Divisor for dynamic key scale factor calculation for FP8 KV Cache + + # Pad the weights for the moe kernel + "VLLM_ROCM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "0"))), "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5336b3c10023..c2b18c241f09 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +padding_size = 128 if envs.VLLM_ROCM_MOE_PADDING else 0 @triton.jit @@ -769,7 +770,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - A.shape[1], + B.shape[2] - padding_size, EM, topk_ids.numel(), A.stride(0), @@ -1205,7 +1206,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[ + 1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1232,7 +1234,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), top_k_num, config_dtype, block_shape=block_shape, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d0209eb40e8c..86dc4e0059ba 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,8 +5,10 @@ from typing import Callable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter +from vllm import envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -98,6 +100,17 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) + if envs.VLLM_ROCM_MOE_PADDING: + layer.w13_weight = torch.nn.Parameter(F.pad( + layer.w13_weight.data, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter(F.pad(layer.w2_weight.data, + (0, 128), "constant", + 0), + requires_grad=False) + torch.cuda.empty_cache() + if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: import intel_extension_for_pytorch as ipex From ece4e83014f8a76e95e53af1ab37028eb8a4cdd6 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Fri, 7 Mar 2025 18:10:08 +0000 Subject: [PATCH 02/11] Adding MoE padding to unit tests Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_moe.py | 29 ++++++++++++++----- .../layers/fused_moe/fused_moe.py | 2 ++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index fb8baa90a2b7..c562ddc4c35e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -3,6 +3,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +import unittest.mock as mock + import pytest import torch from torch.nn import Parameter @@ -40,6 +42,7 @@ @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) def test_fused_moe( m: int, n: int, @@ -48,7 +51,14 @@ def test_fused_moe( topk: int, ep_size: int, dtype: torch.dtype, + padding: bool, ): + if padding: + padding_size = 128 + envs.VLLM_ROCM_MOE_PADDING = True + else: + padding_size = 0 + envs.VLLM_ROCM_MOE_PADDING = False a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -85,14 +95,17 @@ def test_fused_moe( w2 = F.pad(w2, (0, 128), "constant", 0) torch.cuda.empty_cache() - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + with mock.patch( + 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', + padding_size): + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c2b18c241f09..e5ace5173c65 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -719,6 +719,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + assert padding_size == 0, "MoE padding is not supported " \ + "with GPTQ/AWQ quantization" fused_moe_kernel_gptq_awq[grid]( A, From b401d67c6170cf3e71dc13595b1520f9aa58a848 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Fri, 7 Mar 2025 18:19:47 +0000 Subject: [PATCH 03/11] Parameterized mixtral moe test for padding as well Signed-off-by: Gregory Shtrasberg --- tests/kernels/test_moe.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index c562ddc4c35e..e33537fb8f14 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -53,12 +53,8 @@ def test_fused_moe( dtype: torch.dtype, padding: bool, ): - if padding: - padding_size = 128 - envs.VLLM_ROCM_MOE_PADDING = True - else: - padding_size = 0 - envs.VLLM_ROCM_MOE_PADDING = False + envs.VLLM_ROCM_MOE_PADDING = padding + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -97,7 +93,7 @@ def test_fused_moe( with mock.patch( 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', - padding_size): + 128 if padding else 0): triton_output = fused_moe(a, w1, w2, @@ -226,8 +222,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("padding", [True, False]) @torch.inference_mode() -def test_mixtral_moe(dtype: torch.dtype): +def test_mixtral_moe(dtype: torch.dtype, padding: bool): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" @@ -244,6 +241,8 @@ def test_mixtral_moe(dtype: torch.dtype): dp_size=1, ).cuda() + envs.VLLM_ROCM_MOE_PADDING = padding + # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): @@ -270,7 +269,10 @@ def test_mixtral_moe(dtype: torch.dtype): # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) - vllm_states = vllm_moe.forward(vllm_inputs) + with mock.patch( + 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', + 128 if padding else 0): + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, From 83d473a7e55307e2bbdb0b6494f8fb03a4c610a1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 11 Mar 2025 22:24:28 +0000 Subject: [PATCH 04/11] implmenting moe padding with tensor slicing Signed-off-by: charlifu --- tests/kernels/test_moe.py | 40 +++++++------------ .../layers/fused_moe/fused_moe.py | 13 +++--- vllm/model_executor/layers/fused_moe/layer.py | 23 ++++++----- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e33537fb8f14..b56574723d2e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -53,8 +53,6 @@ def test_fused_moe( dtype: torch.dtype, padding: bool, ): - envs.VLLM_ROCM_MOE_PADDING = padding - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -85,23 +83,20 @@ def test_fused_moe( renormalize=False) # Pad the weight if moe padding is enabled - if envs.VLLM_ROCM_MOE_PADDING: - w1 = F.pad(w1, (0, 128), "constant", 0) + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0) + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] torch.cuda.empty_cache() - with mock.patch( - 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', - 128 if padding else 0): - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -241,8 +236,6 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool): dp_size=1, ).cuda() - envs.VLLM_ROCM_MOE_PADDING = padding - # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): @@ -257,22 +250,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool): vllm_inputs = hf_inputs.flatten(0, 1) # Pad the weight if moe padding is enabled - if envs.VLLM_ROCM_MOE_PADDING: + if padding: vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0), + vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) torch.cuda.empty_cache() vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0), + vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], requires_grad=False) torch.cuda.empty_cache() # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) - with mock.patch( - 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size', - 128 if padding else 0): - vllm_states = vllm_moe.forward(vllm_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e5ace5173c65..a0df1999fe1f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,7 +18,6 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) -padding_size = 128 if envs.VLLM_ROCM_MOE_PADDING else 0 @triton.jit @@ -719,8 +718,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 - assert padding_size == 0, "MoE padding is not supported " \ - "with GPTQ/AWQ quantization" fused_moe_kernel_gptq_awq[grid]( A, @@ -772,7 +769,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2] - padding_size, + B.shape[2], EM, topk_ids.numel(), A.stride(0), @@ -1209,12 +1206,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, 2], "Hidden size mismatch" else: assert hidden_states.shape[ - 1] == w1.shape[2] - padding_size, "Hidden size mismatch" + 1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must 1" + assert w2.stride(-1) == 1, "Stride of last dimension must 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] @@ -1236,7 +1233,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), + w2.shape, top_k_num, config_dtype, block_shape=block_shape, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 86dc4e0059ba..af57d20df102 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -97,19 +97,22 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - if envs.VLLM_ROCM_MOE_PADDING: - layer.w13_weight = torch.nn.Parameter(F.pad( - layer.w13_weight.data, (0, 128), "constant", 0), - requires_grad=False) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter(F.pad(layer.w2_weight.data, - (0, 128), "constant", - 0), - requires_grad=False) - torch.cuda.empty_cache() + layer.w13_weight = self.add_padding_to_weight(layer.w13_weight.data) + layer.w2_weight = self.add_padding_to_weight(layer.w2_weight.data) if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: From 718cc26e49c296b53b95cc94ba4e913aec774117 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 12 Mar 2025 14:56:34 +0000 Subject: [PATCH 05/11] fix grammar issue of error message Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a0df1999fe1f..7d083a4cdb33 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1210,8 +1210,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must 1" - assert w2.stride(-1) == 1, "Stride of last dimension must 1" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] From a0f3706ed8182d7beaef3fa54ce15b2473de78d1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 12 Mar 2025 15:23:46 +0000 Subject: [PATCH 06/11] assign Parameter to weight Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index af57d20df102..5e91cde9b84b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -111,8 +111,8 @@ def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = self.add_padding_to_weight(layer.w13_weight.data) - layer.w2_weight = self.add_padding_to_weight(layer.w2_weight.data) + layer.w13_weight = torch.nn.Parameter(self.add_padding_to_weight(layer.w13_weight.data), requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self.add_padding_to_weight(layer.w2_weight.data), requires_grad=False) if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: From fa2b8d1b802ad264b4377ce3bb127c7edcf4c2d1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 12 Mar 2025 15:48:07 +0000 Subject: [PATCH 07/11] linting Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5e91cde9b84b..f19c7a6c3807 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -111,8 +111,12 @@ def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self.add_padding_to_weight(layer.w13_weight.data), requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self.add_padding_to_weight(layer.w2_weight.data), requires_grad=False) + layer.w13_weight = torch.nn.Parameter(self.add_padding_to_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self.add_padding_to_weight( + layer.w2_weight.data), + requires_grad=False) if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: From 6e36f513bf6e30a734237cda6e19ffb506166b74 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 19 Mar 2025 18:43:03 +0000 Subject: [PATCH 08/11] fix pre-commit Signed-off-by: charlifu --- tests/kernels/test_moe.py | 2 -- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b56574723d2e..653d2734afe8 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -3,7 +3,6 @@ Run `pytest tests/kernels/test_moe.py`. """ -import unittest.mock as mock import pytest import torch @@ -16,7 +15,6 @@ from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm import _custom_ops as ops -from vllm import envs from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cdfac5255b8b..4de020ff81c0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1318,8 +1318,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[ - 1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" From fe3025dc4fdc227c66d3850a8887536b9e7628e9 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 20 Mar 2025 21:35:53 +0000 Subject: [PATCH 09/11] change padding function name to maybe pad weight Signed-off-by: charlifu --- vllm/model_executor/layers/fused_moe/layer.py | 6 +++--- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0adf9af4f439..8aa5375c646a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -97,7 +97,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() @@ -111,10 +111,10 @@ def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self.add_padding_to_weight( + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w13_weight.data), requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self.add_padding_to_weight( + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2d5d8e6adc9c..1cefafe316c5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -255,7 +255,7 @@ def create_weights( else: layer.register_parameter("input_scale", None) - def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() @@ -279,7 +279,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight.data weight_scale_inv = layer.weight_scale_inv.data - weight = self.add_padding_to_weight(weight) + weight = self._maybe_pad_weight(weight) # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) From e30941771b8756a20d9a950998dc4c8627b93e20 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 20 Mar 2025 21:39:56 +0000 Subject: [PATCH 10/11] enable moe padding by default Signed-off-by: charlifu --- vllm/envs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a90b67ffb6db..99586ee53f28 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -526,11 +526,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache # Pad the weights for the moe kernel "VLLM_ROCM_MOE_PADDING": - lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "0"))), + lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + + # Divisor for dynamic key scale factor calculation for FP8 KV Cache "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), From 0b5b5901e3d76640dfbaf10c46fe445e17415f41 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 21 Mar 2025 15:23:59 +0000 Subject: [PATCH 11/11] fix pre-commit Signed-off-by: charlifu --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1cefafe316c5..d92b0931a6ee 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -343,7 +343,7 @@ def process_weights_after_loading(self, layer: Module) -> None: logical_widths=layer.logical_widths, ) - weight = self.add_padding_to_weight(weight) + weight = self._maybe_pad_weight(weight) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)