diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 75a9c4a22cb4..06cb557e2ecf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -321,7 +321,6 @@ steps: - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py @@ -644,7 +643,7 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 mirror_hardwares: [amdexperimental] @@ -818,7 +817,8 @@ steps: # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/language -v -s -m 'distributed(num_gpus=2)' - - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' # test sequence parallel - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index df6c1eaf4a21..957db3c23b86 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -5,6 +5,8 @@ encoder/decoder models, specifically BART and mBART. This script is refactored to allow model selection via command-line arguments. + +NOTE: This example is not yet supported in V1. """ import argparse diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index 655f9f3fce7a..35e9203d1caf 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -5,6 +5,7 @@ the explicit/implicit prompt format on enc-dec LMMs for text generation. """ +import os import time from collections.abc import Sequence from dataclasses import asdict @@ -130,6 +131,8 @@ def run_mllama(): def run_whisper(): + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 8b99d9d6e21f..3cf4c377fb58 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -63,6 +63,7 @@ def clear_cache(): current_platform.is_cpu(), reason="CPU backend is not currently supported with encoder/decoder models" ) +@pytest.mark.skip(reason="bart not supported in V1") def test_encoder_decoder_e2e( hf_runner, vllm_runner, diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py index 9c2aef23e877..75612962c95f 100644 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -30,6 +30,7 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="bart is not yet supported in V1") async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): completion = await client.completions.create(model=model_name, prompt="Hello, my name is", diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py index b4c771840196..22ceb27869ac 100644 --- a/tests/models/language/generation/test_bart.py +++ b/tests/models/language/generation/test_bart.py @@ -178,6 +178,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.skip(reason="bart not supported in V1") def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: @@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +@pytest.mark.skip(reason="bart not supported in V1") def test_models_distributed(hf_runner, vllm_runner, example_encoder_decoder_prompts, distributed_executor_backend, model, dtype, diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 4a65e8c95204..e0e9980b8833 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -122,8 +122,7 @@ def run_test( @pytest.mark.core_model -@pytest.mark.parametrize( - "model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"]) +@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"]) @create_new_process_for_each_test() def test_models(vllm_runner, model) -> None: run_test( diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b678313752d6..3b87b669dbbe 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -31,6 +31,7 @@ ARCH_TO_SKIP = { "MolmoForCausalLM": "incompatible requirements", + "Florence2ForConditionalGeneration": "not supported in V1", } ARCH_NEEDS_EXTRAS = [ "InternVLChatModel", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index aaa04f52f779..792b93fbcd0f 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -68,6 +68,12 @@ def _initialize_kv_caches_v1(self, vllm_config): # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3. m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1") + if model_arch == "Florence2ForConditionalGeneration": + # An encoder-decoder model that's V0-only. Just skip it + # since V0 is about to be removed. + pytest.skip("Skipping Florence2ForConditionalGeneration") + if model_arch == "WhisperForConditionalGeneration": + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") LLM( model_info.default, tokenizer=model_info.tokenizer, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 1f16e92f657e..efa604dd6b5a 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -10,7 +10,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine UNSUPPORTED_MODELS_V1 = [ - "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder ] diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py new file mode 100644 index 000000000000..5f814b23888b --- /dev/null +++ b/vllm/attention/layers/cross_attention.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from copy import copy +from typing import Optional + +import numpy as np +import torch +from transformers import CacheConfig + +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + subclass_attention_backend) +from vllm.v1.kv_cache_interface import CrossAttentionSpec + +logger = init_logger(__name__) + + +def _get_max_encoder_len(vllm_config: VllmConfig) -> int: + return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len( + vllm_config.model_config) + + +def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, + block_table_tensor: torch.Tensor, + kv_cache_spec: CrossAttentionSpec, + device: torch.device) -> torch.Tensor: + """Get cross-attention slot mappings.""" + + block_size = kv_cache_spec.block_size + slot_mappings = [] + + # Find indices with non-zero encoder sequence lengths + # The majority of parallel requests will be running the + # decoder, so this list should be relatively small. + active_indices = np.nonzero(encoder_seq_lens)[0] + + for req_index in active_indices: + encoder_seq_len = encoder_seq_lens[req_index].item() + + # Calculate the number of blocks needed for this request + num_blocks_needed = cdiv(encoder_seq_len, block_size) + + # Get the block IDs for this request from the tensor + req_block_ids = block_table_tensor[req_index] + + # Get only the blocks we need (first num_blocks_needed blocks) + needed_block_ids = req_block_ids[:num_blocks_needed] + + # All needed blocks are allocated + i_values = torch.arange(encoder_seq_len, + dtype=torch.int64, + device=device) + block_indices = i_values // block_size + block_offsets = i_values % block_size + block_numbers = needed_block_ids[block_indices] + slot_mapping = block_numbers * block_size + block_offsets + + slot_mappings.append(slot_mapping) + + if slot_mappings: + return torch.cat(slot_mappings) + else: + return torch.empty(0, dtype=torch.int64, device=device) + + +@functools.lru_cache +def create_cross_attention_backend( + underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + prefix = "CrossAttention_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class CrossAttentionBuilder(underlying_builder): # type: ignore + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> AttentionMetadata: + new_metadata = copy(common_attn_metadata) + new_metadata.causal = False + max_encoder_len = _get_max_encoder_len(self.vllm_config) + new_metadata.max_seq_len = max_encoder_len + + new_metadata.seq_lens = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + new_metadata.seq_lens_cpu = torch.full( + (new_metadata.num_reqs, ), + max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + new_metadata.slot_mapping = _get_cross_slot_mapping( + new_metadata.encoder_seq_lens, new_metadata.block_table_tensor, + self.kv_cache_spec, self.device) + return super().build(common_prefix_len, new_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=CrossAttentionBuilder) + + return attn_backend + + +class CrossAttention(Attention): + """ + Cross-attention for encoder-decoder models. + Handles attention between decoder queries and encoder keys/values. + """ + + def __init__(self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if envs.VLLM_USE_V1: + underlying_attn_backend = get_attn_backend(head_size, dtype, + kv_cache_dtype, + block_size) + + attn_backend = create_cross_attention_backend( + underlying_attn_backend) + else: + # in v0 cross attention is handled inside the backends + attn_backend = None + + if attn_type is not None: + assert attn_type == AttentionType.ENCODER_DECODER, ( + "CrossAttention only supports AttentionType.ENCODER_DECODER") + + super().__init__(num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_DECODER, + **kwargs) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index c8d531f12a2e..168d12de0bb2 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -8,6 +8,7 @@ import hashlib import inspect import json +import os import textwrap import warnings from collections.abc import Mapping @@ -41,6 +42,7 @@ from vllm.config.utils import ConfigType, config from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -3512,16 +3514,33 @@ def __post_init__(self): disable_chunked_prefill_reasons: list[str] = [] - if self.model_config and self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": - disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") - elif not getattr(self.model_config.hf_config, "is_causal", True): + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + elif self.model_config.is_encoder_decoder: + self.scheduler_config.max_num_encoder_input_tokens = \ + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens) + self.scheduler_config.disable_chunked_mm_input = True disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") + if (self.model_config.architecture + == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") + != "spawn"): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 9451670bc60c..27e8b6fa5535 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -600,7 +600,6 @@ def __init__( self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "whisper_encoder"), - is_standalone_encoder=True, init_in_fp32=True) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 97e8cd6e7695..41ae7b129782 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -15,6 +15,7 @@ from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.cross_attention import CrossAttention from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig) from vllm.distributed import get_tensor_model_parallel_world_size @@ -43,7 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema): TensorShape("b", "nmb", "t")] +class WhisperEncoderAttention(MultiHeadAttention): + """Multi-headed attention for Whisper encoder with 2D tensor support.""" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """ + Input shape: batch_size x seq_len x hidden_size + or seq_len x hidden_size + """ + is_2d = query.dim() == 2 + if is_2d: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + # Call the parent forward method + out = super().forward(query, key, value) + + if is_2d: + out = out.squeeze(0) + + return out + + class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int): @@ -144,7 +173,6 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - standalone_encoder: bool = False, ): super().__init__() self.embed_dim = embed_dim @@ -180,14 +208,25 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - if standalone_encoder: - self.attn = MultiHeadAttention( + if attn_type == AttentionType.ENCODER: + self.attn = WhisperEncoderAttention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, ) - else: + elif self.attn_type == AttentionType.ENCODER_DECODER: + self.attn = CrossAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + else: # AttentionType.DECODER (regular decoder self-attention) self.attn = Attention( self.num_heads, self.head_dim, @@ -332,11 +371,7 @@ def forward(self, hidden_states: torch.Tensor): class WhisperEncoderLayer(nn.Module): - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - is_standalone_encoder: bool = False): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -350,7 +385,6 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - standalone_encoder=is_standalone_encoder, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.mlp = WhisperMLP( @@ -446,12 +480,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - is_standalone_encoder: bool = False, init_in_fp32: bool = False): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model - self.is_standalone_encoder = is_standalone_encoder self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = (math.sqrt(embed_dim) @@ -469,9 +501,7 @@ def __init__(self, self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers", - is_standalone_encoder= - is_standalone_encoder), + prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -752,7 +782,7 @@ def _get_prompt_updates( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -880,19 +910,17 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: - # TODO: This method just returns the decoder sequence embeddings since - # Whisper does not have encoder text tokens. Refactor this once - # encoder/decoder support is implemented in V1. + # This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) def _parse_and_validate_audio_input( diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 8a9c660b882f..5d9206e18832 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: encoder_attention_heads=encoder_args["n_heads"], vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default ) } if quant_config: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ced8234a7b43..ab87f3bb4e3c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: - self.kv_cache_spec = kv_cache_spec - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.scheduler_config = vllm_config.scheduler_config # For reorder diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3cc67acd04c6..20f1904b3be6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c7a565810b45..39e7deaa9160 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d5b1c15e68d0..cb983494216a 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index ac0034b5dcf0..3ff201d83a79 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 07ef7cb69a16..9970331a6042 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, @@ -52,4 +49,4 @@ def build_for_cudagraph_capture( m.max_query_len = 1 # decode-only - return self.build(0, m) \ No newline at end of file + return self.build(0, m) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 173a0a255e49..a4e2758bd311 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config - self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) @@ -248,7 +248,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index fcbf0c7b5356..f5ad65b02b4d 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec def build(self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index b96d957a150b..10238f36455d 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -165,7 +165,8 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 104cebb45d74..fe2894eaa075 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder( def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - self.device = device + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8e3d530fc1f9..009943fa743d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -72,6 +72,9 @@ class CommonAttentionMetadata: logits_indices_padded: Optional[torch.Tensor] = None num_logits_indices: Optional[int] = None + # Needed by CrossAttentionBuilder + encoder_seq_lens: Optional[np.ndarray] = None + @dataclass class UbatchSlice: @@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device @abstractmethod def build(self, diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index c59ff32cf7c2..a6ca33491235 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -206,8 +206,9 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + assert XFORMERS_AVAILABLE - self.kv_cache_spec = kv_cache_spec self.block_size = kv_cache_spec.block_size self._num_decodes = 0 self._num_decode_tokens = 0 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ed7c16dc520f..d1a6dd73e85c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -144,8 +144,8 @@ def __init__( ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed). Currently, we assume that the encoder also - # has the Transformer architecture (e.g., ViT). + # projector if needed) for MM models as well as encoder-decoder + # transformers. self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 @@ -775,15 +775,19 @@ def _try_schedule_encoder_inputs( # in the decoder's KV cache. continue - # The same encoder input has already been scheduled in the current - # step. - if request.mm_hashes[i] in mm_hashes_to_schedule: - continue + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_hashes[i] in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue - if self.encoder_cache_manager.check_and_update_cache(request, i): - # The encoder input is already computed and cached from a - # previous step. - continue + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would @@ -1047,7 +1051,13 @@ def _free_encoder_inputs(self, request: Request) -> None: mm_positions = request.mm_positions[input_id] start_pos = mm_positions.offset num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 8ce070e4d6fb..96e89eeac556 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -325,7 +325,6 @@ def process_inputs( ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) if trace_headers is not None: @@ -384,10 +383,6 @@ def process_inputs( encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError - sampling_params = None pooling_params = None if isinstance(params, SamplingParams): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 33f4d65a7a11..d3822251b5b6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,12 +61,16 @@ create_fast_prefill_custom_backend, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +# yapf conflicts with isort for this block +# yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, + CrossAttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) +# yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, ModelRunnerOutput, SamplerOutput) @@ -208,6 +212,14 @@ def __init__( self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config) + if self.model_config.is_encoder_decoder: + # Maximum length of the encoder input, only for encoder-decoder + # models. + self.max_encoder_len = self.mm_registry.\ + get_encdec_max_encoder_len(model_config) + else: + self.max_encoder_len = 0 + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -265,7 +277,9 @@ def __init__( # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -798,6 +812,24 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, src=self.input_batch.prev_sampled_token_ids[ prev_common_req_indices_tensor, 0]) + def _get_encoder_seq_lens( + self, + scheduler_output: "SchedulerOutput", + kv_cache_spec: KVCacheSpec, + num_reqs: int, + ) -> Optional[np.ndarray]: + if not isinstance(kv_cache_spec, CrossAttentionSpec): + return None + + # Build encoder_seq_lens array mapping request indices to + # encoder lengths for inputs scheduled in this batch + encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) + for req_id in scheduler_output.scheduled_encoder_inputs: + req_index = self.input_batch.req_id_to_index[req_id] + encoder_seq_lens[req_index] = self.max_encoder_len + + return encoder_seq_lens + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -937,6 +969,8 @@ def _prepare_inputs( # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + encoder_seq_lens = self._get_encoder_seq_lens( + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): @@ -981,6 +1015,7 @@ def _prepare_inputs( logits_indices_padded=logits_indices_padded, num_logits_indices=logits_indices.size(0), causal=True, + encoder_seq_lens=encoder_seq_lens, ) if self.speculative_config and \ @@ -1253,10 +1288,24 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) return logits_indices_padded - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + def _batch_mm_kwargs_from_scheduler( + self, + scheduler_output: "SchedulerOutput", + ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: + """Batch multimodal kwargs from scheduled encoder inputs. + + Args: + scheduler_output: The scheduler output containing scheduled encoder + inputs. + + Returns: + A tuple of (mm_kwargs, req_ids_pos) where: + - mm_kwargs: List of multimodal kwargs items to be batched + - mm_hashes_pos: List of (mm_hash, position_info) tuples + """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return + return [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() # list of tuple (mm_hash, position_info) @@ -1270,6 +1319,16 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): mm_hashes_pos.append( (mm_hash, req_state.mm_positions[mm_input_id])) + return mm_kwargs, mm_hashes_pos + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): + # Batch the multi-modal inputs using the helper method. + mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + scheduler_output) + + if not mm_kwargs: + return + # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. @@ -1360,6 +1419,35 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models. + + This method extracts multimodal input features from scheduled encoder + inputs and formats them for the encoder-decoder model forward pass. + """ + # Batch the multi-modal inputs using the helper method. + mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) + + if not mm_kwargs: + return {} + + # Group MM kwargs by modality and extract features + encoder_features = {} + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + ): + # Add the grouped features to encoder_features dict + # This allows the model to receive them as kwargs (e.g., + # input_features=...) + encoder_features.update(mm_kwargs_group) + + return encoder_features + def get_model(self) -> nn.Module: # get raw model out of the cudagraph wrapper. if isinstance(self.model, CUDAGraphWrapper): @@ -1631,7 +1719,8 @@ def _preprocess( # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.supports_mm_inputs and get_pp_group().is_first_rank: + if (self.supports_mm_inputs and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -1673,6 +1762,11 @@ def _preprocess( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + model_kwargs.update(encoder_inputs) + return ( num_scheduled_tokens, num_input_tokens, @@ -2591,17 +2685,18 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, remove_lora): - if self.supports_mm_inputs: + model_kwargs = self._init_model_kwargs(num_tokens) + if (self.supports_mm_inputs + and not self.model_config.is_encoder_decoder): input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { - **self._init_model_kwargs(num_tokens), + **model_kwargs, **self._dummy_mm_kwargs(num_reqs), } else: input_ids = self.input_ids.gpu[:num_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_tokens] @@ -2823,7 +2918,6 @@ def profile_run(self) -> None: mm_budget = self.mm_budget assert mm_budget is not None - # TODO: handle encoder-decoder models once we support them. if (encoder_budget := mm_budget.get_encoder_budget()) > 0: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when @@ -3170,7 +3264,7 @@ def may_reinitialize_input_batch(self, "for more details.") self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -3443,7 +3537,7 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_spec = EncoderOnlyAttentionSpec( + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, @@ -3485,7 +3579,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: @@ -3513,12 +3606,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6767804c71b9..be05d02ff29f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry +from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec @@ -269,7 +270,17 @@ def bind_kv_cache( # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. - raise NotImplementedError + + # TODO - analyze where runner_kv_caches is used and the right + # way to ensure it properly reflects multiple attention layers + # in the same decoder block. + if current_platform.is_cuda(): + # We know that the GPU runner is not impacted by this + # case. Some test code depends on runner_kv_caches, but + # not in a way that's impacted by ignoring this. + pass + else: + raise NotImplementedError layer_name = layer_names[0] runner_kv_caches.append(kv_caches[layer_name])