From c618001ba5171202b4f94b459ddc66ae5d7ad6a3 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 21 Sep 2024 07:14:52 +0000 Subject: [PATCH 01/14] Rename input data types: - `LLMInputs` to `TokenInputs` (preprocessing), `SingletonInputs` (sequence) or `DecoderOnlyInputs` (VLM processors) - `EncoderDecoderLLMInputs` to `EncoderDecoderInputs` - `is_valid_encoder_decoder_llm_inputs` to `is_encoder_decoder_inputs` --- .../input_processing/model_inputs_index.rst | 2 +- .../decoder_only/vision_language/test_qwen.py | 12 ++-- vllm/engine/llm_engine.py | 10 ++-- vllm/inputs/__init__.py | 15 +++-- vllm/inputs/data.py | 46 +++++++++------- vllm/inputs/parse.py | 12 ++-- vllm/inputs/preprocess.py | 34 ++++++------ vllm/inputs/registry.py | 13 +++-- vllm/model_executor/models/blip.py | 18 +++--- vllm/model_executor/models/blip2.py | 18 +++--- vllm/model_executor/models/chameleon.py | 19 ++++--- vllm/model_executor/models/clip.py | 18 +++--- vllm/model_executor/models/fuyu.py | 18 +++--- vllm/model_executor/models/internvl.py | 18 +++--- vllm/model_executor/models/llava.py | 12 ++-- vllm/model_executor/models/llava_next.py | 13 +++-- .../model_executor/models/llava_next_video.py | 18 +++--- vllm/model_executor/models/minicpmv.py | 16 +++--- vllm/model_executor/models/paligemma.py | 19 ++++--- vllm/model_executor/models/phi3v.py | 20 +++---- vllm/model_executor/models/pixtral.py | 12 ++-- vllm/model_executor/models/qwen.py | 22 ++++---- vllm/model_executor/models/qwen2_vl.py | 20 +++---- vllm/model_executor/models/siglip.py | 14 ++--- vllm/model_executor/models/ultravox.py | 18 +++--- vllm/sequence.py | 55 ++++++++++++------- 26 files changed, 262 insertions(+), 230 deletions(-) diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst index 5d895837590b..f0ec1fea15dd 100644 --- a/docs/source/dev/input_processing/model_inputs_index.rst +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -25,7 +25,7 @@ Module Contents LLM Engine Inputs ----------------- -.. autoclass:: vllm.inputs.LLMInputs +.. autoclass:: vllm.inputs.DecoderOnlyInputs :members: :show-inheritance: diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b760..ce82302ae7ac 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.config import ModelConfig -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import DecoderOnlyInputs, InputContext from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -98,12 +98,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( [f"Picture {num}: \n" for num in range(1, num_images + 1)]) - inputs = LLMInputs( + inputs = DecoderOnlyInputs( prompt=prompt, # When processing multimodal data for a multimodal model, the qwen # input processor will overwrite the provided prompt_token_ids with # the image prompts - prompt_token_ids=None, + prompt_token_ids=[], multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, ) proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) @@ -161,9 +161,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, trust_remote_code=True) prompt = "Picture 1: \n" prompt_token_ids = tokenizer.encode(prompt) - inputs = LLMInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_data) + inputs = DecoderOnlyInputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) # Should fail since we have too many or too few dimensions for embeddings with pytest.raises(ValueError): input_processor_for_qwen(qwen_vl_context, inputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 39409757d381..568917b6707e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,8 +28,8 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -616,7 +616,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -1695,8 +1695,8 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model - def _validate_model_inputs(self, inputs: Union[LLMInputs, - EncoderDecoderLLMInputs]): + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, + EncoderDecoderInputs]): if self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ba1bef1ab3ec..29af87c3aeef 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ -from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -19,8 +20,10 @@ "PromptType", "SingletonPrompt", "ExplicitEncoderDecoderPrompt", - "LLMInputs", - "EncoderDecoderLLMInputs", + "TokenInputs", + "SingletonInputs", + "DecoderOnlyInputs", + "EncoderDecoderInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e072bb65714b..3b1a637c9143 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -35,7 +35,7 @@ class TokensPrompt(TypedDict): SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ -Set of possible schemas for a single LLM input: +Set of possible schemas for a single prompt: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptType` may be employed +A prompt of type :class:`SingletonPrompt` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -66,22 +66,21 @@ class TokensPrompt(TypedDict): # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a - decoder prompt. + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. - The encoder and decoder prompts, respectively, - may formatted according to any of the - :class:`SingletonPromptType` schemas, and are not - required to have the same schema. + The encoder and decoder prompts, respectively, may be formatted + according to any of the :class:`SingletonPrompt` schemas, + and are not required to have the same schema. Only the encoder prompt may have multi-modal data. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, - and that the `encoder_prompt` and `decoder_prompt` + and that the :code:`encoder_prompt` and :code:`decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptType` instances. + :class:`SingletonPrompt` instances. """ encoder_prompt: _T1_co @@ -101,13 +100,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """ -class LLMInputs(TypedDict): - """ - The inputs in :class:`~vllm.LLMEngine` before they are - passed to the model executor. - - This specifies the data required for decoder-only models. - """ +class TokenInputs(TypedDict): + """Represents token-based inputs.""" prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -123,7 +117,21 @@ class LLMInputs(TypedDict): """ -class EncoderDecoderLLMInputs(LLMInputs): +SingletonInputs = TokenInputs +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + +DecoderOnlyInputs = TokenInputs +""" +The inputs in :class:`~vllm.LLMEngine` before they are +passed to the model executor. +This specifies the data required for decoder-only models. +""" + + +class EncoderDecoderInputs(TokenInputs): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e418427..7f9152dd3347 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, + TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_valid_encoder_decoder_llm_inputs( - inputs: Union[LLMInputs, EncoderDecoderLLMInputs], -) -> TypeIs[EncoderDecoderLLMInputs]: +def is_encoder_decoder_inputs( + inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], +) -> TypeIs[EncoderDecoderInputs]: return "encoder_prompt_token_ids" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 1f1b048d37e9..391651e9474f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,7 +9,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt @@ -291,7 +291,7 @@ def _build_enc_dec_llm_inputs( self, encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps @@ -302,7 +302,7 @@ def _build_enc_dec_llm_inputs( decoder_prompt_ids = ( self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - return EncoderDecoderLLMInputs( + return EncoderDecoderInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, encoder_prompt_token_ids=encoder_prompt_ids, @@ -313,11 +313,11 @@ def _process_encoder_decoder_prompt( self, prompt: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. + :class:`EncoderDecoderInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -344,7 +344,7 @@ def _process_encoder_decoder_prompt( Returns: - * :class:`EncoderDecoderLLMInputs` instance + * :class:`EncoderDecoderInputs` instance ''' encoder_comps: PromptComponents @@ -377,7 +377,7 @@ async def _process_encoder_decoder_prompt_async( self, prompt: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents @@ -413,15 +413,15 @@ def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: + ) -> DecoderOnlyInputs: prompt, prompt_token_ids, multi_modal_data = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) def _process_decoder_only_prompt( self, @@ -429,10 +429,10 @@ def _process_decoder_only_prompt( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: ''' For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. + Process an input prompt into an :class:`DecoderOnlyInputs` instance. Arguments: @@ -443,7 +443,7 @@ def _process_decoder_only_prompt( Returns: - * :class:`LLMInputs` instance + * :class:`DecoderOnlyInputs` instance ''' prompt_comps = self._extract_prompt_components( @@ -463,7 +463,7 @@ async def _process_decoder_only_prompt_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( prompt, @@ -482,7 +482,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -510,7 +510,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2df61a914962..ac0c23b2e1ae 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger -from .data import LLMInputs +from .data import DecoderOnlyInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -93,7 +93,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] """Preprocess the inputs to the model.""" @@ -200,8 +200,11 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor(self, ctx: InputContext, - inputs: LLMInputs) -> LLMInputs: + def _default_input_processor( + self, + ctx: InputContext, + inputs: DecoderOnlyInputs, + ) -> DecoderOnlyInputs: """The default input processor is a no-op.""" return inputs @@ -230,7 +233,7 @@ def wrapper(model_cls: N) -> N: return wrapper def process_input(self, model_config: "ModelConfig", - inputs: LLMInputs) -> LLMInputs: + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """ Apply an input processor to an instance of model inputs. diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e943427eda8e..0d1db14f363f 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -88,14 +88,14 @@ def dummy_image_for_blip( def input_processor_for_blip( model_config: ModelConfig, hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[int] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -106,16 +106,16 @@ def input_processor_for_blip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 37fabf3f3f9a..81fdd8fb7c3f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -455,10 +455,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs hf_config = ctx.get_hf_config(Blip2Config) image_feature_size = get_blip2_image_feature_size(hf_config) @@ -466,15 +466,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): # The original model places image tokens at the front # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size - new_token_ids += llm_inputs["prompt_token_ids"] + new_token_ids += inputs["prompt_token_ids"] - new_prompt = llm_inputs.get("prompt") + new_prompt = inputs.get("prompt") if new_prompt is not None: new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 51a61485caf6..cf67a26e2842 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,7 +11,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -107,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_chameleon(ctx: InputContext, + inputs: DecoderOnlyInputs): """ Processing input prompt to insert required tokens for image placeholder. @@ -115,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 """ # noqa - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, @@ -138,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids += [CHAMELEON_SEP_TOKEN_ID] # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a7754f70e278..0a260c107ef7 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -87,14 +87,14 @@ def dummy_image_for_clip( def input_processor_for_clip( model_config: ModelConfig, hf_config: CLIPVisionConfig, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -111,16 +111,16 @@ def input_processor_for_clip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index beeae1422957..0aa005746353 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -153,10 +153,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, return model_image_input -def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config image_data = multi_modal_data["image"] @@ -180,8 +180,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): raise TypeError(f"Invalid image type: {type(image_data)}") # process prompts - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] tokenizer = cached_get_tokenizer(model_config.model) # dim0 is batch_size, dim1 is subseq_size which will always be 1 image_input_ids: List[List[ @@ -194,9 +194,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ 1:] + boa_token - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) + return DecoderOnlyInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) def input_mapper_for_fuyu(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a..91222089e59b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,7 +18,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -193,10 +193,10 @@ def get_max_internvl_image_tokens(ctx: InputContext): return num_patches * max_dynamic_patch -def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_internvl(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config() @@ -234,8 +234,8 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) @@ -248,9 +248,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): new_prompt = new_prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_internvl(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133..e39b46d527e4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,7 +9,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -126,10 +126,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaConfig) @@ -152,7 +152,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -160,7 +160,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index d550a249ee82..491c9d116ac3 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -210,10 +210,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_llava_next(ctx: InputContext, + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaNextConfig) @@ -248,7 +249,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -256,7 +257,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3..69622ae8bb27 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,7 +11,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( @@ -144,10 +144,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, def input_processor_for_llava_next_video(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: - return llm_inputs + return inputs video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -165,15 +165,15 @@ def input_processor_for_llava_next_video(ctx: InputContext, new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=hf_config.video_token_index, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5579205832aa..e5eded71b6cd 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,7 +36,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -277,10 +277,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config version = get_version_by_config(model_config.hf_config) tokenizer = cached_get_tokenizer(model_config.tokenizer, @@ -294,9 +294,9 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return image_processor. \ get_slice_image_placeholder(image_size, num_image) - prompt = llm_inputs.get("prompt") + prompt = inputs.get("prompt") if prompt is None: - token_ids = llm_inputs.get("prompt_token_ids") + token_ids = inputs.get("prompt_token_ids") prompt = tokenizer.decode(token_ids) pattern = "(./)" @@ -320,12 +320,12 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) - llm_inputs = LLMInputs( + inputs = DecoderOnlyInputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, ) - return llm_inputs + return inputs class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be..64adbe137ccb 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -71,7 +71,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_paligemma(ctx: InputContext, + inputs: DecoderOnlyInputs): """ The correct prompt format needs to be: @@ -80,9 +81,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 """ # noqa - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(PaliGemmaConfig) @@ -94,8 +95,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): image_token_str_pad = image_token_str * image_feature_size image_token_ids_pad = [hf_config.image_token_index] * image_feature_size - orig_prompt = llm_inputs.get("prompt") - orig_prompt_ids = llm_inputs.get("prompt_token_ids") + orig_prompt = inputs.get("prompt") + orig_prompt_ids = inputs.get("prompt_token_ids") if orig_prompt is not None and image_token_str in orig_prompt: logger.warning( @@ -109,9 +110,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PaliGemmaMultiModalProjector(nn.Module): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccae..5af9b9630f86 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -27,7 +27,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -398,10 +398,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_phi3v(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_image_processor_config() @@ -430,7 +430,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): else: raise TypeError(f"Invalid image type: {type(image_data)}") - prompt = llm_inputs.get("prompt") + prompt = inputs.get("prompt") if prompt is None: # for async server request, we assume prompt and its token_ids is always # in correct format. And num_image_tags == len(image_data) always True. @@ -447,7 +447,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_data), "The count of image_placeholder not match image's" new_prompt = prompt - prompt_token_ids = llm_inputs["prompt_token_ids"].copy() + prompt_token_ids = inputs["prompt_token_ids"].copy() # masked place_holder with image token id for idx in image_idx: @@ -485,10 +485,10 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - return llm_inputs + inputs = DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + return inputs @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index aa92e62a30d3..0db7b23702ba 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -13,7 +13,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -101,8 +101,8 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is not None and "image" in multi_modal_data: tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, @@ -111,15 +111,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img - if image_token_id not in llm_inputs['prompt_token_ids']: + if image_token_id not in inputs['prompt_token_ids']: raise ValueError( - (f"You've passed {llm_inputs=} without {image_token_id=}" + (f"You've passed {inputs=} without {image_token_id=}" " Make sure to process your input via mistral_common's" " tokenizer or pass a chat completion request. For more" " For more info, see: " "https://github.com/vllm-project/vllm/issues/8411.")) - return llm_inputs + return inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e62a841485f2..bc8c85c28711 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -22,7 +22,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -649,30 +649,30 @@ def get_image_text(image_num: int, padding: bool) -> str: def input_processor_for_qwen(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """Processes the inputs, which may or may not be multimodal. Multimodal inputs will only be processed if the model has a "visual" component in its model config, otherwise they'll be ignored. Args: ctx: Context of the loaded model. - llm_inputs: LLM inputs which may have a multi_modal_data attribute. + inputs: LLM inputs which may have a multi_modal_data attribute. Returns: If the model is language only or not multimodal inputs were provided, - returns llm_inputs unmodified. Otherwise, processes the multimodal + returns inputs unmodified. Otherwise, processes the multimodal images / image embeddings and adds the fixed-length image placeholders. """ - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") # Only process images if we have multimodal data and a visual config hf_config = ctx.get_hf_config() if (multi_modal_data is None or "image" not in multi_modal_data or not hasattr(hf_config, "visual")): - return llm_inputs + return inputs - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -709,9 +709,9 @@ def input_processor_for_qwen(ctx: InputContext, new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1011c9256793..6f97e641fc74 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -47,7 +47,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -721,11 +721,11 @@ def _get_llm_num_vision_tokens( return llm_num_vision_tokens -def input_processor_for_qwen2_vl(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: - multi_modal_data = llm_inputs.get("multi_modal_data", None) +def input_processor_for_qwen2_vl( + ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is None: - return llm_inputs + return inputs image_inputs = multi_modal_data.get("image", None) video_inputs = multi_modal_data.get("video", None) @@ -739,7 +739,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext, # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. # # The following code is equivalent to: - # prompt = llm_inputs["prompt"] + # prompt = inputs["prompt"] # inputs = processor(text=[prompt], # images=image_inputs, # videos=video_inputs, @@ -747,9 +747,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, # return_tensors="pt") # prompt_token_ids = inputs["input_ids"][0].tolist() - prompt_token_ids = llm_inputs.get("prompt_token_ids", None) + prompt_token_ids = inputs.get("prompt_token_ids", None) if prompt_token_ids is None: - prompt = llm_inputs["prompt"] + prompt = inputs["prompt"] prompt_token_ids = processor.tokenizer( prompt, padding=True, @@ -814,9 +814,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, 1:]) prompt_token_ids = prompt_token_ids_with_video - return LLMInputs( + return DecoderOnlyInputs( prompt_token_ids=prompt_token_ids, - prompt=llm_inputs["prompt"], + prompt=inputs["prompt"], multi_modal_data=multi_modal_data, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 5b332fa1a24d..05aa1e573c4c 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -12,7 +12,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -92,14 +92,14 @@ def dummy_image_for_siglip( def input_processor_for_siglip( model_config: ModelConfig, hf_config: SiglipVisionConfig, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -116,14 +116,14 @@ def input_processor_for_siglip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs( + return DecoderOnlyInputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 87f59f487f87..8740fd994bca 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -19,7 +19,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs +from vllm.inputs.data import DecoderOnlyInputs from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn @@ -155,10 +155,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): return MultiModalInputs({"audio_features": audio_features}) -def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "audio" not in multi_modal_data: - return llm_inputs + return inputs feature_extractor = whisper_feature_extractor(ctx) audios = multi_modal_data["audio"] @@ -190,16 +190,16 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, repeat_count=audio_token_counts, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class StackAudioFrames(nn.Module): diff --git a/vllm/sequence.py b/vllm/sequence.py index d8e54ff1fc70..82fb76a600a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,7 +13,7 @@ import msgspec import torch -from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -21,12 +21,17 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs + from vllm.inputs import SingletonInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" +def array_full(token_id: int, count: int): + """:class:`array` equivalent of :func:`numpy.full`.""" + return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. # TODO(sang): Fix it. @@ -172,21 +177,32 @@ class SequenceData(msgspec.Struct, @staticmethod def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance by concatenating + prompt token sequences. + + Each tuple represents one token sequence, expressed in the form + :code:`(token_id, count)`. + """ if len(token_counts) == 0: return SequenceData.from_seqs([]) - arrs = [ - array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - for token_id, count in token_counts - ] + prompt_token_ids_arr = reduce( + array.__iadd__, + (array_full(token_id, count) for token_id, count in token_counts), + ) - return SequenceData(reduce(array.__add__, arrs)) + return SequenceData(prompt_token_ids_arr) @staticmethod def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, ) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance from prompt and output + token sequences. + """ prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids) @@ -360,14 +376,14 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - The sequence is constructed from the LLMInputs instance passed - in through the `inputs` constructor argument. + The sequence is constructed from the :code:`SingletonInputs` instance + passed in through the :code:`inputs` constructor argument. - For encoder/decoder models, LLMInputs encapsulates both a + For encoder/decoder models, SingletonInputs encapsulates both a decoder and encoder prompt, creating an ambiguity about which prompt to construct the sequence from. The `from_decoder_prompt` constructor argument signals whether to construct the Sequence - from the LLMInputs decoder prompt, or encoder prompt. + from the SingletonInputs decoder prompt, or encoder prompt. Args: seq_id: The ID of the sequence. @@ -377,16 +393,16 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt - (True) or encoder prompt (False.) Must be True - for decoder-only model. + from_decoder_prompt: Construct Sequence from SingletonInputs decoder + prompt (True) or encoder prompt (False.) Must be + True for decoder-only model. """ def __init__( self, seq_id: int, - inputs: "LLMInputs", + inputs: "SingletonInputs", block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, @@ -402,19 +418,19 @@ def __init__( self.from_decoder_prompt = from_decoder_prompt # For decoder-only models, a Sequence is constructed - # from an LLMInputs instance (the `inputs` arg.) + # from an DecoderOnlyInputs instance (the `inputs` arg.) # # For encoder/decoder models the same `inputs` # instance could be utilized to construct either an # encoder sequence or a decoder sequence, because - # `LLMInputs` has both decoder- and encoder-oriented + # `DecoderOnlyInputs` has both decoder- and encoder-oriented # member variables (i.e. it encapsulates both an encoder # and a decoder prompt.) The decision of which type of sequence # to generate is determined by the `from_decoder_prompt` argument. # # When constructing a encoder sequence # (`from_decoder_prompt` False) it matters that - # the `LLMInputs` instance stored in `inputs` is valid + # the `DecoderOnlyInputs` instance stored in `inputs` is valid # in the sense that its encoder-related member variables are # populated; below, an exception is raised if this is # not the case. @@ -422,8 +438,7 @@ def __init__( # When constructing a decoder sequence (`from_decoder_prompt` True) # it does not matter whether `inputs` has its encoder-related # member variables populated. - if not (from_decoder_prompt - or is_valid_encoder_decoder_llm_inputs(inputs)): + if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)): raise ValueError("Cannot extract encoder input prompt from " f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") From d4a5c21b1a94ce47f471c420b2f6ab71f4d37ecf Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 21 Sep 2024 07:18:10 +0000 Subject: [PATCH 02/14] Rename `from_token_counts` to `from_prompt_token_counts` --- vllm/inputs/registry.py | 2 +- vllm/model_executor/models/blip.py | 2 +- vllm/model_executor/models/blip2.py | 2 +- vllm/model_executor/models/chameleon.py | 2 +- vllm/model_executor/models/clip.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/qwen.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/siglip.py | 2 +- vllm/sequence.py | 3 ++- 11 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ac0c23b2e1ae..6c1c5baca808 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -125,7 +125,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) + dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 0d1db14f363f..b6313840f9d7 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -62,7 +62,7 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 81fdd8fb7c3f..70fce0abbe81 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -427,7 +427,7 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index cf67a26e2842..1f812e39b519 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -70,7 +70,7 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 0a260c107ef7..25102997aa9b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -61,7 +61,7 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e5eded71b6cd..499539b97952 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -257,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - return SequenceData.from_token_counts((0, seq_len)) + return SequenceData.from_prompt_token_counts((0, seq_len)) def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0db7b23702ba..27dcea46713d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -61,7 +61,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images - seq_data = SequenceData.from_token_counts( + seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), (0, seq_len - num_image_tokens), ) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index bc8c85c28711..f97e44d41d48 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -817,7 +817,7 @@ def dummy_data_for_qwen( # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceData.from_token_counts((0, seq_len)) + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) mm_data = None return seq_data, mm_data diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 6f97e641fc74..c3091b8dc913 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -680,7 +680,7 @@ def dummy_data_for_qwen2_vl( hf_config = ctx.get_hf_config(Qwen2VLConfig) - dummy_seqdata = SequenceData.from_token_counts( + dummy_seqdata = SequenceData.from_prompt_token_counts( (hf_config.vision_start_token_id, 1), (hf_config.image_token_id, max_llm_image_tokens), (hf_config.vision_end_token_id, 1), diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 05aa1e573c4c..38ed524941a5 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -66,7 +66,7 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 82fb76a600a3..bb5bf60ddbf4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -176,7 +176,8 @@ class SequenceData(msgspec.Struct, _mrope_position_delta: Optional[int] = None @staticmethod - def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": + def from_prompt_token_counts( + *token_counts: Tuple[int, int]) -> "SequenceData": """ Construct a :class:`SequenceData` instance by concatenating prompt token sequences. From 03923ea20824e4075596eb40522ff61ebbcd73df Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 02:01:15 +0000 Subject: [PATCH 03/14] Cleanup --- vllm/model_executor/models/llava_onevision.py | 20 ++++++++----------- vllm/model_executor/models/ultravox.py | 3 +-- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 9099d4f88222..e1cefd84f6c6 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,11 +14,9 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -38,8 +36,6 @@ from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 @@ -253,7 +249,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, def input_processor_when_multimodal_input_image(ctx: InputContext, - llm_inputs: LLMInputs): + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -309,7 +305,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext, - llm_inputs: LLMInputs): + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: return llm_inputs @@ -333,9 +329,9 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( @@ -346,7 +342,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext, - llm_inputs: LLMInputs): + llm_inputs: DecoderOnlyInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or ("video" not in multi_modal_data and "image" not in multi_modal_data): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6e29d57d6f7f..f39d23e4f77b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -22,8 +22,7 @@ from vllm.inputs.registry import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal From eb00f713814b4b621eafae97f66553b07fd30f08 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 02:15:44 +0000 Subject: [PATCH 04/14] Adopt `token_inputs` helper function --- .../decoder_only/vision_language/test_qwen.py | 10 +++++----- vllm/inputs/__init__.py | 3 ++- vllm/inputs/data.py | 16 ++++++++++++++++ vllm/model_executor/models/blip.py | 8 ++++---- vllm/model_executor/models/blip2.py | 9 +++++---- vllm/model_executor/models/chameleon.py | 9 +++++---- vllm/model_executor/models/clip.py | 8 ++++---- vllm/model_executor/models/fuyu.py | 9 +++++---- vllm/model_executor/models/internvl.py | 9 +++++---- vllm/model_executor/models/llava_next_video.py | 9 +++++---- vllm/model_executor/models/llava_onevision.py | 9 +++++---- vllm/model_executor/models/minicpmv.py | 6 +++--- vllm/model_executor/models/paligemma.py | 9 +++++---- vllm/model_executor/models/phi3v.py | 10 +++++----- vllm/model_executor/models/qwen.py | 9 +++++---- vllm/model_executor/models/qwen2_vl.py | 5 +++-- vllm/model_executor/models/siglip.py | 4 ++-- vllm/model_executor/models/ultravox.py | 8 ++++---- 18 files changed, 88 insertions(+), 62 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index ce82302ae7ac..196dc8900ab4 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.config import ModelConfig -from vllm.inputs import DecoderOnlyInputs, InputContext +from vllm.inputs import InputContext, token_inputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -98,7 +98,7 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( [f"Picture {num}: \n" for num in range(1, num_images + 1)]) - inputs = DecoderOnlyInputs( + inputs = token_inputs( prompt=prompt, # When processing multimodal data for a multimodal model, the qwen # input processor will overwrite the provided prompt_token_ids with @@ -161,9 +161,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, trust_remote_code=True) prompt = "Picture 1: \n" prompt_token_ids = tokenizer.encode(prompt) - inputs = DecoderOnlyInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_data) + inputs = token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) # Should fail since we have too many or too few dimensions for embeddings with pytest.raises(ValueError): input_processor_for_qwen(qwen_vl_context, inputs) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 29af87c3aeef..9bc33c271c96 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -2,7 +2,7 @@ ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - zip_enc_dec_prompts) + token_inputs, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -21,6 +21,7 @@ "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "TokenInputs", + "token_inputs", "SingletonInputs", "DecoderOnlyInputs", "EncoderDecoderInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 3b1a637c9143..d0d84799d50d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -117,6 +117,22 @@ class TokenInputs(TypedDict): """ +def token_inputs( + prompt_token_ids: List[int], + prompt: Optional[str] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, +) -> TokenInputs: + """Construct :class:`TokenInputs` from optional values.""" + inputs = TokenInputs(prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + + return inputs + + SingletonInputs = TokenInputs """ A processed :class:`SingletonPrompt` which can be passed to diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index dbf1c70b1de3..778162dd63ca 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -114,9 +114,9 @@ def input_processor_for_blip( ) # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a04a574e3f06..91b0c46fc79f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -8,7 +8,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -466,9 +467,9 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): if new_prompt is not None: new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 509e30187355..1d85f9dbc3d1 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,7 +11,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -136,9 +137,9 @@ def input_processor_for_chameleon(ctx: InputContext, new_token_ids += [CHAMELEON_SEP_TOKEN_ID] # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 7430e25bb470..7b0981d611b2 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -11,7 +11,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -137,9 +137,9 @@ def input_processor_for_clip( ) # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 439bb1d2504e..4f3a8f92b3ee 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -191,9 +192,9 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ 1:] + boa_token - return DecoderOnlyInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) def input_mapper_for_fuyu(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 99c168cdc1a1..0db94c800cab 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -247,9 +248,9 @@ def input_processor_for_internvl(ctx: InputContext, inputs: DecoderOnlyInputs): new_prompt = new_prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) - return DecoderOnlyInputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_internvl(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index e84bb817c8f6..e4dfcf2ec2f1 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -10,7 +10,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -167,9 +168,9 @@ def input_processor_for_llava_next_video(ctx: InputContext, repeat_count=video_feature_size, ) - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e1cefd84f6c6..3b28f130eae2 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,7 +14,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -329,9 +330,9 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, repeat_count=video_feature_size, ) - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 7e6400414aaa..06edf55b8cf7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,7 +36,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -317,12 +318,11 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) - inputs = DecoderOnlyInputs( + return token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, ) - return inputs class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 110a473bf086..f85b9a43315d 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -7,7 +7,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -109,9 +110,9 @@ def input_processor_for_paligemma(ctx: InputContext, new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PaliGemmaMultiModalProjector(nn.Module): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 5af9b9630f86..00bd5b326d2b 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -485,10 +486,9 @@ def input_processor_for_phi3v(ctx: InputContext, inputs: DecoderOnlyInputs): new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs - inputs = DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - return inputs + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f97e44d41d48..fb15780302e2 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -22,7 +22,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -709,9 +710,9 @@ def input_processor_for_qwen(ctx: InputContext, new_prompt_token_ids = tokenizer.encode(new_prompt) - return DecoderOnlyInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c3091b8dc913..25336d4bce30 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -47,7 +47,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -814,7 +815,7 @@ def input_processor_for_qwen2_vl( 1:]) prompt_token_ids = prompt_token_ids_with_video - return DecoderOnlyInputs( + return token_inputs( prompt_token_ids=prompt_token_ids, prompt=inputs["prompt"], multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 311dcfcb2643..4d14897d66a8 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -13,7 +13,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import DecoderOnlyInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -142,7 +142,7 @@ def input_processor_for_siglip( ) # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs( + return token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f39d23e4f77b..ecd9e8faea23 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import DecoderOnlyInputs +from vllm.inputs.data import DecoderOnlyInputs, token_inputs from vllm.inputs.registry import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -193,9 +193,9 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): ) # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class StackAudioFrames(nn.Module): From b4f4cc12c4bc6d4dbc49aba356cc38d8a067e4b4 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Sep 2024 02:16:36 +0000 Subject: [PATCH 05/14] Cleanup --- vllm/model_executor/models/llava_onevision.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 3b28f130eae2..1041c956be7f 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -250,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, def input_processor_when_multimodal_input_image(ctx: InputContext, - llm_inputs: DecoderOnlyInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaOnevisionConfig) @@ -288,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -296,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -306,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext, - llm_inputs: DecoderOnlyInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: - return llm_inputs + return inputs video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -324,8 +324,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=hf_config.video_token_index, repeat_count=video_feature_size, ) @@ -343,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext, - llm_inputs: DecoderOnlyInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or ("video" not in multi_modal_data and "image" not in multi_modal_data): - return llm_inputs + return inputs if "image" in multi_modal_data: - return input_processor_when_multimodal_input_image(ctx, llm_inputs) + return input_processor_when_multimodal_input_image(ctx, inputs) if "video" in multi_modal_data: - return input_processor_when_multimodal_input_video(ctx, llm_inputs) + return input_processor_when_multimodal_input_video(ctx, inputs) msg = "Unsupported multi data type" raise NotImplementedError(msg) From 5c0f091e132eaa39bc36bd35cf426af30d4902b7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 25 Sep 2024 16:47:33 +0000 Subject: [PATCH 06/14] Add backward compatibility for `LLMInputs` --- vllm/inputs/__init__.py | 12 ++++++++++-- vllm/inputs/data.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 4f785850180c..79a0e9bde4fd 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -35,9 +35,9 @@ def __getattr__(name: str): - if name == "PromptInput": - import warnings + import warnings + if name == "PromptInput": msg = ("PromptInput has been renamed to PromptType. " "The original name will be removed in an upcoming version.") @@ -45,4 +45,12 @@ def __getattr__(name: str): return PromptType + if name == "LLMInputs": + msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return DecoderOnlyInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 68626d97a58c..aa7338825c49 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -199,9 +199,9 @@ def to_enc_dec_tuple_list( def __getattr__(name: str): - if name == "PromptInput": - import warnings + import warnings + if name == "PromptInput": msg = ("PromptInput has been renamed to PromptType. " "The original name will be removed in an upcoming version.") @@ -209,4 +209,12 @@ def __getattr__(name: str): return PromptType + if name == "LLMInputs": + msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return DecoderOnlyInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From a8cb33949ad8bef098d387cab33bd1b349283bce Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 25 Sep 2024 16:50:38 +0000 Subject: [PATCH 07/14] Add backward compability for `EncoderDecoderLLMInputs` --- vllm/inputs/__init__.py | 9 +++++++++ vllm/inputs/data.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 79a0e9bde4fd..7b73922ddd2c 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -53,4 +53,13 @@ def __getattr__(name: str): return DecoderOnlyInputs + if name == "EncoderDecoderLLMInputs": + msg = ( + "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return EncoderDecoderInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index aa7338825c49..41016b52d12f 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -217,4 +217,13 @@ def __getattr__(name: str): return DecoderOnlyInputs + if name == "EncoderDecoderLLMInputs": + msg = ( + "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return EncoderDecoderInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From ab5a9372749cd671e4060cd1fb009008ce00cc81 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Sep 2024 00:36:47 +0800 Subject: [PATCH 08/14] rename PromptInputs and inputs with backward compatibility (#8760) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/async_engine/test_async_llm_engine.py | 8 +- tests/entrypoints/llm/test_encode.py | 34 ------ tests/entrypoints/llm/test_generate.py | 37 ------ tests/mq_llm_engine/test_error_handling.py | 12 +- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 110 +++++++++++++++--- vllm/engine/llm_engine.py | 52 +++++++-- vllm/engine/multiprocessing/__init__.py | 61 +++++++++- vllm/engine/multiprocessing/client.py | 95 ++++++++++++--- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 68 +++++------ vllm/inputs/__init__.py | 20 +++- vllm/inputs/data.py | 48 +++++--- vllm/inputs/parse.py | 22 ++-- vllm/inputs/preprocess.py | 86 +++++++------- 21 files changed, 438 insertions(+), 245 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f0..eadf994cacd3 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991..e112b43aade5 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e..0d47281db485 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db89166504..ca5b125369c8 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 6cae76f74603..1903a7582dc8 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): + params = SamplingParams() + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", None) + await engine.add_request("1", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", None) + await engine.add_request("2", "", params) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -111,7 +113,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", None) + await engine.add_request("3", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index d1056a049050..1885f2e168d8 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) - - v2_output = llm.encode(prompt, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) - - v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode( - [{ - "prompt": p - } for p in PROMPTS], - pooling_params=pooling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index cd989225e248..6543c4bb1b58 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=prompt, - sampling_params=sampling_params) - - v2_output = llm.generate(prompt, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate({"prompt": prompt}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=PROMPTS, - sampling_params=sampling_params) - - v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate( - [{ - "prompt": p - } for p in PROMPTS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 76b2f494d5b2..616a15a1328d 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd7792341..3ffa126070ca 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 90363b3e49b7..8f477ea84756 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version_tuple__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f0..54c5af2fe366 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import time import weakref from functools import partial -from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, - Mapping, Optional, Set, Tuple, Type, Union) +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType import vllm.envs as envs @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -28,7 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import weak_bind +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -402,17 +402,54 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() + @overload # DEPRECATED async def add_request_async( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -420,7 +457,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -774,16 +811,55 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - async def add_request( + @overload # DEPRECATED + def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -797,7 +873,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + prompt=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +884,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +898,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +956,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +966,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +979,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +1033,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 768ac69c3692..487255cb6b59 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union +from typing import Set, Type, Union, overload import torch from typing_extensions import TypeVar @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) + InputRegistry, LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, weak_bind +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -689,16 +689,51 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() + @overload # DEPRECATED def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + ) -> None: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Add a request to the engine's request pool. @@ -708,8 +743,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -744,6 +778,10 @@ def add_request( >>> # continue the request processing >>> ... """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -756,7 +794,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 1603189979a2..6d6d7895b210 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Union, overload from vllm import PoolingParams -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - inputs: PromptInputs + prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None + @overload # DEPRECATED + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + @dataclass class RPCError: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 0ee56f7bf840..700e65000e05 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) + Union, overload) import cloudpickle import zmq @@ -25,13 +25,14 @@ RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -367,14 +368,45 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) + @overload # DEPRECATED def generate( self, - inputs: PromptInputs, + *, + inputs: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -383,8 +415,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -393,17 +424,51 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) + @overload # DEPRECATED def encode( self, - inputs: PromptInputs, + *, + inputs: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -412,8 +477,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -424,12 +488,17 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -462,7 +531,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - inputs=inputs, + prompt=prompt, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 1b2e7ccf8664..eecca82cd2f7 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -278,7 +278,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a..d0bbeb357b50 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request.""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 77ae7b088398..f4943cb38da4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -293,8 +293,8 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -304,14 +304,13 @@ def generate( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -330,7 +329,9 @@ def generate( into a single list and pass it to this method. Args: - inputs: A list of inputs to generate completions for. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -358,12 +359,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -378,7 +380,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -648,8 +650,8 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -659,14 +661,13 @@ def encode( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -682,9 +683,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` - for more details about the format of each input. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -707,19 +708,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -763,9 +765,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -774,13 +776,13 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -788,11 +790,11 @@ def _validate_and_add_requests( guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[List[int]] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -809,9 +811,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, prompt in enumerate(prompts): self._add_request( - request_inputs, + prompt, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -821,7 +823,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -830,7 +832,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f91..a8c8672cb5fe 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", @@ -28,3 +28,17 @@ "InputContext", "InputRegistry", ] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index a71e9a7b5db6..c85f03345c67 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPrompt` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,33 +55,33 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a - decoder prompt. + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPrompt` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, - and that the `encoder_prompt` and `decoder_prompt` + and that the :code:`encoder_prompt` and :code:`decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPrompt` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -146,12 +146,8 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( @@ -182,3 +178,17 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ac9d355c64c8..e5fa1e418427 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 6d54a07e92cc..d4474a10f542 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,8 +10,8 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -209,7 +209,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -219,7 +219,7 @@ def _extract_prompt_components( Arguments: * request_id - * inputs: single encoder or decoder input prompt + * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -229,24 +229,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -254,33 +254,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -288,7 +288,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -321,7 +321,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -349,7 +349,7 @@ def _process_encoder_decoder_prompt( Arguments: - * inputs: an input prompt + * prompt: an input prompt * request_id Returns: @@ -360,13 +360,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -375,7 +375,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, ) @@ -385,20 +385,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -411,7 +411,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, ) @@ -435,7 +435,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -446,7 +446,7 @@ def _process_decoder_only_prompt( Arguments: - * inputs: input prompt + * prompt: input prompt * request_id * lora_request * prompt_adapter_request @@ -457,7 +457,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -469,14 +469,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -488,7 +488,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -498,17 +498,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -516,7 +516,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -526,17 +526,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From 4f5a5d58c87b19aa335473bbe29043f32be4cf95 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 26 Sep 2024 16:15:08 +0000 Subject: [PATCH 09/14] Fix doc --- vllm/inputs/data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c85f03345c67..dfbcf9526487 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -70,10 +70,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): Represents an encoder/decoder model input prompt, comprising an explicit encoder prompt and a decoder prompt. - The encoder and decoder prompts, respectively, - may formatted according to any of the - :class:`SingletonPrompt` schemas, and are not - required to have the same schema. + The encoder and decoder prompts, respectively, may be formatted + according to any of the :class:`SingletonPrompt` schemas, + and are not required to have the same schema. Only the encoder prompt may have multi-modal data. From 343e4c9ea7d539c0d80007dbcbe27aed9a70f432 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 27 Sep 2024 02:15:38 +0000 Subject: [PATCH 10/14] Update mllama --- vllm/model_executor/models/mllama.py | 48 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45d6ad3c0efa..202131579f6b 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -33,7 +32,7 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, EncoderDecoderInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -47,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -72,24 +71,25 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs -def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_mllama(ctx: InputContext, + inputs: EncoderDecoderInputs): # move encoder_prompt to prompt - if llm_inputs.get("prompt") is None: - llm_inputs["prompt"] = llm_inputs["encoder_prompt"] - llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if inputs.get("prompt") is None: + inputs["prompt"] = inputs["encoder_prompt"] + inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] # process multi-modal data - assert "decoder_multi_modal_data" not in llm_inputs, \ + assert "multi_modal_data" not in inputs, \ "multi-modal data should be put in encoder message of mllama" - multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + multi_modal_data = inputs.get("encoder_multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data \ or multi_modal_data["image"] is None: # text-only - llm_inputs["encoder_prompt"] = "" - llm_inputs["encoder_prompt_token_ids"] = [] - llm_inputs["encoder_multi_modal_data"] = {} - return llm_inputs + inputs["encoder_prompt"] = "" + inputs["encoder_prompt_token_ids"] = [] + inputs["encoder_multi_modal_data"] = {} + return inputs # get num_tiles if isinstance(multi_modal_data['image'], Image.Image): @@ -114,11 +114,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID - ] * num_tokens + inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - return llm_inputs + return inputs def get_max_mllama_image_tokens(ctx: InputContext) -> int: @@ -131,17 +130,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): # <|image|> * num_images + 0 * (seq_len - num_images) assert seq_len >= num_images, \ "seq_len should be greater than or equal to num_images" - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [MLLAMA_IMAGE_TOKEN_ID]) * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) - return SequenceData(token_ids) + + return SequenceData.from_prompt_token_counts( + (MLLAMA_IMAGE_TOKEN_ID, num_images), + (0, seq_len - num_images), + ) def dummy_encoder_seq_data(ctx: InputContext, num_images: int): num_tokens = get_max_mllama_image_tokens(ctx) * num_images - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens - return SequenceData(token_ids) + + return SequenceData.from_prompt_token_counts( + (MLLAMA_IMAGE_TOKEN_ID, num_tokens)) def dummy_image(num_images: int, ): From 3f099a1b8dc3e32e258bd2abd576484d52a84e3f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 27 Sep 2024 02:17:59 +0000 Subject: [PATCH 11/14] Fix type annotation --- vllm/model_executor/models/mllama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 202131579f6b..c5f02ccfe1f6 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -32,7 +32,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, EncoderDecoderInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputContext) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -72,7 +73,8 @@ class MllamaImagePixelInputs(TypedDict): def input_processor_for_mllama(ctx: InputContext, - inputs: EncoderDecoderInputs): + inputs: Union[DecoderOnlyInputs, + EncoderDecoderInputs]): # move encoder_prompt to prompt if inputs.get("prompt") is None: inputs["prompt"] = inputs["encoder_prompt"] From c2db5e1bcbcc7b89ffce650e83d6a3b0357b6fbd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 27 Sep 2024 06:41:49 +0000 Subject: [PATCH 12/14] Remove faulty assertion --- vllm/model_executor/models/mllama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index c5f02ccfe1f6..b7b74ac88970 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -81,8 +81,6 @@ def input_processor_for_mllama(ctx: InputContext, inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] # process multi-modal data - assert "multi_modal_data" not in inputs, \ - "multi-modal data should be put in encoder message of mllama" multi_modal_data = inputs.get("encoder_multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data \ From 54bb0cf2432e6d43854288ce64189e78a18ac537 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 27 Sep 2024 07:22:32 +0000 Subject: [PATCH 13/14] Fix processor call --- .../decoder_only/vision_language/test_phi3v.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 53f22ba6614f..bc6ed9413964 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,6 +1,6 @@ import os import re -from typing import Callable, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import pytest import torch @@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets, (4, 781), (16, 2653), ]) -def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, +def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, num_crops: int, expected_max_tokens: int): """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" # NOTE: mm_processor_kwargs on the context in this test is unused, since @@ -343,7 +343,7 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, (16, 2653, 1), (16, 2653, 2), ]) -def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, +def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, toks_per_img: int, num_imgs: int): """Ensure dummy_data_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs @@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, (16, 1921, 1), (16, 1921, 2), ]) -def test_input_processor_override(input_processor_for_phi3v: Callable, +def test_input_processor_override(input_processor_for_phi3v, image_assets: _ImageAssets, model: str, num_crops: int, expected_toks_per_img: int, num_imgs: int): @@ -397,11 +397,8 @@ def test_input_processor_override(input_processor_for_phi3v: Callable, prompt=prompt, multi_modal_data={"image": images}) - processed_inputs = input_processor_for_phi3v( - ctx=ctx, - processed_inputs=inputs, - num_crops=num_crops, - ) + processed_inputs = input_processor_for_phi3v(ctx, inputs, + num_crops=num_crops) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) From 6c2f55faf16e214e1f6e25b0db84cc9e80784276 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 27 Sep 2024 07:28:34 +0000 Subject: [PATCH 14/14] format --- tests/models/decoder_only/vision_language/test_phi3v.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index bc6ed9413964..d911ea8c460c 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, (16, 2653, 1), (16, 2653, 2), ]) -def test_dummy_data_override(dummy_data_for_phi3v, model: str, - num_crops: int, toks_per_img: int, num_imgs: int): +def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, + toks_per_img: int, num_imgs: int): """Ensure dummy_data_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -397,7 +397,8 @@ def test_input_processor_override(input_processor_for_phi3v, prompt=prompt, multi_modal_data={"image": images}) - processed_inputs = input_processor_for_phi3v(ctx, inputs, + processed_inputs = input_processor_for_phi3v(ctx, + inputs, num_crops=num_crops) # Ensure we have the right number of placeholders per num_crops size