From 8a9c13ac7639aba3c469db1dc127635010bf6f85 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 7 Jun 2025 02:26:14 -0700 Subject: [PATCH 1/8] finish model init Signed-off-by: Chen Zhang --- .../layers/mamba/mamba_mixer2.py | 109 ++++++++++++---- vllm/model_executor/models/mamba2.py | 57 +++++---- vllm/v1/attention/backends/mamba_attn.py | 110 +++++++++++++++++ vllm/v1/core/single_type_kv_cache_manager.py | 44 ++++++- vllm/v1/kv_cache_interface.py | 24 ++++ vllm/v1/worker/gpu_model_runner.py | 116 ++++++++++++------ 6 files changed, 378 insertions(+), 82 deletions(-) create mode 100644 vllm/v1/attention/backends/mamba_attn.py diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 6d9ea5387879..672323e434e3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -6,7 +6,9 @@ import torch from torch import nn +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -241,6 +243,7 @@ def __init__( activation: str = "silu", use_rms_norm: bool = True, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -273,6 +276,7 @@ def __init__( ), "Tensor parallel currently not supported for quantized models." self.ssm_state_size = ssm_state_size + self.conv_kernel_size = conv_kernel_size self.activation = activation self.intermediate_size = intermediate_size @@ -411,6 +415,17 @@ def __init__( self.use_rms_norm, eps=rms_norm_eps) + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + def forward_native( self, hidden_states: torch.Tensor, @@ -426,19 +441,34 @@ def forward_cuda( mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, ): + forward_context = get_forward_context() # mamba2_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill # modes; they are computed at top-level model forward since they # stay the same and reused for all mamba layers in the same iteration - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - - num_prefills = attn_metadata.num_prefills # request count - num_decodes = attn_metadata.num_decode_tokens # token count (=request) - num_prefill_tokens = attn_metadata.num_prefill_tokens # token count - has_prefill = num_prefills > 0 - has_decode = num_decodes > 0 + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0] + ssm_state = self_kv_cache[1] + else: + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + # - get hidden_states, B and C after depthwise convolution. groups_time_state_size = self.n_groups * self.ssm_state_size + split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -459,6 +489,22 @@ def forward_cuda( conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states_B_C = (hidden_states_B_C.transpose( + 0, 1).clone().transpose(0, 1)).contiguous() + hidden_states, _B, _C = split_hidden_states_B_C_fn( + hidden_states_B_C) + hidden_states = self.norm(hidden_states, gate) + out, _ = self.out_proj(hidden_states) + return out + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + # Separate prefill and decode by splitting varlen input # Split along token dimension hidden_states_B_C_p, hidden_states_B_C_d = torch.split( @@ -480,17 +526,6 @@ def forward_cuda( query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] if has_prefill else None) - # - get hidden_states, B and C after depthwise convolution. - split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( - hidden_states_B_C, - [ - self.intermediate_size // self.tp_size, - groups_time_state_size // self.tp_size, - groups_time_state_size // self.tp_size, - ], - dim=-1, - ) - ssd_output_list = [] # Process prefill requests @@ -503,7 +538,7 @@ def forward_cuda( conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, + conv_states=conv_state, has_initial_state=mamba2_metadata.has_initial_states, cache_indices=state_indices_tensor_p, query_start_loc=query_start_loc_p).transpose( @@ -521,7 +556,7 @@ def forward_cuda( # making a copy of the states initial_states = torch.where( mamba2_metadata.has_initial_states[:, None, None, None], - mamba_cache_params.ssm_state[state_indices_tensor_p], 0) + ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, @@ -550,7 +585,7 @@ def forward_cuda( # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state + ssm_state[state_indices_tensor_p] = varlen_state # - reshape ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) @@ -560,7 +595,7 @@ def forward_cuda( # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, @@ -586,7 +621,7 @@ def forward_cuda( # using state_indices_tensor_d hidden_states_d = selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states_d, dt_d, A_d, @@ -614,3 +649,31 @@ def forward_cuda( # 5. Final linear projection out, _ = self.out_proj(hidden_states) return out + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + world_size = get_tensor_model_parallel_world_size() + + conv_state_shape, temporal_state_shape = None, None + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.n_groups + + extra_groups_for_head_shards(self.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (self.intermediate_size + + 2 * n_groups * self.ssm_state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.conv_kernel_size - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.num_heads, world_size), + self.head_dim, + self.ssm_state_size, + ) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index cf9e1bd03e98..e659d61e437d 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -8,6 +8,7 @@ from torch import nn from transformers import MambaConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -25,8 +26,7 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, - SupportsV0Only) + IsAttentionFree) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module): def __init__(self, config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.mixer = MambaMixer2(hidden_size=config.hidden_size, @@ -60,7 +61,8 @@ def __init__(self, head_dim=config.head_dim, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -108,8 +110,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer(config, - quant_config=quant_config), + lambda prefix: Mamba2DecoderLayer( + config, quant_config=quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -142,10 +144,14 @@ def forward( attn_metadata: AttentionMetadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None for i in range(len(self.layers)): layer = self.layers[i] @@ -155,7 +161,7 @@ def forward( hidden_states=hidden_states, residual=residual, mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer), + i - self.start_layer) if mamba_cache_params else None, mamba2_metadata=mamba2_metadata) if not get_pp_group().is_last_rank: @@ -190,8 +196,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, - SupportsV0Only): +class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -242,14 +247,24 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + # TODO in this PR: default to v0 but allows v1 + # TODO in this PR: check whether the kernel is cuda graph compatible + # TODO in this PR: current design for mamba seems incompatible with spec + # decode + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None hidden_states = self.backbone(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py new file mode 100644 index 000000000000..6add59e7fe27 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + +import torch + +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import MambaSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + + +class Mamba2AttentionMetadataBuilder: + + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, + block_table: BlockTable): + self.runner = runner + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # NOTE (Chen): Copied from FlashInferMetadataBuilder. This is not + # elegant and should be refactored. + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + if num_tokens == 1: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + raise NotImplementedError("Mamba2AttentionBackend is not implemented.") + + +class Mamba2AttentionBackend: + + @staticmethod + def get_builder_cls() -> type[Mamba2AttentionMetadataBuilder]: + return Mamba2AttentionMetadataBuilder + + +class Mamba2AttentionMetadata: + has_initial_states: torch.Tensor + prep_initial_states: bool + + chunk_size: int + seq_idx: torch.Tensor + chunk_indices: torch.Tensor + chunk_offsets: torch.Tensor + + def __init__(self, query_start_loc, context_lens_tensor): + self.query_start_loc = query_start_loc + self.context_lens_tensor = context_lens_tensor diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 98d758f820ad..4913cfdddce7 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -8,7 +8,7 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, - SlidingWindowSpec) + MambaSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -52,6 +52,7 @@ def __init__( self.caching_hash_fn = caching_hash_fn self.kv_cache_group_id = kv_cache_group_id + self._null_block = block_pool.null_block def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, @@ -392,9 +393,50 @@ def get_num_common_prefix_blocks(self, request_id: str, return 0 +class MambaManager(SingleTypeKVCacheManager): + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> list[list[KVCacheBlock]]: + assert isinstance( + kv_cache_spec, + MambaSpec), ("MambaManager can only be used for mamba groups") + # NOTE(Chen): prefix caching is not supported for mamba now. Always + # return empty list. + computed_blocks: list[list[KVCacheBlock]] = [ + [] for _ in range(len(kv_cache_group_ids)) + ] + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # NOTE(Chen): each request will always have 1 block at this moment, so + # no need to remove blocks. + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + return 0 + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[KVCacheBlock]: + new_blocks = super().allocate_new_blocks(request_id, num_tokens) + assert len(self.req_to_blocks[request_id]) == 1, ( + "MambaManager should only allocate 1 block for each request.") + return new_blocks + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, + MambaSpec: MambaManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index e938f3bfc671..3090647180b6 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -3,6 +3,7 @@ import copy from dataclasses import dataclass +from math import prod from typing import Optional import torch @@ -154,6 +155,29 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes +@dataclass +class MambaSpec(KVCacheSpec): + shapes: tuple[tuple[int, ...], ...] + dtype: torch.dtype + + def __post_init__(self): + self.num_elements = sum(prod(shape) for shape in self.shapes) + + @property + def type_id(self) -> str: + return f"mamba_{self.shapes}_{self.dtype}" + + @property + def page_size_bytes(self) -> int: + return self.num_elements * get_dtype_size(self.dtype) + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # NOTE(Chen): we allocate 1 block for each request now, so + # max_memory_usage_bytes is the same as page_size_bytes. + # Need to update this when supporting prefix caching. + return self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a90c294a9749..1bfb3ba814f5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,6 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -37,11 +38,13 @@ from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, - check_use_alibi, is_pin_memory_available) + check_use_alibi, get_dtype_size, + is_pin_memory_available) +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, + KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -2019,39 +2022,42 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - raise NotImplementedError( - "Only AttentionSpec is supported for now.") - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = ( - f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = attn_backend_i.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is " - f"{attn_backend_name}, FlashAttention version is " - f"{flash_attn_version}.") + if isinstance(kv_cache_spec, AttentionSpec): + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = (f"Error with get_attn_backend: " + f"{kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = attn_backend_i.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is " + f"{attn_backend_name}, FlashAttention version is " + f"{flash_attn_version}.") + elif isinstance(kv_cache_spec, MambaSpec): + attn_backend_i = Mamba2AttentionBackend + else: + raise ValueError( + f"Unknown KV cache spec type: {type(kv_cache_spec)}") block_table_i = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( @@ -2168,6 +2174,22 @@ def _reshape_kv_cache_tensors( kv_caches[layer_name] = kv_cache_raw_tensors[ layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) + elif isinstance(kv_cache_spec, MambaSpec): + raw_tensor = kv_cache_raw_tensors[layer_name] + dtype = kv_cache_spec.dtype + state_tensors = [] + start_pos = 0 + for shape in kv_cache_spec.shapes: + target_shape = (num_blocks, *shape) + size_in_bytes = np.prod(shape) * get_dtype_size( + dtype) * num_blocks + tensor = raw_tensor[start_pos:start_pos + + size_in_bytes] + tensor = tensor.view(dtype).view(target_shape) + state_tensors.append(tensor) + start_pos += size_in_bytes + assert start_pos == raw_tensor.numel() + kv_caches[layer_name] = tuple(state_tensors) else: raise NotImplementedError return kv_caches @@ -2234,11 +2256,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in layers.items(): + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of @@ -2278,4 +2300,24 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") + mamba_layers = get_layers_from_vllm_config(self.vllm_config, + MambaMixer2) + if len(mamba_layers) > 0: + if self.vllm_config.speculative_config is not None: + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet.") + if not self.vllm_config.model_config.enforce_eager: + raise NotImplementedError( + "Mamba with cuda graph is not supported yet.") + if self.vllm_config.cache_config.enable_prefix_caching: + raise NotImplementedError( + "Prefix caching is not supported for Mamba yet.") + max_model_len = self.vllm_config.model_config.max_model_len + # NOTE(Chen): set block_size to max_model_len, so that mamba model + # will always have only one block in the KV cache. + for layer_name, mamba_module in mamba_layers.items(): + kv_cache_spec[layer_name] = MambaSpec( + shapes=mamba_module.get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len) return kv_cache_spec From f8eeae5fb7216aad15dcf2eb4bfcdfe280e6c1be Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 8 Jun 2025 00:55:44 -0700 Subject: [PATCH 2/8] can say human words now Signed-off-by: Chen Zhang --- .../layers/mamba/mamba_mixer2.py | 155 ++++++++++++------ vllm/model_executor/models/mamba2.py | 3 +- vllm/v1/attention/backends/mamba_attn.py | 96 ++++++++++- 3 files changed, 200 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 672323e434e3..6cee3e9c754a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -29,6 +29,7 @@ LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -229,21 +230,22 @@ class MambaMixer2(CustomOp): """ def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + chunk_size: int = -1, # the chunk size used by v1 ): super().__init__() @@ -425,6 +427,11 @@ def __init__( # of Attention + v0 PP. # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + assert chunk_size != -1, "chunk_size must be set for v1" + + # NOTE: chunk_size may be -1 for models without v1 support + self.chunk_size = chunk_size + self.prefix = prefix def forward_native( self, @@ -451,12 +458,27 @@ def forward_cuda( if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0] ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx + chunk_indices_p = attn_metadata.chunk_indices + chunk_offsets_p = attn_metadata.chunk_offsets else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + has_initial_states_p = mamba2_metadata.has_initial_states + prep_initial_states = mamba2_metadata.prep_initial_states + chunk_size = mamba2_metadata.chunk_size + seq_idx_p = mamba2_metadata.seq_idx + chunk_indices_p = mamba2_metadata.chunk_indices + chunk_offsets_p = mamba2_metadata.chunk_offsets # - get hidden_states, B and C after depthwise convolution. groups_time_state_size = self.n_groups * self.ssm_state_size @@ -505,26 +527,49 @@ def forward_cuda( has_prefill = num_prefills > 0 has_decode = num_decodes > 0 + # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension - hidden_states_B_C_p, hidden_states_B_C_d = torch.split( - hidden_states_B_C, - [num_prefill_tokens, num_decodes], - dim=0, - ) - dt_p, dt_d = torch.split( - dt, - [num_prefill_tokens, num_decodes], - dim=0, - ) - # Split along batch dimension - state_indices_tensor_p, state_indices_tensor_d = torch.split( - mamba_cache_params.state_indices_tensor, - [num_prefills, num_decodes], - dim=0, - ) - query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1] - if has_prefill else None) + if envs.VLLM_USE_V1: + hidden_states_B_C_d, hidden_states_B_C_p = torch.split( + hidden_states_B_C, + [num_decodes, num_prefill_tokens], + dim=0, + ) + dt_d, dt_p = torch.split( + dt, + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_B_C_p, hidden_states_B_C_d = torch.split( + hidden_states_B_C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) ssd_output_list = [] @@ -532,14 +577,26 @@ def forward_cuda( if has_prefill: # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions - # pointed to by "mamba_cache_params.state_indices_tensor" + # pointed to by "state_indices_tensor" + # if ".0." in self.prefix: + # print("hidden_states_B_C_p", hidden_states_B_C_p.shape) + # print("conv_weights", conv_weights.shape) + # print("conv_state", conv_state.shape) + # print("state_indices_tensor_p", state_indices_tensor_p) + # print("query_start_loc_p", query_start_loc_p.shape) + # print("query_start_loc_p", query_start_loc_p) + # print("chunk_size", chunk_size) + # print("chunk_indices_p", chunk_indices_p) + # print("chunk_offsets_p", chunk_offsets_p) + # print("has_initial_states_p", has_initial_states_p) + # print("prep_initial_states", prep_initial_states) hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p.transpose(0, 1), conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, - has_initial_state=mamba2_metadata.has_initial_states, + has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] @@ -551,11 +608,10 @@ def forward_cuda( # 3. State Space Model sequence transformation initial_states = None - if (mamba2_metadata.has_initial_states is not None - and mamba2_metadata.prep_initial_states): + if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states initial_states = torch.where( - mamba2_metadata.has_initial_states[:, None, None, None], + has_initial_states_p[:, None, None, None], ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( @@ -568,14 +624,14 @@ def forward_cuda( -1), C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), - chunk_size=mamba2_metadata.chunk_size, + chunk_size=chunk_size, D=self.D, z=None, dt_bias=self.dt_bias, - seq_idx=mamba2_metadata.seq_idx, - chunk_indices=mamba2_metadata.chunk_indices, - chunk_offsets=mamba2_metadata.chunk_offsets, - cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], + seq_idx=seq_idx_p, + chunk_indices=chunk_indices_p, + chunk_offsets=chunk_offsets_p, + cu_seqlens=query_start_loc_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, @@ -633,9 +689,16 @@ def forward_cuda( dt_softplus=True, state_batch_indices=state_indices_tensor_d, ) - ssd_output_list.append( - hidden_states_d.view(-1, (self.num_heads // self.tp_size) * - self.head_dim)) + + if envs.VLLM_USE_V1: + ssd_output_list.insert( + 0, + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + else: + ssd_output_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) # Merge prefill and decode outputs before passing to gated MLP hidden_states = torch.vstack(ssd_output_list) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index e659d61e437d..60c015a0128c 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -62,7 +62,8 @@ def __init__(self, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer") + prefix=f"{prefix}.mixer", + chunk_size=config.chunk_size) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 6add59e7fe27..854ae395618a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass from typing import TYPE_CHECKING import torch +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + _query_start_loc_to_chunk_indices_offsets) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.worker.block_table import BlockTable @@ -14,6 +18,15 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner +def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: + from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 + layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) + chunk_sizes = set(layer.chunk_size for layer in layers.values()) + assert len( + chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" + return chunk_sizes.pop() + + class Mamba2AttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, @@ -21,11 +34,12 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, self.runner = runner self.kv_cache_spec = kv_cache_spec self.block_table = block_table + self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # NOTE (Chen): Copied from FlashInferMetadataBuilder. This is not - # elegant and should be refactored. + # NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be + # refactored later to avoid code duplication. # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -86,7 +100,70 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): - raise NotImplementedError("Mamba2AttentionBackend is not implemented.") + # print("query_start_loc", common_attn_metadata.query_start_loc) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + + seq_idx = None + chunk_indices, chunk_offsets = None, None + # Need flags to indicate if there are initial states + # currently we really only support the FlashAttention backend + has_initial_states = None + prep_initial_states = False + + state_indices_tensor = self.block_table.block_table[:num_reqs, 0] + + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + if self._num_prefills > 0: + #[batch,] + has_initial_states_cpu = ( + self.runner.input_batch. + num_computed_tokens_cpu_tensor[num_reqs - + self._num_prefills:num_reqs] + > 0) + prep_initial_states = torch.any(has_initial_states_cpu).item() + has_initial_states = has_initial_states_cpu.to( + query_start_loc.device) + + query_start_loc_p = common_attn_metadata.query_start_loc[ + -self._num_prefills - 1:] - self._num_decode_tokens + # TODO: remove this debug check + assert query_start_loc_p[0] == 0 + assert query_start_loc_p[-1] == self._num_prefill_tokens + seq_idx = torch.repeat_interleave( + torch.arange(self._num_prefills, + dtype=torch.int32, + device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=self._num_prefill_tokens) + seq_idx.unsqueeze_(0) + + # We compute metadata for chunked prefill once at the top level + # model forward and reuse them in mamba layers. If not needed, + # they will be ignored inside mamba kernels. + if prep_initial_states: + chunk_indices, chunk_offsets = ( + _query_start_loc_to_chunk_indices_offsets( + query_start_loc_p, self.chunk_size, + self._num_prefill_tokens)) + + attn_metadata = Mamba2AttentionMetadata( + num_prefills=self._num_prefills, + num_prefill_tokens=self._num_prefill_tokens, + num_decodes=self._num_decodes, + num_decode_tokens=self._num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + has_initial_states=has_initial_states, + prep_initial_states=prep_initial_states, + chunk_size=self.chunk_size, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + state_indices_tensor=state_indices_tensor, + ) + # print("attn_metadata", attn_metadata) + return attn_metadata class Mamba2AttentionBackend: @@ -96,15 +173,20 @@ def get_builder_cls() -> type[Mamba2AttentionMetadataBuilder]: return Mamba2AttentionMetadataBuilder +@dataclass class Mamba2AttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + has_initial_states: torch.Tensor prep_initial_states: bool - chunk_size: int seq_idx: torch.Tensor chunk_indices: torch.Tensor chunk_offsets: torch.Tensor - def __init__(self, query_start_loc, context_lens_tensor): - self.query_start_loc = query_start_loc - self.context_lens_tensor = context_lens_tensor + state_indices_tensor: torch.Tensor # shape: [batch,] From 05fe002ae8a8fc77b5de15666ae1509e6e4e4e13 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 8 Jun 2025 08:57:50 -0700 Subject: [PATCH 3/8] default to v0 and some small fix Signed-off-by: Chen Zhang --- vllm/engine/arg_utils.py | 7 +++- .../layers/mamba/mamba_mixer2.py | 33 +++++++------------ vllm/model_executor/models/mamba2.py | 4 --- 3 files changed, 17 insertions(+), 27 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 81f160968897..095257f2580c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1350,12 +1350,17 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Mamba or Encoder-Decoder so far. + # No Encoder-Decoder, not all Mamba so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, recommend_to_remove=False) return False + # V1 mamba models are unoptimized. + if model_config.has_inner_state and _warn_or_fallback( + feature_name="Mamba"): + return False + # No Concurrent Partial Prefills so far. if (self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 6cee3e9c754a..2fd05babf2ad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -480,17 +480,7 @@ def forward_cuda( chunk_indices_p = mamba2_metadata.chunk_indices chunk_offsets_p = mamba2_metadata.chunk_offsets - # - get hidden_states, B and C after depthwise convolution. groups_time_state_size = self.n_groups * self.ssm_state_size - split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( - hidden_states_B_C, - [ - self.intermediate_size // self.tp_size, - groups_time_state_size // self.tp_size, - groups_time_state_size // self.tp_size, - ], - dim=-1, - ) # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -511,6 +501,17 @@ def forward_cuda( conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + # - get hidden_states, B and C after depthwise convolution. + split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + if envs.VLLM_USE_V1 and attn_metadata is None: # V1 profile run hidden_states_B_C = (hidden_states_B_C.transpose( @@ -578,18 +579,6 @@ def forward_cuda( # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" - # if ".0." in self.prefix: - # print("hidden_states_B_C_p", hidden_states_B_C_p.shape) - # print("conv_weights", conv_weights.shape) - # print("conv_state", conv_state.shape) - # print("state_indices_tensor_p", state_indices_tensor_p) - # print("query_start_loc_p", query_start_loc_p.shape) - # print("query_start_loc_p", query_start_loc_p) - # print("chunk_size", chunk_size) - # print("chunk_indices_p", chunk_indices_p) - # print("chunk_offsets_p", chunk_offsets_p) - # print("has_initial_states_p", has_initial_states_p) - # print("prep_initial_states", prep_initial_states) hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p.transpose(0, 1), conv_weights, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 60c015a0128c..d2403ccbb972 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -248,10 +248,6 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - # TODO in this PR: default to v0 but allows v1 - # TODO in this PR: check whether the kernel is cuda graph compatible - # TODO in this PR: current design for mamba seems incompatible with spec - # decode if not envs.VLLM_USE_V1: if self.mamba_cache is None: num_mamba_layers = ( From b8d5517535a128572fb1ef207df5f86e0562d24d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 8 Jun 2025 09:08:46 -0700 Subject: [PATCH 4/8] update comments Signed-off-by: Chen Zhang --- vllm/v1/core/single_type_kv_cache_manager.py | 8 ++++---- vllm/v1/kv_cache_interface.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 4913cfdddce7..32d6fd58366d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -408,8 +408,8 @@ def find_longest_cache_hit( assert isinstance( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") - # NOTE(Chen): prefix caching is not supported for mamba now. Always - # return empty list. + # Prefix caching is not supported for mamba now. Always return empty + # list. computed_blocks: list[list[KVCacheBlock]] = [ [] for _ in range(len(kv_cache_group_ids)) ] @@ -417,8 +417,8 @@ def find_longest_cache_hit( def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: - # NOTE(Chen): each request will always have 1 block at this moment, so - # no need to remove blocks. + # Each request will always have 1 block at this moment, so no need to + # remove blocks. pass def get_num_common_prefix_blocks(self, request_id: str, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 3090647180b6..c48775adc9b8 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -172,8 +172,8 @@ def page_size_bytes(self) -> int: return self.num_elements * get_dtype_size(self.dtype) def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # NOTE(Chen): we allocate 1 block for each request now, so - # max_memory_usage_bytes is the same as page_size_bytes. + # We allocate 1 block for each request now, so max_memory_usage_bytes is + # the same as page_size_bytes. # Need to update this when supporting prefix caching. return self.page_size_bytes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1bfb3ba814f5..191cd25a7292 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2313,8 +2313,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - # NOTE(Chen): set block_size to max_model_len, so that mamba model - # will always have only one block in the KV cache. + # Set block_size to max_model_len, so that mamba model will always + # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), From 0425c46487906c91da3c0c195140b55c955d474d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 13 Jun 2025 14:37:41 -0700 Subject: [PATCH 5/8] update comments Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mamba_attn.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 854ae395618a..e6aecfa742ef 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -38,8 +38,9 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # NOTE (Chen): Copied from FlashInferMetadataBuilder. Should be - # refactored later to avoid code duplication. + # NOTE (Chen): Copied from MLACommonMetadataBuilder and + # FlashInferMetadataBuilder. Should be refactored later to avoid code + # duplication of these 3 functions. # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -100,7 +101,6 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): - # print("query_start_loc", common_attn_metadata.query_start_loc) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -127,9 +127,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, query_start_loc_p = common_attn_metadata.query_start_loc[ -self._num_prefills - 1:] - self._num_decode_tokens - # TODO: remove this debug check - assert query_start_loc_p[0] == 0 - assert query_start_loc_p[-1] == self._num_prefill_tokens + seq_idx = torch.repeat_interleave( torch.arange(self._num_prefills, dtype=torch.int32, @@ -162,7 +160,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, chunk_offsets=chunk_offsets, state_indices_tensor=state_indices_tensor, ) - # print("attn_metadata", attn_metadata) return attn_metadata From 727d8dad1843f5efd54ac7930af99090702a473c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 13 Jun 2025 15:02:18 -0700 Subject: [PATCH 6/8] test v1 Signed-off-by: Chen Zhang --- .../models/language/generation/test_hybrid.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 3eaadcb45fe1..90c4cd968e7a 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -17,9 +17,10 @@ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", # TODO: Compare to a Mamba2 model. The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups. + # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test + # doesn't compare vLLM output with HF output. # See https://github.com/huggingface/transformers/pull/35943 - # "mistralai/Mamba-Codestral-7B-v0.1", + "mistralai/Mamba-Codestral-7B-v0.1", ] HYBRID_MODELS = [ @@ -35,6 +36,10 @@ "hmellor/tiny-random-BambaForCausalLM", ] +V1_SUPPORTED_MODELS = [ + "mistralai/Mamba-Codestral-7B-v0.1", +] + # Avoid OOM MAX_NUM_SEQS = 4 @@ -46,24 +51,50 @@ def test_models( hf_runner, vllm_runner, example_prompts, + monkeypatch, model: str, max_tokens: int, num_logprobs: int, ) -> None: with hf_runner(model) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + if model != "mistralai/Mamba-Codestral-7B-v0.1": + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + if model in V1_SUPPORTED_MODELS: + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + enforce_eager=True, + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + else: + vllm_v1_outputs = None + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + if model in V1_SUPPORTED_MODELS: + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) From 1beef4171290aebfe5f59f271d1f4b96debb47fb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 13 Jun 2025 15:25:39 -0700 Subject: [PATCH 7/8] fix ci Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/mamba_attn.py | 63 +++++++++++++----------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index e6aecfa742ef..74d619aadbdc 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -5,10 +5,12 @@ import torch +from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.mamba.mamba2_metadata import ( _query_start_loc_to_chunk_indices_offsets) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import MambaSpec from vllm.v1.worker.block_table import BlockTable @@ -27,7 +29,34 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: return chunk_sizes.pop() -class Mamba2AttentionMetadataBuilder: +class Mamba2AttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: + return Mamba2AttentionMetadataBuilder + + +@dataclass +class Mamba2AttentionMetadata: + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + + has_initial_states: torch.Tensor + prep_initial_states: bool + chunk_size: int + seq_idx: torch.Tensor + chunk_indices: torch.Tensor + chunk_offsets: torch.Tensor + + state_indices_tensor: torch.Tensor # shape: [batch,] + + +class Mamba2AttentionMetadataBuilder( + AttentionMetadataBuilder[Mamba2AttentionMetadata]): def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, block_table: BlockTable): @@ -98,9 +127,9 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): + num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -161,29 +190,3 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, state_indices_tensor=state_indices_tensor, ) return attn_metadata - - -class Mamba2AttentionBackend: - - @staticmethod - def get_builder_cls() -> type[Mamba2AttentionMetadataBuilder]: - return Mamba2AttentionMetadataBuilder - - -@dataclass -class Mamba2AttentionMetadata: - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int - query_start_loc: torch.Tensor - seq_lens: torch.Tensor - - has_initial_states: torch.Tensor - prep_initial_states: bool - chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor - - state_indices_tensor: torch.Tensor # shape: [batch,] From 9bf1ebc1f0865ee1ea738f70a58044bb88d2fd33 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 18 Jun 2025 09:52:44 -0700 Subject: [PATCH 8/8] fix v1 tests Signed-off-by: Chen Zhang --- tests/v1/test_oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index e5eadfd4e9da..1787b9a0b469 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,7 @@ UNSUPPORTED_MODELS_V1 = [ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder - "mistralai/Mamba-Codestral-7B-v0.1", # mamba + "state-spaces/mamba-130m-hf", # mamba1 "hmellor/tiny-random-BambaForCausalLM", # hybrid "BAAI/bge-m3", # embedding ]