-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[v1] Add Whisper model support (encoder-decoder) #21088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+429
−92
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
a3b89cc
v1: Add Whisper encoder-decoder model support
russellb e689784
prevent voxtral from being detected as encoder-decoder
russellb a0aeea2
Drop whisper-small from tests
russellb d2b51e1
Use correct number of encoder tokens in attention metadata
russellb a57ee42
Force spawn multiproc method in whisper example
russellb 44193e5
Warn if Whisper is used without spawn
russellb 21e5a90
Run whisper test with spawn multiproc method
russellb 8ecc6c7
whisper: simplify encoder attention by using pytorch
russellb 35b2125
Limit whisper encoder concurrency
russellb 2423682
Drop EncoderAttention abstractions no longer needed
russellb dec14fe
remove debug logs
russellb b9b228b
encoder-decoder does not use encoder-cache yet
russellb be9add9
Simplify and reduce duplication in TorchAttention
russellb 7474c67
Use existing MultiHeadAttention instead of new TorchAttention
russellb bf4c7c1
Remove some unnecessary variables that were not used
russellb f99f5d7
Remove unused TorchAttention
russellb 5c858d5
Replace a slow Python loop with torch
russellb 11b9b9e
Move max_seq_len override into CrossAttentionBuilder
russellb a875bfc
Move seq_lens / seq_lens_cpu overrides into CrossAttentionBuilder
russellb 8c7176f
Revert unnecessary code move to reduce diff size
russellb 7304b3e
further simplification of slot_mappings for cross attn
russellb 83fd244
move slot mapping calculation back into CrossAttentionBuilder
russellb 04cf403
improve how we get the number of blocks needed
russellb 431db03
make python loop more efficient
russellb 279c0b0
remove leftover docstring addition
russellb d466655
Restore TODO that was accidentally removed
russellb bc7277e
remove variables no longer needed
russellb 1c28542
Ensure AttentionMetadataBuilder subclasses call parent constructor
russellb 4b31447
remove old param from docstring
russellb 8305707
move MultiHeadAttention changes into whisper.py
russellb 3c10b1f
use vllm.utils.cdiv
russellb 7444d37
minor refactoring
russellb 4afbdcb
revert unnecessary changes left over
russellb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import functools | ||
from copy import copy | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
from transformers import CacheConfig | ||
|
||
from vllm import envs | ||
from vllm.attention.backends.abstract import (AttentionBackend, | ||
AttentionMetadata, AttentionType) | ||
from vllm.attention.layer import Attention | ||
from vllm.attention.selector import get_attn_backend | ||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.utils import cdiv | ||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, | ||
subclass_attention_backend) | ||
from vllm.v1.kv_cache_interface import CrossAttentionSpec | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
def _get_max_encoder_len(vllm_config: VllmConfig) -> int: | ||
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len( | ||
vllm_config.model_config) | ||
|
||
|
||
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, | ||
block_table_tensor: torch.Tensor, | ||
kv_cache_spec: CrossAttentionSpec, | ||
device: torch.device) -> torch.Tensor: | ||
"""Get cross-attention slot mappings.""" | ||
|
||
block_size = kv_cache_spec.block_size | ||
slot_mappings = [] | ||
|
||
# Find indices with non-zero encoder sequence lengths | ||
# The majority of parallel requests will be running the | ||
# decoder, so this list should be relatively small. | ||
active_indices = np.nonzero(encoder_seq_lens)[0] | ||
|
||
for req_index in active_indices: | ||
encoder_seq_len = encoder_seq_lens[req_index].item() | ||
|
||
# Calculate the number of blocks needed for this request | ||
num_blocks_needed = cdiv(encoder_seq_len, block_size) | ||
|
||
# Get the block IDs for this request from the tensor | ||
req_block_ids = block_table_tensor[req_index] | ||
|
||
# Get only the blocks we need (first num_blocks_needed blocks) | ||
needed_block_ids = req_block_ids[:num_blocks_needed] | ||
|
||
# All needed blocks are allocated | ||
i_values = torch.arange(encoder_seq_len, | ||
dtype=torch.int64, | ||
device=device) | ||
block_indices = i_values // block_size | ||
block_offsets = i_values % block_size | ||
block_numbers = needed_block_ids[block_indices] | ||
slot_mapping = block_numbers * block_size + block_offsets | ||
|
||
slot_mappings.append(slot_mapping) | ||
|
||
if slot_mappings: | ||
return torch.cat(slot_mappings) | ||
else: | ||
return torch.empty(0, dtype=torch.int64, device=device) | ||
|
||
|
||
@functools.lru_cache | ||
def create_cross_attention_backend( | ||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: | ||
prefix = "CrossAttention_" | ||
underlying_builder = underlying_attn_backend.get_builder_cls() | ||
|
||
class CrossAttentionBuilder(underlying_builder): # type: ignore | ||
|
||
def build(self, | ||
common_prefix_len: int, | ||
common_attn_metadata: CommonAttentionMetadata, | ||
fast_build: bool = False) -> AttentionMetadata: | ||
new_metadata = copy(common_attn_metadata) | ||
new_metadata.causal = False | ||
max_encoder_len = _get_max_encoder_len(self.vllm_config) | ||
new_metadata.max_seq_len = max_encoder_len | ||
|
||
new_metadata.seq_lens = torch.full( | ||
heheda12345 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
(new_metadata.num_reqs, ), | ||
max_encoder_len, | ||
dtype=torch.int32, | ||
device=self.device, | ||
) | ||
new_metadata.seq_lens_cpu = torch.full( | ||
(new_metadata.num_reqs, ), | ||
max_encoder_len, | ||
dtype=torch.int32, | ||
device="cpu", | ||
) | ||
new_metadata.slot_mapping = _get_cross_slot_mapping( | ||
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor, | ||
self.kv_cache_spec, self.device) | ||
return super().build(common_prefix_len, new_metadata, fast_build) | ||
|
||
attn_backend = subclass_attention_backend( | ||
name_prefix=prefix, | ||
attention_backend_cls=underlying_attn_backend, | ||
builder_cls=CrossAttentionBuilder) | ||
|
||
return attn_backend | ||
|
||
|
||
class CrossAttention(Attention): | ||
""" | ||
Cross-attention for encoder-decoder models. | ||
Handles attention between decoder queries and encoder keys/values. | ||
""" | ||
|
||
def __init__(self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
cache_config: Optional[CacheConfig] = None, | ||
attn_type: Optional[str] = None, | ||
**kwargs): | ||
dtype = torch.get_default_dtype() | ||
|
||
if cache_config is not None: | ||
kv_cache_dtype = cache_config.cache_dtype | ||
block_size = cache_config.block_size | ||
else: | ||
kv_cache_dtype = "auto" | ||
block_size = 16 | ||
|
||
if envs.VLLM_USE_V1: | ||
underlying_attn_backend = get_attn_backend(head_size, dtype, | ||
kv_cache_dtype, | ||
block_size) | ||
|
||
attn_backend = create_cross_attention_backend( | ||
underlying_attn_backend) | ||
else: | ||
# in v0 cross attention is handled inside the backends | ||
attn_backend = None | ||
|
||
if attn_type is not None: | ||
assert attn_type == AttentionType.ENCODER_DECODER, ( | ||
"CrossAttention only supports AttentionType.ENCODER_DECODER") | ||
|
||
super().__init__(num_heads=num_heads, | ||
head_size=head_size, | ||
scale=scale, | ||
cache_config=cache_config, | ||
attn_backend=attn_backend, | ||
attn_type=AttentionType.ENCODER_DECODER, | ||
**kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.