diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 2e1fa50e2ab3..2ac98976539e 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py +# usage: +# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \ +# python examples/offline_inference/data_parallel.py # we need to have a launcher to create multiple data parallel # ranks. And each rank will create a vLLM instance to process its own prompts. import os @@ -7,6 +9,9 @@ from vllm import LLM, SamplingParams from vllm.utils import get_open_port +GPUs_per_dp_rank = 2 +DP_size = 2 + def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): os.environ["VLLM_DP_RANK"] = str(dp_rank) @@ -48,8 +53,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): max_tokens=16 * (dp_rank + 1)) # Create an LLM. - llm = LLM(model="facebook/opt-125m", - tensor_parallel_size=2, + llm = LLM(model="ibm-research/PowerMoE-3b", + tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=True) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -62,14 +67,12 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank): if __name__ == "__main__": from multiprocessing import Process - dp_size = 2 - GPUs_per_dp_rank = 2 dp_master_ip = "127.0.0.1" dp_master_port = get_open_port() procs = [] - for i in range(dp_size): + for i in range(DP_size): proc = Process(target=main, - args=(dp_size, i, dp_master_ip, dp_master_port, + args=(DP_size, i, dp_master_ip, dp_master_port, GPUs_per_dp_rank)) proc.start() procs.append(proc) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2f5c69046f48..52893f4329ec 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -217,6 +217,7 @@ def test_mixtral_moe(dtype: torch.dtype): intermediate_size=config.intermediate_size, params_dtype=dtype, tp_size=1, + dp_size=1, ).cuda() # Load the weights diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 58a3b4ee43ce..7810089a05c7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -324,7 +324,7 @@ def unified_attention( ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - self = forward_context.attn_layers[layer_name] + self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) @@ -356,7 +356,7 @@ def unified_attention_with_output( ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - self = forward_context.attn_layers[layer_name] + self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b972f03c9685..afb63cf8319f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -396,8 +396,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: cache_dir = self.compilation_config.cache_dir os.makedirs(cache_dir, exist_ok=True) - local_cache_dir = os.path.join( - cache_dir, f"rank_{vllm_config.parallel_config.rank}") + rank = vllm_config.parallel_config.rank + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") self.compilation_config.local_cache_dir = local_cache_dir disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c3d20cff426c..540a35e1ecb9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -25,16 +25,22 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class DPMetadata: + num_tokens_across_dp: list[int] + cu_tokens_across_dp_cpu: torch.Tensor + + @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context - attn_layers: dict[str, Any] + no_compile_layers: dict[str, Any] # TODO: extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass - num_tokens_across_dp: Optional[ - list[int]] = None # set dynamically for each forward pass + # set dynamically for each forward pass + dp_metadata: Optional[DPMetadata] = None _forward_context: Optional[ForwardContext] = None @@ -61,7 +67,7 @@ def set_forward_context(attn_metadata: Any, need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() - num_tokens_across_dp = None + dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank @@ -82,15 +88,17 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - num_tokens_across_dp = num_tokens_tensor.tolist() + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) + dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu) global _forward_context prev_context = _forward_context _forward_context = ForwardContext( - attn_layers=vllm_config.compilation_config.static_forward_context, + no_compile_layers=vllm_config.compilation_config. + static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - num_tokens_across_dp=num_tokens_across_dp) + dp_metadata=dp_metadata) try: yield finally: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 052d4d54601f..33d2896f3fd2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,9 +8,11 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.config import get_current_vllm_config +from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( @@ -18,6 +20,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum +from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): from .fused_moe import fused_experts @@ -246,6 +249,51 @@ def forward_tpu( forward_native = forward_cuda +def determine_expert_map( + ep_size: int, ep_rank: int, + global_num_experts: int) -> Tuple[int, Optional[torch.Tensor]]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Args: + ep_size (int): The size of the expert parallel group + global_num_experts (int): The total number of experts in the model. + + Returns: + Tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + local_num_experts = global_num_experts // ep_size + + # Create a tensor of size num_experts filled with -1 + expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) + # Create a expert map for the local experts + if ep_rank < (ep_size - 1): + # Each non-last rank gets local_num_experts experts. + expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = (global_num_experts - ep_rank * local_num_experts) + + expert_map[-local_num_experts:] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + return (local_num_experts, expert_map) + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -282,6 +330,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, ep_size: Optional[int] = None, + dp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", @@ -293,16 +342,48 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() + # For smuggling this layer into the fused moe custom op + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP + + # Note: here we guard against accessing the TP and DP groups when + # uninitialized (this happens when testing) self.tp_size = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) + tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() + self.dp_size = (dp_size + if dp_size is not None else get_dp_group().world_size) + self.dp_rank = (0 + if self.dp_size == 1 else get_dp_group().rank_in_group) + self.global_num_experts = num_experts + if envs.VLLM_TEST_ENABLE_EP: - self.ep_size = self.tp_size + # Set TP size to 1 to adjust for EP and adjust EP size and rank + # for DP attention. + self.ep_rank = tp_rank + self.tp_size * self.dp_rank + self.tp_rank = 0 + self.ep_size = self.tp_size * self.dp_size self.tp_size = 1 + + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) else: + # Adjust TP size for DP attention + self.tp_rank = tp_rank + self.tp_size * self.dp_rank + self.ep_rank = 0 + self.tp_size = self.tp_size * self.dp_size self.ep_size = 1 + self.local_num_experts = self.global_num_experts + self.expert_map = None self.top_k = top_k self.global_num_experts = num_experts - self.local_num_experts = self.global_num_experts // self.ep_size + assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -316,26 +397,6 @@ def __init__( self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias self.activation = activation - self.expert_map = None - - if self.ep_size > 1: - # Create a tensor of size num_experts filled with -1 - self.expert_map = torch.full((self.global_num_experts, ), - -1, - dtype=torch.int32) - # Create a expert map for the local experts - ep_rank = get_tensor_model_parallel_rank() - if ep_rank < (self.ep_size - 1): - # Each non-last rank gets local_num_experts experts. - self.expert_map[ep_rank * self.local_num_experts: - (ep_rank + 1) * self.local_num_experts] = \ - torch.arange(0, self.local_num_experts, dtype=torch.int32) - else: - # All remaining experts are assigned to the last rank. - self.local_num_experts = (self.global_num_experts - - ep_rank * self.local_num_experts) - self.expert_map[-self.local_num_experts:] = \ - torch.arange(0, self.local_num_experts, dtype=torch.int32) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -493,9 +554,6 @@ def weight_loader(self, param: torch.nn.Parameter, if expert_id == -1: return - # TP rank is set to 0 if EP is enabled - tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank() - # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -539,8 +597,7 @@ def weight_loader(self, param: torch.nn.Parameter, final_shape = list(loaded_weight.shape) if shard_id in ["w1", "w3"]: final_shape[1] *= 2 - final_shape[shard_dim] = final_shape[ - shard_dim] // get_tensor_model_parallel_world_size() + final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) expert_data = param.data if full_load else param.data[expert_id] @@ -567,7 +624,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_id=shard_id, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=self.tp_rank) return # Case weight scales and zero_points @@ -584,7 +641,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=self.tp_rank) elif quant_method in [ FusedMoeWeightScaleSupported.GROUP.value, FusedMoeWeightScaleSupported.BLOCK.value, @@ -594,7 +651,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank, + tp_rank=self.tp_rank, load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, @@ -621,7 +678,7 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=self.tp_rank) return @staticmethod @@ -665,10 +722,45 @@ def select_experts(hidden_states: torch.Tensor, return topk_weights, topk_ids + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(get_dp_group().world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + + return buffer + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + if self.use_direct_call: + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): assert self.quant_method is not None + if self.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -687,6 +779,14 @@ def forward(self, hidden_states: torch.Tensor, activation=self.activation, ) + if self.dp_size > 1: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = get_dp_group().all_reduce(final_hidden_states) + final_hidden_states = all_hidden_states[start:end, :] + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( @@ -757,3 +857,26 @@ def extra_repr(self) -> str: s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 return s + + +def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.quant_method is not None + + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="moe_forward", + op_func=moe_forward, + mutates_args=[], + fake_impl=moe_forward_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 061a9a5bd2bc..53872812b323 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict): pixel_values: torch.Tensor pixel_mask: Optional[torch.Tensor] """ - Shape: + Shape: pixel_values: `(batch_size * num_images, num_channels, height, width)` pixel_mask: `(batch_size * num_images, height, width)` """ @@ -135,11 +135,11 @@ class AriaProjector(nn.Module): query numbers, e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - kv_dim (int): Dimension of key and value. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. Outputs: @@ -239,6 +239,7 @@ def __init__( self, config: AriaTextConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -254,6 +255,7 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, reduce_results=True, + prefix=f"{prefix}.experts", ) self.shared_experts = LlamaMLP( config.hidden_size, @@ -301,7 +303,9 @@ def __init__( prefix: str = "", ) -> None: super().__init__(config, cache_config, quant_config, prefix) - self.mlp = AriaTextMoELayer(config, quant_config=quant_config) + self.mlp = AriaTextMoELayer(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") class AriaTextModel(LlamaModel, SupportsQuant): diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 7830dd4ce2ec..b66529860bc2 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -65,6 +65,7 @@ def __init__( config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, + prefix: str = "", ): super().__init__( num_experts=config.ffn_config.moe_num_experts, @@ -76,6 +77,7 @@ def __init__( renormalize=True, quant_config=quant_config, tp_size=get_tensor_model_parallel_world_size(), + prefix=prefix, ) self.config = config self.tp_size = get_tensor_model_parallel_world_size() @@ -139,6 +141,7 @@ def __init__( config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -150,7 +153,8 @@ def __init__( self.experts = DbrxExperts(config=config, quant_config=quant_config, - params_dtype=self.params_dtype) + params_dtype=self.params_dtype, + prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -291,7 +295,7 @@ def __init__( cache_config, quant_config, prefix=f"{prefix}.norm_attn_norm") - self.ffn = DbrxMoE(config, quant_config) + self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn") def forward( self, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 58eccd6a6b87..92d40ae7d565 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -47,7 +47,8 @@ def __init__(self, top_k: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.num_total_experts = num_experts or config.num_experts self.top_k = top_k or config.num_experts_per_tok @@ -70,7 +71,8 @@ def __init__(self, reduce_results=True, renormalize=False, use_grouped_topk=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -92,13 +94,15 @@ def __init__(self, config: JambaConfig, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__(config, num_experts=1, top_k=1, params_dtype=params_dtype, tp_size=tp_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) class JambaMambaDecoderLayer(nn.Module): @@ -109,6 +113,7 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, + prefix: str = "", **kwargs) -> None: super().__init__() self.config = config @@ -129,7 +134,9 @@ def __init__(self, num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.feed_forward = ffn_layer_class(config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward") self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, @@ -211,7 +218,9 @@ def __init__(self, num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP - self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.feed_forward = ffn_layer_class(config, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward") self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = RMSNorm(config.hidden_size, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c8dea557e571..f91b20707031 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -71,6 +71,7 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + dp_size: Optional[int] = None, prefix: str = ""): super().__init__() self.hidden_size = hidden_size @@ -93,6 +94,7 @@ def __init__(self, renormalize=True, quant_config=quant_config, tp_size=tp_size, + dp_size=dp_size, prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index e27ff5deace2..392e95575dc4 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -80,7 +80,8 @@ def __init__(self, reduce_results=True, renormalize=False, quant_config=quant_config, - tp_size=tp_size) + tp_size=tp_size, + prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -212,6 +213,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index c35c7e9fcce7..99bd58a83257 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -249,6 +249,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -272,7 +273,8 @@ def __init__( renormalize=False, quant_config=quant_config, tp_size=tp_size, - custom_routing_function=phimoe_routing_function) + custom_routing_function=phimoe_routing_function, + prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -396,6 +398,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", ) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 41536b34b2f2..366e020f17d5 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -100,6 +100,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -115,7 +116,8 @@ def __init__( intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.experts") self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, @@ -277,7 +279,8 @@ def __init__( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bffa113cab89..519905539167 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -111,6 +111,7 @@ def log_warnings(cls): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config + compilation_config = vllm_config.compilation_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: @@ -150,6 +151,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "FlashMLA: Forcing kv cache block size to 64 since this" " is currently the only block size supported by the kernel.") + if (parallel_config.data_parallel_size > 1 + and compilation_config.use_cudagraph): + logger.info( + "Data Parallel: Forcing enforce eager to be True since DP is " + "currently not supported with CUDA Graphs.") + vllm_config.model_config.enforce_eager = True + compilation_config.use_cudagraph = False + @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None diff --git a/vllm/utils.py b/vllm/utils.py index 26c9e1a90837..114eb9b36dbc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2194,8 +2194,8 @@ def bind_kv_cache( from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ layer_name for layer_name in ctx - if ctx[layer_name].attn_type in (AttentionType.DECODER, - AttentionType.ENCODER_DECODER) + if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) ] layer_index_sorted = sorted( set( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b9bf8fac40f6..473a1a73cef8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -149,7 +149,6 @@ def step(self) -> EngineCoreOutputs: if not self.scheduler.has_unfinished_requests(): return EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats()) - scheduler_output = self.scheduler.schedule() output = self.model_executor.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a1fb0514c3f..a1a50e89676b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,6 +17,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs @@ -1357,7 +1358,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: - kv_cache_config: Configuration for the KV cache, including the KV + kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: @@ -1389,10 +1390,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Generates the KVCacheSpec by parsing the kv cache format from each + Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache + KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ @@ -1400,6 +1401,9 @@ def get_kv_cache_spec(self) -> KVCacheSpec: block_size = self.vllm_config.cache_config.block_size kv_cache_spec: KVCacheSpec = {} for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + # TODO: Support other attention modules, e.g., sliding window, # cross-attention, MLA. assert isinstance(attn_module, Attention)