diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 563d2fe4ef65..13c6178941cf 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -45,6 +45,9 @@ struct SSMParamsBase { index_t out_d_stride; index_t out_z_batch_stride; index_t out_z_d_stride; + index_t ssm_states_batch_stride; + index_t ssm_states_dim_stride; + index_t ssm_states_dstate_stride; // Common data pointers. void *__restrict__ A_ptr; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5766fbab4e87..c4ddbc142791 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; - + input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll @@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } // Initialize running total - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; if (chunk == n_chunks - 1) { - ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y); } } #pragma unroll @@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_batch_stride = out.stride(1); params.out_d_stride = out.stride(0); + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); + } else{ if (!is_variable_B) { @@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, } params.out_batch_stride = out.stride(0); params.out_d_stride = out.stride(1); + + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); } } diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c058c20f1ed7..49ca2a3b8c74 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -367,9 +367,9 @@ th { | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 38399c6633bd..d30144e8a825 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🟢 (Mamba-2), 🟡 (Mamba-1) | +| **Mamba Models** | 🟢 (Mamba-2), 🟢 (Mamba-1) | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention are partially supported. -Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers -(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require -disabling prefix caching in V1. +Models using selective state-space mechanisms instead of standard transformer attention are supported. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. -Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. #### Encoder-Decoder Models diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2238924c1b50..67ba2f25593d 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -53,6 +53,8 @@ ] V1_SUPPORTED_MODELS = [ + "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", "Zyphra/Zamba2-1.2B-instruct", diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index b68ed298a189..a756c89b520f 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,6 @@ UNSUPPORTED_MODELS_V1 = [ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder - "state-spaces/mamba-130m-hf", # mamba1 ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 60cf3e11885a..17b7f84a933f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,30 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from torch import nn from torch.nn.parameter import Parameter -from vllm.attention.backends.abstract import AttentionMetadata +from vllm import envs +from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer") -class MambaMixer(CustomOp): +class MambaMixer(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -47,13 +54,16 @@ def __init__(self, rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, activation="silu", - is_lora_enabled: bool = False): + is_lora_enabled: bool = False, + prefix: str = ""): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation self.is_lora_enabled = is_lora_enabled + self.conv_kernel_size = conv_kernel_size + self.intermediate_size = intermediate_size self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -131,14 +141,62 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): has_weight=rms_norm_has_weight, ) if use_rms_norm else None - def forward_native(self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, ssm_state: torch.Tensor): + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + + self.prefix = prefix + + def forward(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + if not envs.VLLM_USE_V1: + return CustomOp.forward(self, hidden_states, mamba_cache_params) + else: + return self.forward_cuda(hidden_states, mamba_cache_params) + + def forward_native(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): pass - def forward_cuda(self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams): + def forward_cuda(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + has_initial_state = mamba1_metadata.has_initial_states + context_lens_tensor = mamba1_metadata.context_lens_tensor + else: + assert mamba_cache_params is not None + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + query_start_loc = attn_metadata.query_start_loc + context_lens_tensor = attn_metadata.context_lens_tensor - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + if context_lens_tensor is not None: + has_initial_state = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -148,8 +206,12 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states = hidden_states.contiguous() + return self.out_proj(hidden_states.transpose(-2, -1))[0] + + if query_start_loc is not None and context_lens_tensor is not None: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -161,18 +223,18 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights, bias=self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=state_indices_tensor, + query_start_loc=query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) + conv_state_indices=state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -203,11 +265,10 @@ def forward_cuda(self, hidden_states: torch.Tensor, time_proj_bias = (self.dt_proj.bias.float() if hasattr( self.dt_proj, "bias") else None) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: + if query_start_loc is not None and context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - mamba_cache_params.ssm_state, + ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -216,24 +277,23 @@ def forward_cuda(self, hidden_states: torch.Tensor, gate, time_proj_bias, delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) + cache_indices=state_indices_tensor, + has_initial_state=has_initial_state, + query_start_loc=query_start_loc) else: scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, - out=scan_outputs) + selective_state_update(ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor, + out=scan_outputs) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -245,3 +305,15 @@ def forward_cuda(self, hidden_states: torch.Tensor, contextualized_states = self.out_proj( scan_outputs.transpose(-2, -1))[0] return contextualized_states + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.intermediate_size, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + ) + + @property + def mamba_type(self) -> str: + return "mamba1" diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5ac9a7f9ab3e..d5f4877135c9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - extra_groups_for_head_shards, get_mamba_state_shape) + MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated @@ -278,8 +278,9 @@ def __init__( # - for TP we shard conv_dim by sharding on n_groups, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + extra_groups_for_head_shards( + groups = MambaStateShapeCalculator.extra_groups_for_head_shards( n_groups, self.tp_size) + self.n_groups = n_groups + groups self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.conv1d = ColumnParallelLinear( @@ -732,7 +733,7 @@ def forward_cuda( output[:num_actual_tokens], _ = self.out_proj(hidden_states) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=self.intermediate_size, tp_world_size=get_tensor_model_parallel_world_size(), n_groups=self.n_groups, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 99a582066c0d..42c815b08f04 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -3,53 +3,70 @@ from vllm.distributed import divide -def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups - - -def get_mamba_state_shape( - intermediate_size: int, - tp_world_size: int, - n_groups: int, - num_heads: int, - head_dim: int, - state_size: int, - conv_kernel: int, - use_v1: bool = True, -) -> tuple[tuple[int, int], tuple[int, int, int]]: - """ Get the shape of mamba state.""" - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (n_groups + - extra_groups_for_head_shards(n_groups, tp_world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * state_size) - # contiguous along 'dim' axis - conv_state_shape = ( - conv_kernel - 1, - divide(conv_dim, tp_world_size), - ) - - if not use_v1: - conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = ( - divide(num_heads, tp_world_size), - head_dim, - state_size, - ) - - return conv_state_shape, temporal_state_shape +class MambaStateShapeCalculator: + + @classmethod + def mamba1_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int]]: + conv_state_shape = (divide(intermediate_size, + tp_world_size), conv_kernel - 1) + + temporal_state_shape = (divide(intermediate_size, + tp_world_size), state_size) + + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + return conv_state_shape, temporal_state_shape + + @classmethod + def mamba2_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = n_groups + cls.extra_groups_for_head_shards( + n_groups, tp_world_size) + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) + if not use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = (divide(num_heads, + tp_world_size), head_dim, state_size) + return conv_state_shape, temporal_state_shape + + @classmethod + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 0f5494427634..4a2ae07581f3 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -457,7 +458,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 6a58b1501fe6..85d64af5bd28 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -24,7 +24,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -543,7 +544,7 @@ def get_mamba_state_shape_from_config( if hf_config.mamba_d_ssm is None else hf_config.mamba_d_ssm) - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 59c1dce48ee7..e59502f12a1c 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -23,7 +23,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -547,7 +548,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 263f4c8379cf..8a9efd4d7247 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -8,6 +8,7 @@ from torch import nn from transformers import JambaConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -19,6 +20,8 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -32,8 +35,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsV0Only) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -112,7 +114,8 @@ def __init__(self, use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - is_lora_enabled = self.is_lora_enabled + is_lora_enabled = self.is_lora_enabled, + prefix=f"{prefix}.mixer", ) num_experts = config.layers_num_experts[layer_idx] @@ -344,7 +347,8 @@ def forward( layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache_index += 1 - if isinstance(layer, JambaMambaDecoderLayer): + if isinstance(layer, + JambaMambaDecoderLayer) and mamba_cache_params: current_state_layer = mamba_cache_index layer_mamba_cache_params = mamba_cache_params.at_layer_idx( current_state_layer) @@ -442,7 +446,7 @@ def load_weights(self, weights: Iterable[tuple[str, class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only): + IsHybrid): hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ ".self_attn.": ".", ".A_log": ".A" @@ -509,14 +513,19 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_layers, *state_shape) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -529,19 +538,22 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_conv - 1, - ) - temporal_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_state, + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + hidden_size = hf_config.hidden_size + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.mamba_expand * hidden_size, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=envs.VLLM_USE_V1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8162ac3f7597..80b63e15377a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,20 +8,21 @@ from torch import nn from transformers import MambaConfig +from vllm import envs from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP, - SupportsV0Only) + IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -41,7 +42,8 @@ def __init__(self, config: MambaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False) -> None: + is_lora_enabled: Optional[bool] = False, + prefix: str = "") -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" @@ -58,7 +60,8 @@ def __init__(self, rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act, - is_lora_enabled=self.is_lora_enabled) + is_lora_enabled=self.is_lora_enabled, + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -107,7 +110,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: MambaDecoderLayer(config, cache_config=cache_config, quant_config=quant_config, - is_lora_enabled=is_lora_enabled), + is_lora_enabled=is_lora_enabled, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -123,7 +127,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, + mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -140,12 +144,17 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + + layer_cache_params = None + if mamba_cache_params is not None: + layer_cache_params = mamba_cache_params.at_layer_idx( + i - self.start_layer) + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer)) + mamba_cache_params=layer_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -176,8 +185,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, - SupportsV0Only): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -227,20 +235,40 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_layers, *state_shape) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) return hidden_states + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.intermediate_size, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=envs.VLLM_USE_V1) + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) @@ -248,19 +276,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel - 1, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index adad181617e6..75e92b01762d 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -19,7 +19,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -220,7 +221,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 6a999e2254e7..eb62d5a53c1a 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -39,7 +39,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -482,7 +483,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 7764fd9b9e08..4cb0becf302f 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -32,7 +32,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -869,7 +870,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_ngroups, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py new file mode 100644 index 000000000000..f0e4636fdb52 --- /dev/null +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + + +class Mamba1AttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: + return Mamba1AttentionMetadataBuilder + + +@dataclass +class Mamba1AttentionMetadata: + query_start_loc: torch.Tensor + context_lens_tensor: torch.Tensor + state_indices_tensor: torch.Tensor + has_initial_states: torch.Tensor + + +class Mamba1AttentionMetadataBuilder( + AttentionMetadataBuilder[Mamba1AttentionMetadata]): + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__( + self, + kv_cache_spec: AttentionSpec, + vllm_config: VllmConfig, + device: torch.device, + layer_names: list[str], + ): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba1AttentionMetadata: + query_start_loc = common_attn_metadata.query_start_loc + + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + query_start_loc.device) + has_initial_states = (context_lens_tensor > 0) + + return Mamba1AttentionMetadata( + query_start_loc=query_start_loc, + context_lens_tensor=context_lens_tensor, + has_initial_states=has_initial_states, + state_indices_tensor=state_indices_tensor, + ) diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index 80021a216556..f56f2fb7bf69 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: + if mamba_type == "mamba1": + return Mamba1AttentionBackend + if mamba_type == "mamba2": return Mamba2AttentionBackend