diff --git a/tests/conftest.py b/tests/conftest.py index f9dfabc82639..a37fc0f576cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,8 +27,9 @@ from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from vllm.inputs import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, + TextPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -654,21 +655,26 @@ def __init__( def get_inputs( self, - prompts: List[str], + prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[TextPrompt]: + ) -> List[Union[TextPrompt, EmbedsPrompt]]: if images is not None: - assert len(prompts) == len(images) + assert len(prompts_or_prompt_embeds) == len(images) if videos is not None: - assert len(prompts) == len(videos) + assert len(prompts_or_prompt_embeds) == len(videos) if audios is not None: - assert len(prompts) == len(audios) + assert len(prompts_or_prompt_embeds) == len(audios) + + inputs = [ + EmbedsPrompt(prompt_embeds=prompt) if isinstance( + prompt, torch.Tensor) else TextPrompt(prompt=prompt) + for prompt in prompts_or_prompt_embeds + ] - inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): if image is not None: @@ -696,13 +702,13 @@ def classify(self, prompts: List[str]) -> List[str]: def generate( self, - prompts: List[str], + prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[List[int]], List[str]]]: - inputs = self.get_inputs(prompts, + inputs = self.get_inputs(prompts_or_prompt_embeds, images=images, videos=videos, audios=audios) @@ -720,7 +726,7 @@ def generate( output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -785,14 +791,14 @@ def generate_encoder_decoder_w_logprobs( def generate_greedy( self, - prompts: List[str], + prompts_or_prompt_embeds: Union[List[str], List[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, + outputs = self.generate(prompts_or_prompt_embeds, greedy_params, images=images, videos=videos, diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index 05117666f8c3..ab6f15900744 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -54,9 +54,22 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) + prompt_embeds = [] + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0)) + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) + # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. @@ -68,3 +81,10 @@ def test_models( name_0="hf", name_1="vllm", ) + + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 433a9b30ba57..aa2952268789 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,3 +1,4 @@ +import random from typing import List import pytest @@ -22,8 +23,9 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: return model_runner -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) +def test_prepare_prompt(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -34,11 +36,19 @@ def test_prepare_prompt(batch_size): seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] block_tables = {0: [1]} + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + range(seq_len), + prompt_embeds=torch.rand(seq_len, 10), + ) + input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -59,6 +69,8 @@ def test_prepare_prompt(batch_size): seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions + input_embeds = model_input.input_embeds + input_embeds_masks = model_input.input_embeds_masks attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping @@ -112,7 +124,12 @@ def test_prepare_prompt(batch_size): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_embeds_masks) == sum(seq_lens) + if input_embeds_len == 0: + torch.testing.assert_close(input_tokens, input_positions) + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -136,8 +153,9 @@ def test_prepare_prompt(batch_size): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("prompt_embeds_ratio", (0.0, 0.5, 1.0)) +def test_prepare_decode_cuda_graph(batch_size, prompt_embeds_ratio): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -151,11 +169,19 @@ def test_prepare_decode_cuda_graph(batch_size): context_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] # Assume each seq group finishes prefill. + input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData.from_seqs(range(context_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(context_len, 10), + ) + input_embeds_len += context_len + else: + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -171,9 +197,13 @@ def test_prepare_decode_cuda_graph(batch_size): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( - model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + (input_tokens, input_positions, input_embeds, input_embeds_masks, + attn_metadata, + slot_mapping) = (model_input.input_tokens, model_input.input_positions, + model_input.input_embeds, model_input.input_embeds_masks, + model_input.attn_metadata, + model_input.attn_metadata.slot_mapping) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) @@ -226,6 +256,8 @@ def test_prepare_decode_cuda_graph(batch_size): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs torch.allclose(input_tokens, input_positions) + assert input_embeds is None + assert input_embeds_masks is None # Verify Sampling expected_selected_token_indices = [] @@ -256,14 +288,19 @@ def test_empty_seq_group(): seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + (input_tokens, input_positions, input_embeds, input_embeds_masks, + attn_metadata) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.input_embeds, + model_input.input_embeds_masks, + model_input.attn_metadata, + ) assert input_tokens is None assert input_positions is None assert attn_metadata is None + assert input_embeds is None + assert input_embeds_masks is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) @@ -289,9 +326,11 @@ def distributed_init(): ensure_model_parallel_initialized(1, 1) -@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, distributed_init): +@pytest.mark.parametrize('prompt_embeds_ratio', [0.0, 0.5, 1.0]) +def test_hybrid_batches(batch_size, enforce_eager, prompt_embeds_ratio, + distributed_init): model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -310,11 +349,19 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size + input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(seq_len, 10), + ) + input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -330,7 +377,13 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(context_len)) + if random.random() < prompt_embeds_ratio: + seq_data = SequenceData.from_seqs( + [], + prompt_embeds=torch.rand(context_len, 10), + ), + else: + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( @@ -345,11 +398,14 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + (input_tokens, input_positions, attn_metadata, input_embeds, + input_embeds_masks) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.input_embeds, + model_input.input_embeds_masks, + ) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -359,6 +415,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) + assert len(input_embeds_masks) == sum(seq_lens) + if input_embeds_len == 0: + assert input_embeds is None + else: + assert len(input_embeds) == input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5d321fc98aeb..4194d5de473b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,7 +7,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union, cast, overload +from typing import Set, Tuple, Type, Union, cast, overload import torch from typing_extensions import TypeVar @@ -60,6 +60,7 @@ usage_message) from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -1983,6 +1984,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() + def _support_prompt_embeds(self) -> Tuple[bool, str]: + if self.speculative_config is not None: + return False, "Speculative decoding does not support prompt_embeds." + driver_worker = self.model_executor.driver_worker + model_runner = driver_worker.worker.model_runner if isinstance( + driver_worker, WorkerWrapperBase) else driver_worker.model_runner + if model_runner.model_supports_input_embeds: + return True, "" + return False, (f"Model {self.model_config.model} does not support " + "input embeddings, but prompt_embeds was provided.") + def _validate_model_inputs(self, inputs: ProcessorInputs, lora_request: Optional[LoRARequest]): if is_encoder_decoder_inputs(inputs): @@ -1994,8 +2006,22 @@ def _validate_model_inputs(self, inputs: ProcessorInputs, prompt_inputs = inputs prompt_ids = prompt_inputs.get("prompt_token_ids") + prompt_embeds = prompt_inputs.get("prompt_embeds") + + if prompt_ids is None: + if prompt_embeds is None: + raise ValueError("You must provide a prompt") + else: + self._validate_prompt_embeds(prompt_embeds) + else: + if prompt_embeds is None: + self._validate_prompt_ids(prompt_ids) + else: + raise ValueError("You can only provide either tokens or " + "embeddings, not both") - if prompt_ids is None or len(prompt_ids) == 0: + def _validate_prompt_ids(self, prompt_ids: List[int]): + if len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") if self.model_config.is_multimodal_model: @@ -2014,6 +2040,14 @@ def _validate_model_inputs(self, inputs: ProcessorInputs, # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + def _validate_prompt_embeds(self, prompt_embeds: torch.Tensor): + if len(prompt_embeds) == 0: + raise ValueError("Prompt cannot be empty") + + support_prompt_embeds, error_msg = self._support_prompt_embeds() + if not support_prompt_embeds: + raise ValueError(error_msg) + def _build_logits_processors( self, sampling_params: SamplingParams, lora_request: Optional[LoRARequest]) -> SamplingParams: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 68ac50a2c5a1..15ab36acf925 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, +from .data import (DecoderOnlyInputs, EmbedInputs, EmbedsPrompt, + EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, embed_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import DummyData, InputContext, InputRegistry @@ -17,11 +18,14 @@ __all__ = [ "TextPrompt", "TokensPrompt", + "EmbedsPrompt", "PromptType", "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "TokenInputs", "token_inputs", + "EmbedInputs", + "embed_inputs", "DecoderOnlyInputs", "EncoderDecoderInputs", "ProcessorInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 46b41f431bec..3a5c73a3a8e9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,6 +1,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) +import torch from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: @@ -49,12 +50,26 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +class EmbedsPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_embeds: torch.Tensor + """Embeddings of the prompt to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- Embeddings of a prompt (:class:`EmbedsPrompt`) Note that "singleton" is as opposed to a data structure which encapsulates multiple prompts, i.e. of the sort @@ -176,10 +191,40 @@ def token_inputs( return inputs -DecoderOnlyInputs = TokenInputs +class EmbedInputs(TypedDict): + """Represents embedding-based inputs.""" + + type: Literal["embed"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +def embed_inputs( + prompt_embeds: torch.Tensor, + multi_modal_data: Optional["MultiModalDataDict"] = None, +) -> EmbedInputs: + """Construct :class:`EmbedInputs` from optional values.""" + inputs = EmbedInputs(type="embed", prompt_embeds=prompt_embeds) + + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedInputs] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. + This specifies the data required for decoder-only models. """ @@ -198,7 +243,7 @@ class EncoderDecoderInputs(TypedDict): """The inputs for the decoder portion.""" -SingletonInputs = TokenInputs +SingletonInputs = Union[TokenInputs, EmbedInputs] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 09f1ff2cb42e..3b3e3e60458a 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import (EmbedsPrompt, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -83,13 +83,22 @@ class ParsedTokensPrompt(TypedDict): content: TokensPrompt +class ParsedEmbedsPrompt(TypedDict): + type: Literal["embeds"] + content: EmbedsPrompt + + def parse_singleton_prompt( prompt: SingletonPrompt, -) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, + ParsedEmbedsPrompt]: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: + if 'prompt_embeds' in prompt: + return ParsedEmbedsPrompt(type="embeds", + content=prompt) # type: ignore + elif "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore elif "prompt" in prompt: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a5c787a56b5a..ddf586ad9479 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,6 +1,7 @@ import asyncio from typing import List, Optional +import torch from typing_extensions import assert_never from vllm.config import ModelConfig @@ -11,7 +12,8 @@ from vllm.utils import print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) + PromptType, SingletonInputs, SingletonPrompt, embed_inputs, + token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -169,6 +171,13 @@ def _apply_prompt_adapter( return prompt_token_ids + def _validate_embed_inputs(self, prompt_embeds: torch.Tensor): + if len(prompt_embeds.shape) != 2: + raise ValueError("Embeddings should be a 2D input with shape " + "`(num_tokens, embed_dim)`") + + return prompt_embeds + def _tokenize_prompt( self, prompt: str, @@ -268,6 +277,17 @@ def _prompt_to_llm_inputs( mm_processor_kwargs=mm_processor_kwargs, ) + if parsed["type"] == "embeds": + embeds_content = parsed["content"] + + prompt_embeds = embeds_content["prompt_embeds"] + multi_modal_data = embeds_content.get("multi_modal_data") + + return embed_inputs( + prompt_embeds=prompt_embeds, + multi_modal_data=multi_modal_data, + ) + assert_never(parsed) async def _prompt_to_llm_inputs_async( @@ -324,6 +344,17 @@ async def _prompt_to_llm_inputs_async( mm_processor_kwargs=mm_processor_kwargs, ) + if parsed["type"] == "embeds": + embeds_content = parsed["content"] + + prompt_embeds = embeds_content["prompt_embeds"] + multi_modal_data = embeds_content.get("multi_modal_data") + + return embed_inputs( + prompt_embeds=prompt_embeds, + multi_modal_data=multi_modal_data, + ) + assert_never(parsed) def _build_enc_dec_llm_inputs( @@ -333,6 +364,11 @@ def _build_enc_dec_llm_inputs( ) -> EncoderDecoderInputs: if encoder_inputs["type"] == "token": pass + elif encoder_inputs["type"] == "embed": + raise NotImplementedError("Embedding inputs are not supported for " + "encoder-decoder models yet") + encoder_inputs["prompt_embeds"] = self._validate_embed_inputs( + encoder_inputs["prompt_embeds"]) else: assert_never(encoder_inputs) @@ -348,6 +384,11 @@ def _build_enc_dec_llm_inputs( if "multi_modal_data" in decoder_inputs: raise ValueError("Multi-modal decoder inputs of encoder-" "decoder models are not supported yet") + elif decoder_inputs["type"] == "embed": + raise NotImplementedError("Embedding inputs are not supported for " + "encoder-decoder models yet") + decoder_inputs["prompt_embeds"] = self._validate_embed_inputs( + decoder_inputs["prompt_embeds"]) else: assert_never(encoder_inputs) @@ -465,6 +506,9 @@ def _build_decoder_only_llm_inputs( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, ) + elif prompt_inputs["type"] == "embed": + prompt_inputs["prompt_embeds"] = self._validate_embed_inputs( + prompt_inputs["prompt_embeds"]) else: assert_never(prompt_inputs) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d66373512b95..9720065c4267 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,5 +1,6 @@ -from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, - SupportsPP, has_inner_state, supports_lora, +from .interfaces import (HasInnerState, SupportsInputEmbeds, SupportsLoRA, + SupportsMultiModal, SupportsPP, has_inner_state, + supports_input_embeds, supports_lora, supports_multimodal, supports_pp) from .interfaces_base import (VllmModelForEmbedding, VllmModelForTextGeneration, is_embedding_model, @@ -14,6 +15,8 @@ "is_text_generation_model", "HasInnerState", "has_inner_state", + "SupportsInputEmbeds", + "supports_input_embeds", "SupportsLoRA", "supports_lora", "SupportsMultiModal", diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 5b712ba83c25..997fe5bb8c12 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -33,7 +33,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -396,9 +396,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -447,9 +451,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 1fbf4135add7..04e3a989c678 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -45,7 +45,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -286,9 +286,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -363,9 +367,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 83ff39a30fbe..b5ecd196752d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -258,9 +258,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, + inputs_embeds, + inputs_embeds_masks) hidden_states = self.word_embeddings_layernorm(hidden_states) else: assert intermediate_tensors is not None @@ -309,9 +313,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 835682ca3b37..031f62117897 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -287,9 +287,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -367,9 +371,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 3e60eee2d8fe..77d5ab0d4ef4 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -24,7 +24,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -328,9 +328,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors hidden_states = intermediate_tensors["hidden_states"] @@ -384,9 +388,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index d278ea5b6a99..e594daf65e97 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -360,9 +360,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -410,9 +414,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 834be78bce87..007086518e63 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -452,9 +452,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -507,9 +511,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 23efe0359cb4..089ebebb0b26 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -53,7 +53,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -370,12 +370,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -492,9 +493,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ad07fc3b3776..dc853496f42c 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -47,7 +47,7 @@ from vllm.transformers_utils.configs import RWConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -371,9 +371,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.word_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.word_embeddings, + inputs_embeds, + inputs_embeds_masks) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -437,9 +441,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 3db82a898159..6459899c2f00 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -57,7 +57,7 @@ class FuyuImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """ - Shape: + Shape: (batch_size, num_patches, patch_size_x * patch_size_y * num_channels) """ @@ -104,6 +104,7 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) + return SequenceData(token_ids), { "image": consecutive_placeholder_ranges(num_items=num_images, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fc3f5cb20afb..a65739b29ce2 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -300,12 +300,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) hidden_states *= self.normalizer residual = None else: @@ -404,9 +405,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index c365880109ef..6c8a23d9bb4e 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, get_inputs_embeds, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -282,12 +283,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) hidden_states *= self.normalizer residual = None else: @@ -425,9 +426,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index a06200c4b7e0..07e15e8c5260 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -217,9 +217,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -271,9 +275,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 7612ea641d95..696e346e5fb0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -227,9 +227,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -297,9 +301,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index b28a6081b868..f5a881391e7e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -209,9 +209,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -259,9 +263,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 931052c7cccf..642138ab9117 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -40,7 +40,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -222,9 +222,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_in(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_in, + inputs_embeds, + inputs_embeds_masks) else: hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): @@ -272,9 +276,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bee48f377e0f..e4dee4bcae5a 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -52,7 +52,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, + make_layers) class GraniteMLP(nn.Module): @@ -305,12 +306,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None hidden_states *= self.config.embedding_multiplier @@ -422,9 +424,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead6511513..11c15a8051b5 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -14,6 +14,45 @@ logger = init_logger(__name__) +@runtime_checkable +class SupportsInputEmbeds(Protocol): + """The interface required to support embedding inputs.""" + + def forward( + self, + *, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ... + + +@overload +def supports_input_embeds( + model: Type[object]) -> TypeIs[Type[SupportsInputEmbeds]]: + ... + + +@overload +def supports_input_embeds(model: object) -> TypeIs[SupportsInputEmbeds]: + ... + + +def supports_input_embeds( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsInputEmbeds]], TypeIs[SupportsInputEmbeds]]: + """Check if the model supports input_embeds and input_embeds_masks.""" + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + required_kws = ("inputs_embeds", "inputs_embeds_masks") + missing_kws = tuple(kw for kw in required_kws + if not supports_kw(model_forward, kw)) + + return len(missing_kws) == 0 + + @runtime_checkable class SupportsMultiModal(Protocol): """The interface required for all multi-modal models.""" diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index afefb6cd9fa9..21734b4a7490 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -28,7 +28,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -286,12 +286,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.tok_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.tok_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -349,9 +349,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 81d88a47c194..1ed6be3f6ebb 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -29,6 +29,7 @@ _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA +from .utils import get_inputs_embeds KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -298,8 +299,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, inputs_embeds_masks) residual = None for i in range(len(self.layers)): layer = self.layers[i] @@ -391,6 +395,8 @@ def forward(self, kv_caches: List[KVCache], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: max_batch_size = (_get_graph_batch_size( @@ -412,8 +418,16 @@ def forward(self, mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + mamba_cache_params, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, + ) + return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6c0a8b5ef845..877d4ae6f51e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, get_inputs_embeds, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -324,12 +325,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -554,9 +556,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index acf03cd8cb8a..e06cc6fb26c5 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -404,12 +404,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -513,9 +514,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e9b9c4d838fa..2aa7b827a3c9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -288,9 +288,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -377,9 +381,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9647d69be8a0..dc260cd532cf 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -325,9 +325,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -377,9 +381,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index fdd8af79b547..44e9665896f3 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -25,7 +25,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -244,9 +244,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.wte(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.wte, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -292,9 +296,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index b649064536dc..ed1ebb8dc93f 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -46,7 +46,7 @@ from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) # The architecture is pretty similar to Llama, with these changes: @@ -344,12 +344,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -454,9 +455,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index dd3f58289a22..1d9d4e8dc3a4 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -45,7 +45,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -253,6 +253,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. @@ -260,7 +262,9 @@ def forward( if get_pp_group().is_first_rank: # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) # embed positions hidden_states = inputs_embeds @@ -320,6 +324,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, @@ -327,6 +333,8 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, ) return hidden_states diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 7a76e4a0906d..ac67092b0afc 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -41,7 +41,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -264,10 +264,13 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) + inputs_embeds = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) pos_embeds = self.embed_positions(positions) if self.project_in is not None: inputs_embeds, _ = self.project_in(inputs_embeds) @@ -321,13 +324,15 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks) class OPTForCausalLM(nn.Module, SupportsPP): @@ -374,9 +379,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index a338a93c2dd9..6974829281b2 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -28,7 +28,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -244,9 +244,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -295,9 +299,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index bd4a9f698bac..fee5f87c9650 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -44,7 +44,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -241,12 +241,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -291,6 +291,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ): hidden_states = self.model( input_ids=input_ids, @@ -299,6 +300,7 @@ def forward( attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, ) return hidden_states diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 492122450b23..bb61c877f803 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -223,9 +223,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -311,10 +315,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) - + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 3a7afc606bb9..195491311b7e 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -23,7 +23,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -337,9 +337,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) if (self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0): hidden_states = hidden_states * self.mup_embedding_multiplier @@ -436,6 +440,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, @@ -443,6 +449,8 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + inputs_embeds_masks=inputs_embeds_masks, ) output_hidden_states = output_hidden_states return output_hidden_states diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c41891ced41..ad9b49b7fcf5 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -676,6 +676,8 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, **kwargs: object): if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 59843ae3dfd5..98e8d0041e82 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -472,9 +472,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -574,9 +578,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 49b3de1304cc..f73e11aaa1df 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, get_inputs_embeds, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -291,12 +292,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -455,9 +456,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 98bb48a274e4..c979b33b338d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.utils import print_warning_once from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -351,9 +351,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: assert intermediate_tensors is not None @@ -404,9 +408,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 1b233ac7427d..6ab64c9379bb 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -320,12 +320,14 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) + hidden_states = get_inputs_embeds(input_ids, + self.get_input_embeddings, + inputs_embeds, + inputs_embeds_masks) + residual = None else: assert intermediate_tensors is not None @@ -463,9 +465,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 34389b645a7c..2a5ee20c6fb0 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -42,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -223,9 +223,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -272,9 +276,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index b24c5dadb2b2..5305a8c45979 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -42,7 +42,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -227,9 +227,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -281,9 +285,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index fee97e8922a7..6013eae63b28 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -564,6 +564,32 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: for missing_layer_name in get_pp_missing_layer_names(model)) +def get_inputs_embeds( + input_ids: Optional[torch.Tensor], + embeddings_module: Callable[[torch.Tensor], torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Get the input embeddings from either `input_ids` and `inputs_embeds`.""" + if inputs_embeds is not None: + if inputs_embeds_masks is None: + hidden_states = inputs_embeds + else: + msg = "inputs_embeds should not be masked out for multimodal models" + assert input_ids is not None, msg + + hidden_states = embeddings_module(input_ids) + hidden_states[inputs_embeds_masks] = inputs_embeds + else: + msg = "inputs_embeds should be set for multimodal models" + assert input_ids is not None, msg + + hidden_states = embeddings_module(input_ids) + + assert hidden_states is not None + return hidden_states + + def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index e559988ada75..1bbcbef57348 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -45,7 +45,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (get_inputs_embeds, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -261,9 +261,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + hidden_states = get_inputs_embeds(input_ids, self.embed_tokens, + inputs_embeds, + inputs_embeds_masks) residual = None else: hidden_states = intermediate_tensors["hidden_states"] @@ -345,9 +349,12 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds_masks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds, inputs_embeds_masks) return hidden_states def compute_logits( diff --git a/vllm/outputs.py b/vllm/outputs.py index 951976310e7a..a1ea23356976 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional from typing import Sequence as GenericSequence -from typing import Union +from typing import Tuple, Union from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind @@ -83,10 +83,11 @@ class RequestOutput: finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. - encoder_prompt: The encoder prompt string of the request; + encoder_prompt: The encoder prompt string of the request; None if decoder-only encoder_prompt_token_ids: The token IDs of the encoder prompt; None if decoder-only + prompt_embeds_shape: The shape of the prompt embeddings. """ def __init__( @@ -101,10 +102,12 @@ def __init__( lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, + prompt_embeds_shape: Optional[Tuple[int, int]] = None, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.prompt_embeds_shape = prompt_embeds_shape self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -227,12 +230,17 @@ def from_seq_group( if include_prompt: prompt = seq_group.prompt prompt_token_ids = seq_group.prompt_token_ids + if (prompt_embeds := seq_group.prompt_embeds) is not None: + prompt_embeds_shape = tuple(prompt_embeds.shape) + else: + prompt_embeds_shape = None encoder_prompt = seq_group.encoder_prompt encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs else: prompt = None prompt_token_ids = None + prompt_embeds_shape = None encoder_prompt = None encoder_prompt_token_ids = None prompt_logprobs = None @@ -242,7 +250,7 @@ def from_seq_group( init_args = (seq_group.request_id, prompt, prompt_token_ids, prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, - encoder_prompt_token_ids) + encoder_prompt_token_ids, prompt_embeds_shape) if use_cache: request_output = seq_group.cached_request_output @@ -257,6 +265,7 @@ def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_embeds_shape={self.prompt_embeds_shape}, " f"encoder_prompt={self.encoder_prompt!r}, " f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec444..5cc2f15aa406 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,7 @@ from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Set, Tuple, Union, overload import msgspec import torch @@ -144,20 +144,21 @@ class SequenceDataDelta( class SequenceData(msgspec.Struct, omit_defaults=True): # type: ignore[call-arg] """Data associated with a sequence. - Args: prompt_token_ids: The token IDs of the prompt. + prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. Set to an empty list if None. - Attributes: prompt_token_ids: The token IDs of the prompt. + prompt_embeds: The embeddings of the prompt. output_token_ids: The token IDs of the output. cumulative_logprob: The cumulative log probability of the output. """ # NOTE: we cannot use Union[List, array] because msgspec cannot support # union of 2 list types. _prompt_token_ids: array + _prompt_embeds: Optional[torch.Tensor] = None _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) @@ -169,11 +170,9 @@ class SequenceData(msgspec.Struct, _num_computed_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) - # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. _new_appended_tokens: List[int] = msgspec.field(default_factory=list) - # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None @@ -201,6 +200,8 @@ def from_prompt_token_counts( def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, + *, + prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": """ Construct a :class:`SequenceData` instance from prompt and output @@ -210,13 +211,15 @@ def from_seqs( prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr, + _prompt_embeds=prompt_embeds) output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + _output_token_ids=output_token_ids_arr, + _prompt_embeds=prompt_embeds) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -239,14 +242,13 @@ def cumulative_logprob(self) -> float: def prompt_token_ids(self) -> Tuple[int, ...]: return self._prompt_token_ids_tuple - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds @property def prompt_token_ids_array(self) -> array: """Return the prompt token ids in array type. - Note that the array is in "I" type, and it is not compatible with torch.long (2 bytes vs 4 bytes). So beware of the usage. """ @@ -410,7 +412,15 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData.from_seqs(self.prompt_token_ids) + data: SequenceData + if self.prompt_token_ids: + data = SequenceData.from_seqs(self.prompt_token_ids) + else: + assert isinstance(self.prompt_embeds, torch.Tensor) + data = SequenceData.from_seqs([], prompt_embeds=self.prompt_embeds) + + self.data = data + self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -438,6 +448,9 @@ def prompt(self) -> Optional[str]: if inputs["type"] == "token": return inputs.get("prompt") + if inputs["type"] == "embed": + return None + assert_never(inputs) @cached_property @@ -447,6 +460,9 @@ def prompt_token_ids(self) -> List[int]: if inputs["type"] == "token": return inputs.get("prompt_token_ids", []) + if inputs["type"] == "embed": + return [] + assert_never(inputs) @cached_property @@ -456,6 +472,9 @@ def prompt_embeds(self) -> Optional[torch.Tensor]: if inputs["type"] == "token": return None + if inputs["type"] == "embed": + return inputs.get("prompt_embeds", []) + assert_never(inputs) @cached_property @@ -465,23 +484,32 @@ def multi_modal_data(self) -> "MultiModalDataDict": if inputs["type"] == "token": return inputs.get("multi_modal_data", {}) + if inputs["type"] == "embed": + return inputs.get("multi_modal_data", {}) + assert_never(inputs) - @cached_property - def mm_processor_kwargs(self) -> Dict[str, Any]: + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "embed": + return {} assert_never(inputs) - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) + return inputs.get("mm_processor_kwargs", {}) + + if inputs["type"] == "embed": + return {} assert_never(inputs) @@ -692,6 +720,12 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> List[int]: return self.first_seq.prompt_token_ids + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return self.seqs[0].prompt_embeds + @property def encoder_prompt(self) -> Optional[str]: # There are either 0 or 1 encoder sequences @@ -906,7 +940,7 @@ class SequenceGroupMetadata( multi_modal_data: Multi modal data. mm_processor_kwargs: Multimodal input processor / mapper overrides. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated @@ -1091,12 +1125,25 @@ class IntermediateTensors: tensors: Dict[str, torch.Tensor] - def __getitem__(self, key: Union[str, slice]): + @overload + def __getitem__(self, key: str) -> torch.Tensor: + ... + + @overload + def __getitem__(self, key: slice) -> "IntermediateTensors": + ... + + def __getitem__( + self, + key: Union[str, slice], + ) -> Union[torch.Tensor, "IntermediateTensors"]: if isinstance(key, str): return self.tensors[key] - elif isinstance(key, slice): + if isinstance(key, slice): return self.__class__({k: v[key] for k, v in self.tensors.items()}) + assert_never(key) + def __setitem__(self, key: str, value): self.tensors[key] = value diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1e8ea4e8e79c..006089bf19e0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,7 +35,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models import (supports_input_embeds, supports_lora, + supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalPlaceholderMap, @@ -93,6 +94,8 @@ class ModelInputForGPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + input_embeds: Optional[torch.Tensor] = None + input_embeds_masks: Optional[torch.BoolTensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None @@ -112,6 +115,8 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "input_embeds": self.input_embeds, + "input_embeds_masks": self.input_embeds_masks, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -163,6 +168,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "input_embeds": self.input_embeds, + "input_embeds_masks": self.input_embeds_masks, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, @@ -230,6 +237,10 @@ def __init__( input_positions: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, + # Input embeddings and masks. + input_embeds: Optional[torch.Tensor] = None, + input_embeds_mask: Optional[torch.BoolTensor] = None, + # The sequence length (may be capped to the sliding window). seq_lens: Optional[List[int]] = None, # The original sequence length (before applying sliding window). @@ -275,6 +286,8 @@ def __init__( self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs + self.input_embeds = input_embeds + self.input_embeds_mask = input_embeds_mask self.encoder_seq_len = encoder_seq_len if reinit: @@ -499,7 +512,16 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() # Compute tokens. - tokens = seq_data.get_token_ids()[context_len:seq_len] + if seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] + input_embeds = None + input_embeds_mask = torch.zeros(seq_len - context_len, + dtype=torch.bool) + else: + tokens = [0] * seq_len + input_embeds = seq_data.prompt_embeds[context_len:seq_len] + input_embeds_mask = torch.ones(seq_len - context_len, + dtype=torch.bool) inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len @@ -507,6 +529,8 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.query_lens[seq_idx] = seq_len - context_len + inter_data.input_embeds = input_embeds + inter_data.input_embeds_mask = input_embeds_mask if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: @@ -829,11 +853,20 @@ def build(self) -> ModelInputForGPU: seq_lens = [] query_lens = [] + input_embeds_lst = [] + input_embeds_masks_lst = [] max_decode_seq_len = 0 max_encoder_seq_len = 0 + for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) query_lens.extend(inter_data.query_lens) + + if inter_data.input_embeds is not None: + input_embeds_lst.append(inter_data.input_embeds) + if inter_data.input_embeds_mask is not None: + input_embeds_masks_lst.append(inter_data.input_embeds_mask) + if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) @@ -841,6 +874,19 @@ def build(self) -> ModelInputForGPU: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) + if input_embeds_lst: + input_embeds = torch.cat(input_embeds_lst).to( + device=self.runner.device, + dtype=self.runner.model_config.dtype) + else: + input_embeds = None + + if input_embeds_masks_lst: + input_embeds_masks = torch.cat(input_embeds_masks_lst).to( + self.runner.device) + else: + input_embeds_masks = None + # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. request_ids_to_seq_ids = { @@ -945,6 +991,8 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + input_embeds=input_embeds, + input_embeds_masks=input_embeds_masks, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, @@ -1041,6 +1089,7 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model + self.model_supports_input_embeds = False # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None @@ -1065,6 +1114,8 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) + self.model_supports_input_embeds = supports_input_embeds( + self.model) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1650,15 +1701,28 @@ def execute_model( model_forward_start.record() with set_forward_context(model_input.attn_metadata): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + model_params = dict(input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs( + multi_modal_kwargs, device=self.device), + **seqlen_agnostic_kwargs) + if self.model_supports_input_embeds: + input_embeds = model_input.input_embeds + + input_embeds_masks = model_input.input_embeds_masks + if (input_embeds_masks is not None + and input_embeds_masks.all().item()): + input_embeds_masks = None + + model_params.update( + inputs_embeds=input_embeds, + inputs_embeds_masks=input_embeds_masks, + ) + + hidden_or_intermediate_states = model_executable(**model_params) if (self.observability_config is not None and self.observability_config.collect_model_forward_time):