From a888bef22cb686c42bf4a0ff8397762fbf9e39fd Mon Sep 17 00:00:00 2001 From: amirk Date: Mon, 11 Aug 2025 17:13:36 +0300 Subject: [PATCH 1/9] feat: Added split to prefil and decode in mamba1 mixer Signed-off-by: amirk --- .../layers/mamba/mamba_mixer.py | 309 ++++++++++++------ vllm/v1/attention/backends/mamba1_attn.py | 53 ++- 2 files changed, 249 insertions(+), 113 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 17b7f84a933f..1c573be52d33 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, NamedTuple import torch from torch import nn @@ -154,13 +154,23 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix - def forward(self, - hidden_states: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): - if not envs.VLLM_USE_V1: - return CustomOp.forward(self, hidden_states, mamba_cache_params) + def _ssm_transform(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies x_proj, splits into time_step, B, C, applies RMS norm, and dt_proj. + Returns (discrete_time_step, time_step, B, C) + """ + if self.is_lora_enabled: + ssm_params = self.x_proj(x.contiguous())[0] else: - return self.forward_cuda(hidden_states, mamba_cache_params) + ssm_params = self.x_proj(x)[0] + time_step, B, C = torch.split(ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1) + if self.use_rms_norm: + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + return discrete_time_step, B, C def forward_native(self, hidden_states: torch.Tensor, @@ -170,6 +180,24 @@ def forward_native(self, def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): + """ + Run the Mamba-1 SSM pipeline. + + Steps + ----- + 1. Apply the gated-MLP linear projection to the raw input. + 2. Pass the projected sequence through the convolutional mixing layer. + 3. Feed the result into the State-Space Model (SSM) block. + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) to produce contextual representations. + 5. Project the contextualised sequence back to the output embedding dimension. + + Batch handling + -------------- + Prefill and decode tokens are processed by dedicated CUDA kernels for both the + convolutional (conv1d) and SSM stages. In the case of a mixed batch (containing + both prefill and decode tokens), both sets of kernels are executed independently + and their outputs are concatenated before the final output projection. + """ forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -185,8 +213,7 @@ def forward_cuda(self, self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_state = mamba1_metadata.has_initial_states - context_lens_tensor = mamba1_metadata.context_lens_tensor + has_initial_states_p = mamba1_metadata.has_initial_states else: assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state @@ -194,117 +221,135 @@ def forward_cuda(self, state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - + has_initial_states_p = None if context_lens_tensor is not None: - has_initial_state = context_lens_tensor > 0 + has_initial_states_p = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) + hidden_states_BC, gate = projected_states.chunk(2, dim=-2) - # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if envs.VLLM_USE_V1 and attn_metadata is None: # V1 profile run - hidden_states = hidden_states.contiguous() - return self.out_proj(hidden_states.transpose(-2, -1))[0] - - if query_start_loc is not None and context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, + hidden_states_BC = hidden_states_BC.contiguous() + return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills + num_decodes = attn_metadata.num_decode_tokens + has_prefill = num_prefill_tokens > 0 + has_decode = num_decode_tokens > 0 + + pefill_decode_split = split_batch_to_prefill_and_decode( + hidden_states_BC, + gate, + state_indices_tensor, + query_start_loc, + has_initial_states_p, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + num_decodes, + ) + hidden_states_BC_p = pefill_decode_split.hidden_states_BC_p + hidden_states_BC_d = pefill_decode_split.hidden_states_BC_d + gate_p = pefill_decode_split.gate_p + gate_d = pefill_decode_split.gate_d + state_indices_tensor_p = pefill_decode_split.state_indices_tensor_p + state_indices_tensor_d = pefill_decode_split.state_indices_tensor_d + query_start_loc_p = pefill_decode_split.query_start_loc_p + initial_states = pefill_decode_split.initial_states + + ssm_outputs = [] + + if has_prefill: + # 2.a Prefill: Convolution sequence transformation + conv_input_p = hidden_states_BC_p + conv_out_p = causal_conv1d_fn( + conv_input_p, conv_weights, - bias=self.conv1d.bias, + self.conv1d.bias, activation=self.activation, conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=state_indices_tensor, - query_start_loc=query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), + has_initial_state=initial_states, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p + ) + # 3. State Space Model sequence transformation. Lora kernel requires contiguous tensor. + discrete_time_step_p, B_p, C_p = self._ssm_transform(conv_out_p.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() + + # 4.a Prefill: perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_out_p = selective_scan_fn( + conv_out_p, + ssm_state, + discrete_time_step_p, + self.A, + B_p.transpose(-2, -1), + C_p.transpose(-2, -1), + self.D.float(), + gate_p, + time_proj_bias, + delta_softplus=True, + cache_indices=state_indices_tensor_p, + has_initial_state=initial_states, + query_start_loc=query_start_loc_p + ) + ssm_outputs.append(scan_out_p) + + if has_decode: + # 2.b Decode: Convolution sequence transformation + conv_input_d = hidden_states_BC_d.transpose(0, 1) + conv_out_d = causal_conv1d_update( + conv_input_d, conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C + conv_state_indices=state_indices_tensor_d + ).transpose(0, 1) - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - ssm_parameters = self.x_proj( - hidden_states.transpose(-2, -1).contiguous())[0] - else: - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - if self.use_rms_norm: - assert self.dt_layernorm is not None - assert self.b_layernorm is not None - assert self.c_layernorm is not None - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) + # 3. State Space Model sequence transformation. Lora kernel requires contiguous tensor. + discrete_time_step_d, B_d, C_d = self._ssm_transform(conv_out_d.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if query_start_loc is not None and context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, + # 4.b Decode: perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1)) + selective_state_update( ssm_state, - discrete_time_step, + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, + B_d, + C_d, + self.D, + gate_d.transpose(0, 1), time_proj_bias, - delta_softplus=True, - cache_indices=state_indices_tensor, - has_initial_state=has_initial_state, - query_start_loc=query_start_loc) + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d + ) + scan_outputs_d = scan_outputs_d.transpose(0, 1) + + if envs.VLLM_USE_V1: + ssm_outputs.insert(0, scan_outputs_d) + else: + ssm_outputs.append(scan_outputs_d) + + scan_outputs_combined = ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + + # 5. Final output projection + if self.is_lora_enabled: # Lora kernel requires contiguous tensor. + scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous() + out = self.out_proj(scan_outputs_combined)[0] else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update(ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1).contiguous())[0] - else: - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1))[0] - return contextualized_states + out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] + + return out def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba1_state_shape( @@ -317,3 +362,73 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: @property def mamba_type(self) -> str: return "mamba1" + + def _time_proj_bias(self) -> Optional[torch.Tensor]: + """Return the time projection bias as float tensor, or None if not present.""" + if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: + return self.dt_proj.bias.float() + return None + + +class PrefillDecodeSplit(NamedTuple): + hidden_states_BC_p: torch.Tensor + hidden_states_BC_d: torch.Tensor + gate_p: torch.Tensor + gate_d: torch.Tensor + state_indices_tensor_p: torch.Tensor + state_indices_tensor_d: torch.Tensor + query_start_loc_p: torch.Tensor + initial_states: torch.Tensor + + +def split_batch_to_prefill_and_decode( + hidden_states_BC: torch.Tensor, + gate: torch.Tensor, + state_indices_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_states_p: Optional[torch.Tensor], + num_prefill_tokens: int, + num_decode_tokens: int, + num_prefills: int, + num_decodes: int, +) -> PrefillDecodeSplit: + if envs.VLLM_USE_V1: + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split(hidden_states_BC, + [num_decode_tokens, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate, + [num_decode_tokens, num_prefill_tokens], + dim=-1) + state_indices_tensor_d, state_indices_tensor_p = torch.split(state_indices_tensor, + [num_decodes, num_prefills], + dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - num_decodes if num_prefills > 0 else None) + has_initial_states_p_split = has_initial_states_p[-num_prefills:] if ( + has_initial_states_p is not None and num_prefills > 0) else None + initial_states = has_initial_states_p_split[:] if has_initial_states_p_split is not None else None + else: + # In v0, prefill tokens come first, then decode tokens. + hidden_states_BC_p, hidden_states_BC_d = torch.split(hidden_states_BC, + [num_prefill_tokens, num_decode_tokens], + dim=-1) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decode_tokens], + dim=-1) + state_indices_tensor_p, state_indices_tensor_d = torch.split(state_indices_tensor, + [num_prefills, num_decodes], + dim=0) + query_start_loc_p = (query_start_loc[:num_prefills + 1] if num_prefills > 0 else None) + initial_states = has_initial_states_p[:num_prefills] if ( + has_initial_states_p is not None and num_prefills > 0) else None + + return PrefillDecodeSplit( + hidden_states_BC_p=hidden_states_BC_p, + hidden_states_BC_d=hidden_states_BC_d, + gate_p=gate_p, + gate_d=gate_d, + state_indices_tensor_p=state_indices_tensor_p, + state_indices_tensor_d=state_indices_tensor_d, + query_start_loc_p=query_start_loc_p, + initial_states=initial_states, + ) \ No newline at end of file diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index f0e4636fdb52..ab2b19ee3b36 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,16 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + class Mamba1AttentionBackend(AttentionBackend): @@ -26,19 +31,22 @@ class Mamba1AttentionMetadata: context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor has_initial_states: torch.Tensor + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): - + AttentionMetadataBuilder[Mamba1AttentionMetadata]): reorder_batch_threshold: ClassVar[int] = 1 def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], + self, + kv_cache_spec: AttentionSpec, + vllm_config: VllmConfig, + device: torch.device, + layer_names: list[str], ): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec @@ -47,21 +55,34 @@ def __init__( self.layer_names = layer_names def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, ) -> Mamba1AttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( query_start_loc.device) - has_initial_states = (context_lens_tensor > 0) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + + has_initial_states = None + + if num_prefills > 0: + has_initial_states = context_lens_tensor > 0 + return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, - ) + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + ) \ No newline at end of file From 8b9cf42416a9d1baf083f4dc1a2b83b052d42061 Mon Sep 17 00:00:00 2001 From: amirk Date: Mon, 11 Aug 2025 17:25:53 +0300 Subject: [PATCH 2/9] fix: Lint Signed-off-by: amirk --- .../layers/mamba/mamba_mixer.py | 156 +++++++++--------- vllm/v1/attention/backends/mamba1_attn.py | 35 ++-- 2 files changed, 93 insertions(+), 98 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 1c573be52d33..8c491d939ca1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, NamedTuple +from typing import NamedTuple, Optional import torch from torch import nn @@ -154,17 +154,17 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix - def _ssm_transform(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Applies x_proj, splits into time_step, B, C, applies RMS norm, and dt_proj. - Returns (discrete_time_step, time_step, B, C) - """ + def _ssm_transform( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: ssm_params = self.x_proj(x.contiguous())[0] else: ssm_params = self.x_proj(x)[0] - time_step, B, C = torch.split(ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1) + time_step, B, C = torch.split( + ssm_params, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1) if self.use_rms_norm: time_step = self.dt_layernorm(time_step.contiguous()) B = self.b_layernorm(B.contiguous()) @@ -187,15 +187,18 @@ def forward_cuda(self, ----- 1. Apply the gated-MLP linear projection to the raw input. 2. Pass the projected sequence through the convolutional mixing layer. - 3. Feed the result into the State-Space Model (SSM) block. - 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) to produce contextual representations. - 5. Project the contextualised sequence back to the output embedding dimension. + 3. Feed the result into the State-Space Model (SSM) blocks. + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + to produce contextual representations. + 5. Project the contextualised sequence back + to the output embedding dimension. Batch handling -------------- - Prefill and decode tokens are processed by dedicated CUDA kernels for both the - convolutional (conv1d) and SSM stages. In the case of a mixed batch (containing - both prefill and decode tokens), both sets of kernels are executed independently + Prefill and decode tokens are processed by dedicated CUDA + kernels for both the convolutional (conv1d) and SSM stages. + In the case of a mixed batch (containing both prefill and + decode tokens), both sets of kernels are executed independently and their outputs are concatenated before the final output projection. """ @@ -269,18 +272,18 @@ def forward_cuda(self, if has_prefill: # 2.a Prefill: Convolution sequence transformation conv_input_p = hidden_states_BC_p - conv_out_p = causal_conv1d_fn( - conv_input_p, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=initial_states, - cache_indices=state_indices_tensor_p, - query_start_loc=query_start_loc_p - ) - # 3. State Space Model sequence transformation. Lora kernel requires contiguous tensor. - discrete_time_step_p, B_p, C_p = self._ssm_transform(conv_out_p.transpose(-2, -1)) + conv_out_p = causal_conv1d_fn(conv_input_p, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=initial_states, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + # 3. State Space Model sequence transformations. + # Lora kernel requires contiguous tensor. + discrete_time_step_p, B_p, C_p = self._ssm_transform( + conv_out_p.transpose(-2, -1)) time_proj_bias = self._time_proj_bias() # 4.a Prefill: perform the recurrence y ← SSM(A, B, C, Δ)(x) @@ -297,8 +300,7 @@ def forward_cuda(self, delta_softplus=True, cache_indices=state_indices_tensor_p, has_initial_state=initial_states, - query_start_loc=query_start_loc_p - ) + query_start_loc=query_start_loc_p) ssm_outputs.append(scan_out_p) if has_decode: @@ -310,29 +312,29 @@ def forward_cuda(self, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d - ).transpose(0, 1) + conv_state_indices=state_indices_tensor_d).transpose(0, 1) - # 3. State Space Model sequence transformation. Lora kernel requires contiguous tensor. - discrete_time_step_d, B_d, C_d = self._ssm_transform(conv_out_d.transpose(-2, -1)) + # 3. State Space Model sequence transformation. + # Lora kernel requires contiguous tensor. + discrete_time_step_d, B_d, C_d = self._ssm_transform( + conv_out_d.transpose(-2, -1)) time_proj_bias = self._time_proj_bias() # 4.b Decode: perform the recurrence y ← SSM(A, B, C, Δ)(x) - scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1)) - selective_state_update( - ssm_state, - conv_out_d.transpose(0, 1), - discrete_time_step_d.transpose(0, 1), - self.A, - B_d, - C_d, - self.D, - gate_d.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=state_indices_tensor_d, - out=scan_outputs_d - ) + scan_outputs_d = torch.empty_like( + hidden_states_BC_d.transpose(0, 1)) + selective_state_update(ssm_state, + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), + self.A, + B_d, + C_d, + self.D, + gate_d.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d) scan_outputs_d = scan_outputs_d.transpose(0, 1) if envs.VLLM_USE_V1: @@ -340,11 +342,13 @@ def forward_cuda(self, else: ssm_outputs.append(scan_outputs_d) - scan_outputs_combined = ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + scan_outputs_combined = ssm_outputs[0] if len( + ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) # 5. Final output projection if self.is_lora_enabled: # Lora kernel requires contiguous tensor. - scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous() + scan_outputs_combined = scan_outputs_combined.transpose( + -2, -1).contiguous() out = self.out_proj(scan_outputs_combined)[0] else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] @@ -364,7 +368,6 @@ def mamba_type(self) -> str: return "mamba1" def _time_proj_bias(self) -> Optional[torch.Tensor]: - """Return the time projection bias as float tensor, or None if not present.""" if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() return None @@ -382,45 +385,42 @@ class PrefillDecodeSplit(NamedTuple): def split_batch_to_prefill_and_decode( - hidden_states_BC: torch.Tensor, - gate: torch.Tensor, - state_indices_tensor: torch.Tensor, - query_start_loc: torch.Tensor, - has_initial_states_p: Optional[torch.Tensor], - num_prefill_tokens: int, - num_decode_tokens: int, - num_prefills: int, - num_decodes: int, + hidden_states_BC: torch.Tensor, + gate: torch.Tensor, + state_indices_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_states_p: Optional[torch.Tensor], + num_prefill_tokens: int, + num_decode_tokens: int, + num_prefills: int, + num_decodes: int, ) -> PrefillDecodeSplit: if envs.VLLM_USE_V1: # In v1, decode tokens come first, then prefill tokens. - hidden_states_BC_d, hidden_states_BC_p = torch.split(hidden_states_BC, - [num_decode_tokens, num_prefill_tokens], - dim=-1) + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) gate_d, gate_p = torch.split(gate, [num_decode_tokens, num_prefill_tokens], dim=-1) - state_indices_tensor_d, state_indices_tensor_p = torch.split(state_indices_tensor, - [num_decodes, num_prefills], - dim=0) - query_start_loc_p = (query_start_loc[-num_prefills - 1:] - num_decodes if num_prefills > 0 else None) - has_initial_states_p_split = has_initial_states_p[-num_prefills:] if ( - has_initial_states_p is not None and num_prefills > 0) else None - initial_states = has_initial_states_p_split[:] if has_initial_states_p_split is not None else None + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, [num_decodes, num_prefills], dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_decodes if num_prefills > 0 else None) + initial_states = has_initial_states_p[-num_prefills:] if ( + has_initial_states_p is not None and num_prefills > 0) else None else: # In v0, prefill tokens come first, then decode tokens. - hidden_states_BC_p, hidden_states_BC_d = torch.split(hidden_states_BC, - [num_prefill_tokens, num_decode_tokens], - dim=-1) + hidden_states_BC_p, hidden_states_BC_d = torch.split( + hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) gate_p, gate_d = torch.split(gate, [num_prefill_tokens, num_decode_tokens], dim=-1) - state_indices_tensor_p, state_indices_tensor_d = torch.split(state_indices_tensor, - [num_prefills, num_decodes], - dim=0) - query_start_loc_p = (query_start_loc[:num_prefills + 1] if num_prefills > 0 else None) + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, [num_prefills, num_decodes], dim=0) + query_start_loc_p = (query_start_loc[:num_prefills + + 1] if num_prefills > 0 else None) initial_states = has_initial_states_p[:num_prefills] if ( - has_initial_states_p is not None and num_prefills > 0) else None + has_initial_states_p is not None and num_prefills > 0) else None return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, @@ -431,4 +431,4 @@ def split_batch_to_prefill_and_decode( state_indices_tensor_d=state_indices_tensor_d, query_start_loc_p=query_start_loc_p, initial_states=initial_states, - ) \ No newline at end of file + ) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index ab2b19ee3b36..44f116e2e8c0 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,21 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar +from typing import ClassVar import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - class Mamba1AttentionBackend(AttentionBackend): @@ -38,15 +34,15 @@ class Mamba1AttentionMetadata: class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): + AttentionMetadataBuilder[Mamba1AttentionMetadata]): reorder_batch_threshold: ClassVar[int] = 1 def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], + self, + kv_cache_spec: AttentionSpec, + vllm_config: VllmConfig, + device: torch.device, + layer_names: list[str], ): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec @@ -55,10 +51,10 @@ def __init__( self.layer_names = layer_names def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, ) -> Mamba1AttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc @@ -75,7 +71,6 @@ def build( if num_prefills > 0: has_initial_states = context_lens_tensor > 0 - return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, @@ -85,4 +80,4 @@ def build( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - ) \ No newline at end of file + ) From b0483b0c80effe9cd74029b36840d5438887d6b9 Mon Sep 17 00:00:00 2001 From: amirk Date: Tue, 12 Aug 2025 12:05:26 +0300 Subject: [PATCH 3/9] fix: Lint Signed-off-by: amirk --- .../layers/mamba/mamba_mixer.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8c491d939ca1..e1bbff124aa4 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -8,6 +8,8 @@ from torch.nn.parameter import Parameter from vllm import envs +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -155,8 +157,8 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix def _ssm_transform( - self, x: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: ssm_params = self.x_proj(x.contiguous())[0] else: @@ -166,9 +168,12 @@ def _ssm_transform( [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) if self.use_rms_norm: - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) + if self.dt_layernorm is not None: + time_step = self.dt_layernorm(time_step.contiguous()) + if self.b_layernorm is not None: + B = self.b_layernorm(B.contiguous()) + if self.c_layernorm is not None: + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C @@ -218,6 +223,7 @@ def forward_cuda(self, ssm_state = self_kv_cache[1] has_initial_states_p = mamba1_metadata.has_initial_states else: + assert isinstance(attn_metadata, PlaceholderAttentionMetadata) assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -240,10 +246,10 @@ def forward_cuda(self, hidden_states_BC = hidden_states_BC.contiguous() return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] - num_prefill_tokens = attn_metadata.num_prefill_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count num_decode_tokens = attn_metadata.num_decode_tokens - num_prefills = attn_metadata.num_prefills - num_decodes = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 @@ -380,8 +386,8 @@ class PrefillDecodeSplit(NamedTuple): gate_d: torch.Tensor state_indices_tensor_p: torch.Tensor state_indices_tensor_d: torch.Tensor - query_start_loc_p: torch.Tensor - initial_states: torch.Tensor + query_start_loc_p: Optional[torch.Tensor] + initial_states: Optional[torch.Tensor] def split_batch_to_prefill_and_decode( From 827bc700621b762d5edf8e7f4acfba90bb473ee2 Mon Sep 17 00:00:00 2001 From: amirk Date: Tue, 12 Aug 2025 16:44:14 +0300 Subject: [PATCH 4/9] feat: remove redundant params Signed-off-by: amirk --- vllm/model_executor/layers/mamba/mamba_mixer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e1bbff124aa4..9a2949bad6d6 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -277,8 +277,7 @@ def forward_cuda(self, if has_prefill: # 2.a Prefill: Convolution sequence transformation - conv_input_p = hidden_states_BC_p - conv_out_p = causal_conv1d_fn(conv_input_p, + conv_out_p = causal_conv1d_fn(hidden_states_BC_p, conv_weights, self.conv1d.bias, activation=self.activation, @@ -311,9 +310,8 @@ def forward_cuda(self, if has_decode: # 2.b Decode: Convolution sequence transformation - conv_input_d = hidden_states_BC_d.transpose(0, 1) conv_out_d = causal_conv1d_update( - conv_input_d, + hidden_states_BC_d.transpose(0, 1), conv_state, conv_weights, self.conv1d.bias, From c39d11e9f2c42d674c9dd1a01c5935fd78803773 Mon Sep 17 00:00:00 2001 From: amirk Date: Wed, 13 Aug 2025 14:22:22 +0300 Subject: [PATCH 5/9] feat: fix _ssm_transform layernorm assertions, change has_initial_states to has_initial_states_p in attn md, fix typo Signed-off-by: amirk --- .../layers/mamba/mamba_mixer.py | 36 +++++++++---------- vllm/v1/attention/backends/mamba1_attn.py | 12 ++++--- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 9a2949bad6d6..4f587bddd474 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -160,6 +160,7 @@ def _ssm_transform( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: + # Lora kernel requires contiguous tensor. ssm_params = self.x_proj(x.contiguous())[0] else: ssm_params = self.x_proj(x)[0] @@ -168,12 +169,12 @@ def _ssm_transform( [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) if self.use_rms_norm: - if self.dt_layernorm is not None: - time_step = self.dt_layernorm(time_step.contiguous()) - if self.b_layernorm is not None: - B = self.b_layernorm(B.contiguous()) - if self.c_layernorm is not None: - C = self.c_layernorm(C.contiguous()) + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C @@ -193,7 +194,7 @@ def forward_cuda(self, 1. Apply the gated-MLP linear projection to the raw input. 2. Pass the projected sequence through the convolutional mixing layer. 3. Feed the result into the State-Space Model (SSM) blocks. - 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) to produce contextual representations. 5. Project the contextualised sequence back to the output embedding dimension. @@ -221,7 +222,7 @@ def forward_cuda(self, self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_states_p = mamba1_metadata.has_initial_states + has_initial_states_p = mamba1_metadata.has_initial_states_p else: assert isinstance(attn_metadata, PlaceholderAttentionMetadata) assert mamba_cache_params is not None @@ -253,7 +254,7 @@ def forward_cuda(self, has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 - pefill_decode_split = split_batch_to_prefill_and_decode( + prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, gate, state_indices_tensor, @@ -264,14 +265,14 @@ def forward_cuda(self, num_prefills, num_decodes, ) - hidden_states_BC_p = pefill_decode_split.hidden_states_BC_p - hidden_states_BC_d = pefill_decode_split.hidden_states_BC_d - gate_p = pefill_decode_split.gate_p - gate_d = pefill_decode_split.gate_d - state_indices_tensor_p = pefill_decode_split.state_indices_tensor_p - state_indices_tensor_d = pefill_decode_split.state_indices_tensor_d - query_start_loc_p = pefill_decode_split.query_start_loc_p - initial_states = pefill_decode_split.initial_states + hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p + hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d + gate_p = prefill_decode_split.gate_p + gate_d = prefill_decode_split.gate_d + state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p + state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d + query_start_loc_p = prefill_decode_split.query_start_loc_p + initial_states = prefill_decode_split.initial_states ssm_outputs = [] @@ -286,7 +287,6 @@ def forward_cuda(self, cache_indices=state_indices_tensor_p, query_start_loc=query_start_loc_p) # 3. State Space Model sequence transformations. - # Lora kernel requires contiguous tensor. discrete_time_step_p, B_p, C_p = self._ssm_transform( conv_out_p.transpose(-2, -1)) time_proj_bias = self._time_proj_bias() diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 44f116e2e8c0..1a8ada205175 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional import torch @@ -26,7 +26,9 @@ class Mamba1AttentionMetadata: query_start_loc: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: torch.Tensor + # has_initial_states_p only contain prefill requests and will be None if + # the batch has no prefill request. + has_initial_states_p: Optional[torch.Tensor] num_prefills: int num_prefill_tokens: int num_decodes: int @@ -66,15 +68,15 @@ def build( split_decodes_and_prefills(common_attn_metadata, decode_threshold=1)) - has_initial_states = None + has_initial_states_p = None if num_prefills > 0: - has_initial_states = context_lens_tensor > 0 + has_initial_states_p = context_lens_tensor > 0 return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, - has_initial_states=has_initial_states, + has_initial_states_p=has_initial_states_p, state_indices_tensor=state_indices_tensor, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, From 29fea2b9ac99fcb1c82dd694c12f548963154525 Mon Sep 17 00:00:00 2001 From: amirk Date: Wed, 13 Aug 2025 14:49:07 +0300 Subject: [PATCH 6/9] feat: change comment Signed-off-by: amirk --- vllm/model_executor/layers/mamba/mamba_mixer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 4f587bddd474..dc71dba01f0d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -277,7 +277,7 @@ def forward_cuda(self, ssm_outputs = [] if has_prefill: - # 2.a Prefill: Convolution sequence transformation + # 2. Convolution sequence transformation conv_out_p = causal_conv1d_fn(hidden_states_BC_p, conv_weights, self.conv1d.bias, @@ -291,7 +291,7 @@ def forward_cuda(self, conv_out_p.transpose(-2, -1)) time_proj_bias = self._time_proj_bias() - # 4.a Prefill: perform the recurrence y ← SSM(A, B, C, Δ)(x) + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) scan_out_p = selective_scan_fn( conv_out_p, ssm_state, @@ -309,7 +309,7 @@ def forward_cuda(self, ssm_outputs.append(scan_out_p) if has_decode: - # 2.b Decode: Convolution sequence transformation + # 2. Convolution sequence transformation conv_out_d = causal_conv1d_update( hidden_states_BC_d.transpose(0, 1), conv_state, @@ -319,12 +319,11 @@ def forward_cuda(self, conv_state_indices=state_indices_tensor_d).transpose(0, 1) # 3. State Space Model sequence transformation. - # Lora kernel requires contiguous tensor. discrete_time_step_d, B_d, C_d = self._ssm_transform( conv_out_d.transpose(-2, -1)) time_proj_bias = self._time_proj_bias() - # 4.b Decode: perform the recurrence y ← SSM(A, B, C, Δ)(x) + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) scan_outputs_d = torch.empty_like( hidden_states_BC_d.transpose(0, 1)) selective_state_update(ssm_state, From 86f94e0511dd38e59ee0bd0b26b9e3c6d702c186 Mon Sep 17 00:00:00 2001 From: amirk Date: Wed, 13 Aug 2025 15:20:15 +0300 Subject: [PATCH 7/9] fix: cr comments Signed-off-by: amirk --- .../layers/mamba/mamba_mixer.py | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index dc71dba01f0d..7a5435fa47ca 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -178,6 +178,17 @@ def _ssm_transform( discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) return discrete_time_step, B, C + def forward(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + if not envs.VLLM_USE_V1: + return CustomOp.forward(self, hidden_states, mamba_cache_params) + else: + return self.forward_cuda( + hidden_states, + mamba_cache_params, + ) + def forward_native(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): @@ -272,20 +283,21 @@ def forward_cuda(self, state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d query_start_loc_p = prefill_decode_split.query_start_loc_p - initial_states = prefill_decode_split.initial_states + has_initial_states_p = prefill_decode_split.has_initial_states_p ssm_outputs = [] if has_prefill: # 2. Convolution sequence transformation - conv_out_p = causal_conv1d_fn(hidden_states_BC_p, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=initial_states, - cache_indices=state_indices_tensor_p, - query_start_loc=query_start_loc_p) + conv_out_p = causal_conv1d_fn( + hidden_states_BC_p, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( conv_out_p.transpose(-2, -1)) @@ -304,7 +316,7 @@ def forward_cuda(self, time_proj_bias, delta_softplus=True, cache_indices=state_indices_tensor_p, - has_initial_state=initial_states, + has_initial_state=has_initial_states_p, query_start_loc=query_start_loc_p) ssm_outputs.append(scan_out_p) @@ -384,7 +396,7 @@ class PrefillDecodeSplit(NamedTuple): state_indices_tensor_p: torch.Tensor state_indices_tensor_d: torch.Tensor query_start_loc_p: Optional[torch.Tensor] - initial_states: Optional[torch.Tensor] + has_initial_states_p: Optional[torch.Tensor] def split_batch_to_prefill_and_decode( @@ -409,7 +421,7 @@ def split_batch_to_prefill_and_decode( state_indices_tensor, [num_decodes, num_prefills], dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - num_decodes if num_prefills > 0 else None) - initial_states = has_initial_states_p[-num_prefills:] if ( + has_initial_states_p = has_initial_states_p[-num_prefills:] if ( has_initial_states_p is not None and num_prefills > 0) else None else: # In v0, prefill tokens come first, then decode tokens. @@ -422,7 +434,7 @@ def split_batch_to_prefill_and_decode( state_indices_tensor, [num_prefills, num_decodes], dim=0) query_start_loc_p = (query_start_loc[:num_prefills + 1] if num_prefills > 0 else None) - initial_states = has_initial_states_p[:num_prefills] if ( + has_initial_states_p = has_initial_states_p[:num_prefills] if ( has_initial_states_p is not None and num_prefills > 0) else None return PrefillDecodeSplit( @@ -433,5 +445,5 @@ def split_batch_to_prefill_and_decode( state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, query_start_loc_p=query_start_loc_p, - initial_states=initial_states, + has_initial_states_p=has_initial_states_p, ) From c4323a94d0af280ae2f6a3ae8a90abcea1764ae2 Mon Sep 17 00:00:00 2001 From: amirk Date: Wed, 13 Aug 2025 15:44:27 +0300 Subject: [PATCH 8/9] fix: return has_initial_states_p naming Signed-off-by: amirk --- .../model_executor/layers/mamba/mamba_mixer.py | 18 +++++++++--------- vllm/v1/attention/backends/mamba1_attn.py | 10 ++++------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 7a5435fa47ca..e21d3a3c6ffc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -233,7 +233,7 @@ def forward_cuda(self, self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_states_p = mamba1_metadata.has_initial_states_p + has_initial_states = mamba1_metadata.has_initial_states else: assert isinstance(attn_metadata, PlaceholderAttentionMetadata) assert mamba_cache_params is not None @@ -242,9 +242,9 @@ def forward_cuda(self, state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - has_initial_states_p = None + has_initial_states = None if context_lens_tensor is not None: - has_initial_states_p = context_lens_tensor > 0 + has_initial_states = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -270,7 +270,7 @@ def forward_cuda(self, gate, state_indices_tensor, query_start_loc, - has_initial_states_p, + has_initial_states, num_prefill_tokens, num_decode_tokens, num_prefills, @@ -404,7 +404,7 @@ def split_batch_to_prefill_and_decode( gate: torch.Tensor, state_indices_tensor: torch.Tensor, query_start_loc: torch.Tensor, - has_initial_states_p: Optional[torch.Tensor], + has_initial_states: Optional[torch.Tensor], num_prefill_tokens: int, num_decode_tokens: int, num_prefills: int, @@ -421,8 +421,8 @@ def split_batch_to_prefill_and_decode( state_indices_tensor, [num_decodes, num_prefills], dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - num_decodes if num_prefills > 0 else None) - has_initial_states_p = has_initial_states_p[-num_prefills:] if ( - has_initial_states_p is not None and num_prefills > 0) else None + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None else: # In v0, prefill tokens come first, then decode tokens. hidden_states_BC_p, hidden_states_BC_d = torch.split( @@ -434,8 +434,8 @@ def split_batch_to_prefill_and_decode( state_indices_tensor, [num_prefills, num_decodes], dim=0) query_start_loc_p = (query_start_loc[:num_prefills + 1] if num_prefills > 0 else None) - has_initial_states_p = has_initial_states_p[:num_prefills] if ( - has_initial_states_p is not None and num_prefills > 0) else None + has_initial_states_p = has_initial_states[:num_prefills] if ( + has_initial_states is not None and num_prefills > 0) else None return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 1a8ada205175..6cdc509083ae 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -26,9 +26,7 @@ class Mamba1AttentionMetadata: query_start_loc: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor - # has_initial_states_p only contain prefill requests and will be None if - # the batch has no prefill request. - has_initial_states_p: Optional[torch.Tensor] + has_initial_states: Optional[torch.Tensor] num_prefills: int num_prefill_tokens: int num_decodes: int @@ -68,15 +66,15 @@ def build( split_decodes_and_prefills(common_attn_metadata, decode_threshold=1)) - has_initial_states_p = None + has_initial_states = None if num_prefills > 0: - has_initial_states_p = context_lens_tensor > 0 + has_initial_states = context_lens_tensor > 0 return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, - has_initial_states_p=has_initial_states_p, + has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, From 0a1c9f0f76219e072c8b7d8d046ce7d31683bb43 Mon Sep 17 00:00:00 2001 From: asafg Date: Fri, 15 Aug 2025 09:21:57 +0300 Subject: [PATCH 9/9] fix: Added enforce_eager=True to test_hybrid for mamba1 models Signed-off-by: asafg --- tests/models/language/generation/test_hybrid.py | 13 +++++++++++++ vllm/model_executor/layers/mamba/mamba_mixer.py | 5 ++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 19fcbf561640..e75677347f03 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -57,6 +57,13 @@ # Avoid OOM MAX_NUM_SEQS = 4 +# Once we add support for FCG in Mamba1, this list will be removed and tests +# all test cases will use enforce_eager=False +ENFORCE_EAGER_MODELS_V1 = [ + "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", +] + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -94,13 +101,19 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + if model in ENFORCE_EAGER_MODELS_V1: + enforce_eager = True + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, + enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e21d3a3c6ffc..3b17fb0ca8c7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -8,8 +8,7 @@ from torch.nn.parameter import Parameter from vllm import envs -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionMetadata) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -235,7 +234,7 @@ def forward_cuda(self, ssm_state = self_kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states else: - assert isinstance(attn_metadata, PlaceholderAttentionMetadata) + assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state