diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst index 5d895837590b..f0ec1fea15dd 100644 --- a/docs/source/dev/input_processing/model_inputs_index.rst +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -25,7 +25,7 @@ Module Contents LLM Engine Inputs ----------------- -.. autoclass:: vllm.inputs.LLMInputs +.. autoclass:: vllm.inputs.DecoderOnlyInputs :members: :show-inheritance: diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 00c1b9975ef3..12e8a961877c 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,12 +1,12 @@ import os import re -from typing import Callable, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import pytest import torch from transformers import AutoImageProcessor, AutoTokenizer -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, token_inputs from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size @@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets, (4, 781), (16, 2653), ]) -def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, +def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, num_crops: int, expected_max_tokens: int): """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" # NOTE: mm_processor_kwargs on the context in this test is unused, since @@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, (16, 2653, 1), (16, 2653, 2), ]) -def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, - num_crops: int, toks_per_img: int, num_imgs: int): +def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, + toks_per_img: int, num_imgs: int): """Ensure dummy_data_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, (16, 1921, 1), (16, 1921, 2), ]) -def test_input_processor_override(input_processor_for_phi3v: Callable, +def test_input_processor_override(input_processor_for_phi3v, image_assets: _ImageAssets, model: str, num_crops: int, expected_toks_per_img: int, num_imgs: int): @@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable, prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) - proc_llm_inputs = input_processor_for_phi3v( - ctx=ctx, - llm_inputs=llm_inputs, - num_crops=num_crops, - ) + processed_inputs = input_processor_for_phi3v(ctx, + inputs, + num_crops=num_crops) # Ensure we have the right number of placeholders per num_crops size - img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index d2d0c62f5b2c..db5ab485f872 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,7 +5,7 @@ import torch from PIL.Image import Image -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, token_inputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( [f"Picture {num}: \n" for num in range(1, num_images + 1)]) - inputs = LLMInputs( + inputs = token_inputs( prompt=prompt, # When processing multimodal data for a multimodal model, the qwen # input processor will overwrite the provided prompt_token_ids with # the image prompts - prompt_token_ids=None, + prompt_token_ids=[], multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, ) proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) @@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, trust_remote_code=True) prompt = "Picture 1: \n" prompt_token_ids = tokenizer.encode(prompt) - inputs = LLMInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_data) + inputs = token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) # Should fail since we have too many or too few dimensions for embeddings with pytest.raises(ValueError): input_processor_for_qwen(qwen_vl_context, inputs) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index efc6903c373b..7b9e0b6e5234 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs from vllm.inputs.registry import InputRegistry from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -31,7 +31,7 @@ def use_processor_mock(): """Patches the internal model input processor with an override callable.""" def custom_processor(ctx: InputContext, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return @@ -84,7 +84,7 @@ def test_default_processor_is_a_noop(): dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID) processor = dummy_registry.create_input_processor(ctx.model_config) - proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") + proc_inputs = token_inputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor( - LLMInputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=inference_kwargs)) + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) assert num_crops_val == expected_seq_count @@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs num_crops_val = processor( - LLMInputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=mm_processor_kwargs)) + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=mm_processor_kwargs)) assert num_crops_val == DEFAULT_NUM_CROPS diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563e52a37d93..eb806075eb7e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,8 +29,8 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -635,7 +635,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -1855,8 +1855,8 @@ def is_encoder_decoder_model(self): def is_embedding_model(self): return self.model_config.is_embedding_model - def _validate_model_inputs(self, inputs: Union[LLMInputs, - EncoderDecoderLLMInputs]): + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, + EncoderDecoderInputs]): if self.model_config.is_multimodal_model: # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index a8c8672cb5fe..7b73922ddd2c 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,8 @@ -from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + token_inputs, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -19,8 +20,11 @@ "PromptType", "SingletonPrompt", "ExplicitEncoderDecoderPrompt", - "LLMInputs", - "EncoderDecoderLLMInputs", + "TokenInputs", + "token_inputs", + "SingletonInputs", + "DecoderOnlyInputs", + "EncoderDecoderInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", @@ -31,9 +35,9 @@ def __getattr__(name: str): - if name == "PromptInput": - import warnings + import warnings + if name == "PromptInput": msg = ("PromptInput has been renamed to PromptType. " "The original name will be removed in an upcoming version.") @@ -41,4 +45,21 @@ def __getattr__(name: str): return PromptType + if name == "LLMInputs": + msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return DecoderOnlyInputs + + if name == "EncoderDecoderLLMInputs": + msg = ( + "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return EncoderDecoderInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 724cdd2e6e80..9a094191eda3 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,5 +1,5 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, - Optional, Tuple, Union) + Optional, Tuple, Union, cast) from typing_extensions import NotRequired, TypedDict, TypeVar @@ -51,7 +51,7 @@ class TokensPrompt(TypedDict): SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ -Set of possible schemas for a single LLM input: +Set of possible schemas for a single prompt: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) @@ -120,13 +120,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """ -class LLMInputs(TypedDict): - """ - The inputs in :class:`~vllm.LLMEngine` before they are - passed to the model executor. - - This specifies the data required for decoder-only models. - """ +class TokenInputs(TypedDict): + """Represents token-based inputs.""" prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -150,7 +145,40 @@ class LLMInputs(TypedDict): """ -class EncoderDecoderLLMInputs(LLMInputs): +def token_inputs( + prompt_token_ids: List[int], + prompt: Optional[str] = None, + multi_modal_data: Optional["MultiModalDataDict"] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, +) -> TokenInputs: + """Construct :class:`TokenInputs` from optional values.""" + inputs = TokenInputs(prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if multi_modal_data is not None: + inputs["multi_modal_data"] = multi_modal_data + if mm_processor_kwargs is not None: + inputs["mm_processor_kwargs"] = mm_processor_kwargs + + return inputs + + +SingletonInputs = TokenInputs +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + +DecoderOnlyInputs = TokenInputs +""" +The inputs in :class:`~vllm.LLMEngine` before they are +passed to the model executor. +This specifies the data required for decoder-only models. +""" + + +class EncoderDecoderInputs(TokenInputs): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -204,11 +232,12 @@ def zip_enc_dec_prompts( be zipped with the encoder/decoder prompts. """ if mm_processor_kwargs is None: - mm_processor_kwargs = {} - if isinstance(mm_processor_kwargs, Dict): + mm_processor_kwargs = cast(Dict[str, Any], {}) + if isinstance(mm_processor_kwargs, dict): return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, - mm_processor_kwargs) + build_explicit_enc_dec_prompt( + encoder_prompt, decoder_prompt, + cast(Dict[str, Any], mm_processor_kwargs)) for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) ] @@ -229,9 +258,9 @@ def to_enc_dec_tuple_list( def __getattr__(name: str): - if name == "PromptInput": - import warnings + import warnings + if name == "PromptInput": msg = ("PromptInput has been renamed to PromptType. " "The original name will be removed in an upcoming version.") @@ -239,4 +268,21 @@ def __getattr__(name: str): return PromptType + if name == "LLMInputs": + msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return DecoderOnlyInputs + + if name == "EncoderDecoderLLMInputs": + msg = ( + "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return EncoderDecoderInputs + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e418427..7f9152dd3347 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, + TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt -def is_valid_encoder_decoder_llm_inputs( - inputs: Union[LLMInputs, EncoderDecoderLLMInputs], -) -> TypeIs[EncoderDecoderLLMInputs]: +def is_encoder_decoder_inputs( + inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], +) -> TypeIs[EncoderDecoderInputs]: return "encoder_prompt_token_ids" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 64387fd2fa47..82ce7d392b71 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,7 +10,7 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt @@ -306,7 +306,7 @@ def _build_enc_dec_llm_inputs( encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, mm_processor_kwargs: Dict[str, Any], - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps @@ -324,7 +324,7 @@ def _build_enc_dec_llm_inputs( decoder_prompt_ids, force_bos=(encoder_mm_data is None and decoder_mm_data is None))) - return EncoderDecoderLLMInputs( + return EncoderDecoderInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, multi_modal_data=decoder_mm_data, @@ -338,11 +338,11 @@ def _process_encoder_decoder_prompt( self, prompt: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. + :class:`EncoderDecoderInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -369,7 +369,7 @@ def _process_encoder_decoder_prompt( Returns: - * :class:`EncoderDecoderLLMInputs` instance + * :class:`EncoderDecoderInputs` instance ''' encoder_comps: PromptComponents @@ -411,7 +411,7 @@ async def _process_encoder_decoder_prompt_async( self, prompt: PromptType, request_id: str, - ) -> EncoderDecoderLLMInputs: + ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents @@ -455,17 +455,17 @@ def _build_decoder_only_llm_inputs( self, prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: + ) -> DecoderOnlyInputs: (prompt, prompt_token_ids, multi_modal_data, mm_processor_kwargs) = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) def _process_decoder_only_prompt( self, @@ -473,10 +473,10 @@ def _process_decoder_only_prompt( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: ''' For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. + Process an input prompt into an :class:`DecoderOnlyInputs` instance. Arguments: @@ -487,7 +487,7 @@ def _process_decoder_only_prompt( Returns: - * :class:`LLMInputs` instance + * :class:`DecoderOnlyInputs` instance ''' prompt_comps = self._extract_prompt_components( @@ -507,7 +507,7 @@ async def _process_decoder_only_prompt_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( prompt, @@ -526,7 +526,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -554,7 +554,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 5bd3e1c86f66..4cebc91ce715 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -12,7 +12,7 @@ from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import LLMInputs +from .data import DecoderOnlyInputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -100,7 +100,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] """Preprocess the inputs to the model.""" @@ -134,7 +134,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) + dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data @@ -245,8 +245,11 @@ def dummy_data_for_profiling( return seq_data, mm_data - def _default_input_processor(self, ctx: InputContext, - inputs: LLMInputs) -> LLMInputs: + def _default_input_processor( + self, + ctx: InputContext, + inputs: DecoderOnlyInputs, + ) -> DecoderOnlyInputs: """The default input processor is a no-op.""" return inputs @@ -279,7 +282,7 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): .get(model_cls, self._default_input_processor) def process_input(self, model_config: "ModelConfig", - inputs: LLMInputs) -> LLMInputs: + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """ Apply an input processor to an instance of model inputs. diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 7c8e76461dd6..778162dd63ca 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -63,7 +63,7 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) @@ -89,14 +89,14 @@ def dummy_image_for_blip( def input_processor_for_blip( model_config: ModelConfig, hf_config: Union[BlipVisionConfig, Blip2VisionConfig], - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[int] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -107,16 +107,16 @@ def input_processor_for_blip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 3ab235754a40..d6fe7d150336 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -9,7 +9,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) @@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs hf_config = ctx.get_hf_config(Blip2Config) image_feature_size = get_blip2_image_feature_size(hf_config) @@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): # The original model places image tokens at the front # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size - new_token_ids += llm_inputs["prompt_token_ids"] + new_token_ids += inputs["prompt_token_ids"] - new_prompt = llm_inputs.get("prompt") + new_prompt = inputs.get("prompt") if new_prompt is not None: new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 03c7419f6f6a..aaf559ca386c 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,7 +11,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) @@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_chameleon(ctx: InputContext, + inputs: DecoderOnlyInputs): """ Processing input prompt to insert required tokens for image placeholder. @@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 """ # noqa - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, @@ -137,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids += [CHAMELEON_SEP_TOKEN_ID] # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index f26c9f950dd3..8283975b9d8e 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -14,7 +14,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]: return [index for index, value in enumerate(input_ids) if value == target] -def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): hf_config = ctx.get_hf_config(ChatGLMConfig) vision_config = getattr(hf_config, 'vision_config', None) if vision_config is None: - return llm_inputs + return inputs elif isinstance(vision_config, dict): image_placeholder_length = calculate_image_placeholder(vision_config) else: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) - input_ids = llm_inputs.get("prompt_token_ids") - position_ids = llm_inputs.get("position_ids") + input_ids = inputs.get("prompt_token_ids") + position_ids = inputs.get("position_ids") tokenizer = cached_get_tokenizer( ctx.model_config.model, trust_remote_code=ctx.model_config.trust_remote_code) @@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): raw_batch_data = tokenizer.apply_chat_template( conversation=[{ "role": "user", - "image": llm_inputs['multi_modal_data']["image"], - "content": llm_inputs['prompt'] + "image": inputs['multi_modal_data']["image"], + "content": inputs['prompt'] }], add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True).data except Exception: - logger.error("Failed to process content (%s)", llm_inputs['prompt']) + logger.error("Failed to process content (%s)", inputs['prompt']) raise input_ids = raw_batch_data['input_ids'][0].tolist() @@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): assert len(new_input_ids) == len(new_position_ids) - llm_inputs["prompt_token_ids"] = new_input_ids - llm_inputs["position_ids"] = new_position_ids - return llm_inputs + inputs["prompt_token_ids"] = new_input_ids + inputs["position_ids"] = new_position_ids + return inputs class GLMAttention(nn.Module): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index edfb0c2b5e19..7b0981d611b2 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -11,7 +11,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -62,7 +62,7 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) @@ -106,14 +106,14 @@ def dummy_video_for_clip( def input_processor_for_clip( model_config: ModelConfig, hf_config: CLIPVisionConfig, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -130,16 +130,16 @@ def input_processor_for_clip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 62a1b1f8cd4c..358d1dd288c4 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, return model_image_input -def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config image_data = multi_modal_data["image"] @@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): raise TypeError(f"Invalid image type: {type(image_data)}") # process prompts - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] tokenizer = cached_get_tokenizer(model_config.model) # dim0 is batch_size, dim1 is subseq_size which will always be 1 image_input_ids: List[List[ @@ -190,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ 1:] + boa_token - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=new_multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) def input_mapper_for_fuyu(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 6adb1e29d656..aada92cdf245 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,7 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.models.intern_vit import InternVisionModel @@ -276,13 +277,13 @@ def _expand_image_prompt( def input_processor( self, ctx: InputContext, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, max_dynamic_patch: Optional[int] = None, - ) -> LLMInputs: - multi_modal_data = llm_inputs.get("multi_modal_data") + ) -> DecoderOnlyInputs: + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config() @@ -311,8 +312,8 @@ def input_processor( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) @@ -320,9 +321,9 @@ def input_processor( num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper( self, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 864b9ff66a84..fd2827c0eff0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,7 +9,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaConfig) @@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 766f6a4cc83f..4dd472b04bb1 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, raise NotImplementedError(msg) -def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_llava_next(ctx: InputContext, + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaNextConfig) @@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index e10c1f9e6e04..4a354b616c2f 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,7 +11,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -139,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, def input_processor_for_llava_next_video(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: - return llm_inputs + return inputs video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -160,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext, new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=hf_config.video_token_index, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 46e97e78d482..5bd3055ca181 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -15,8 +15,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -37,8 +37,6 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 @@ -252,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, def input_processor_when_multimodal_input_image(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(LlavaOnevisionConfig) @@ -290,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_clip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -298,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, return input_processor_for_siglip( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) @@ -308,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: - return llm_inputs + return inputs video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -326,15 +324,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=hf_config.video_token_index, repeat_count=video_feature_size, ) - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( @@ -345,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext, - llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") + inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or ("video" not in multi_modal_data and "image" not in multi_modal_data): - return llm_inputs + return inputs if "image" in multi_modal_data: - return input_processor_when_multimodal_input_image(ctx, llm_inputs) + return input_processor_when_multimodal_input_image(ctx, inputs) if "video" in multi_modal_data: - return input_processor_when_multimodal_input_video(ctx, llm_inputs) + return input_processor_when_multimodal_input_video(ctx, inputs) msg = "Unsupported multi data type" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9ee4dd0f0623..ca7c2be5a038 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,7 +36,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, @@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - return SequenceData.from_token_counts((0, seq_len)) + return SequenceData.from_prompt_token_counts((0, seq_len)) def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, @@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config version = get_version_by_config(model_config.hf_config) tokenizer = cached_get_tokenizer( @@ -297,8 +298,8 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return image_processor. \ get_slice_image_placeholder(image_size, num_image) - prompt = llm_inputs.get("prompt") - token_ids = llm_inputs.get("prompt_token_ids") + prompt = inputs.get("prompt") + token_ids = inputs.get("prompt_token_ids") if prompt is None: prompt = tokenizer.decode(token_ids) @@ -332,12 +333,11 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): _build_image_input(ctx, image) for image in images ] - llm_inputs = LLMInputs( + return token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, ) - return llm_inputs def input_mapper_for_minicpmv(ctx: InputContext, data: object): diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 66e9b2844620..378231f14455 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Mllama model.""" import math -from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -37,7 +36,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputContext) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -86,24 +86,24 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: return num_images -def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_mllama(ctx: InputContext, + inputs: Union[DecoderOnlyInputs, + EncoderDecoderInputs]): # move encoder_prompt to prompt - if llm_inputs.get("prompt") is None: - llm_inputs["prompt"] = llm_inputs["encoder_prompt"] - llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if inputs.get("prompt") is None: + inputs["prompt"] = inputs["encoder_prompt"] + inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] # process multi-modal data - assert "decoder_multi_modal_data" not in llm_inputs, \ - "multi-modal data should be put in encoder message of mllama" - multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + multi_modal_data = inputs.get("encoder_multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data \ or multi_modal_data["image"] is None: # text-only - llm_inputs["encoder_prompt"] = "" - llm_inputs["encoder_prompt_token_ids"] = [] - llm_inputs["encoder_multi_modal_data"] = {} - return llm_inputs + inputs["encoder_prompt"] = "" + inputs["encoder_prompt_token_ids"] = [] + inputs["encoder_multi_modal_data"] = {} + return inputs if isinstance(multi_modal_data['image'], Image.Image): multi_modal_data['image'] = [multi_modal_data['image']] @@ -111,7 +111,7 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # are attended by the decoded tokens, we only need to # get the number of tiles for those images. num_decode_images = _get_num_image_in_last_group( - llm_inputs["prompt_token_ids"]) + inputs["prompt_token_ids"]) hf_config = ctx.model_config.hf_config num_tiles = 0 for image in multi_modal_data["image"][::-1]: @@ -137,11 +137,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID - ] * num_tokens + inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - return llm_inputs + return inputs def get_max_mllama_image_tokens(ctx: InputContext) -> int: @@ -154,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): # <|image|> * num_images + 0 * (seq_len - num_images) assert seq_len >= num_images, \ "seq_len should be greater than or equal to num_images" - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [MLLAMA_IMAGE_TOKEN_ID]) * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) - return SequenceData(token_ids) + + return SequenceData.from_prompt_token_counts( + (MLLAMA_IMAGE_TOKEN_ID, num_images), + (0, seq_len - num_images), + ) def dummy_encoder_seq_data(ctx: InputContext, num_images: int): num_tokens = get_max_mllama_image_tokens(ctx) * num_images - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens - return SequenceData(token_ids) + + return SequenceData.from_prompt_token_counts( + (MLLAMA_IMAGE_TOKEN_ID, num_tokens)) def dummy_image(num_images: int, ): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b04916f17088..b2f0f5ea6953 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -23,7 +23,8 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -945,9 +946,9 @@ def pad_images( return images, image_input_idx, image_masks -def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): - prompt = llm_inputs.get("prompt", None) - multi_modal_data = llm_inputs.get("multi_modal_data", None) +def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): + prompt = inputs.get("prompt", None) + multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is not None: image = multi_modal_data.get("image", None) else: @@ -965,9 +966,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): elif prompt is not None: out = processor.process(prompt, image) else: - out = processor.process(None, - image, - tokens=llm_inputs["prompt_token_ids"]) + out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops @@ -1020,9 +1019,9 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = dict(image=image_data) - return LLMInputs( + return token_inputs( prompt_token_ids=out["input_ids"], - prompt=llm_inputs["prompt"], + prompt=inputs["prompt"], multi_modal_data=multi_modal_data, ) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 99d000ea13a2..7806cd6ab460 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -7,7 +7,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -68,7 +69,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, return seq_data, mm_data -def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_paligemma(ctx: InputContext, + inputs: DecoderOnlyInputs): """ The correct prompt format needs to be: @@ -77,9 +79,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 """ # noqa - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_config(PaliGemmaConfig) @@ -91,8 +93,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): image_token_str_pad = image_token_str * image_feature_size image_token_ids_pad = [hf_config.image_token_index] * image_feature_size - orig_prompt = llm_inputs.get("prompt") - orig_prompt_ids = llm_inputs.get("prompt_token_ids") + orig_prompt = inputs.get("prompt") + orig_prompt_ids = inputs.get("prompt_token_ids") if orig_prompt is not None and image_token_str in orig_prompt: logger.warning( @@ -106,9 +108,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PaliGemmaMultiModalProjector(nn.Module): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index bcd5cd2154e6..91c14e32c946 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -27,7 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -410,12 +411,12 @@ def _get_image_placeholder_token_id_candidates( def input_processor_for_phi3v(ctx: InputContext, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, num_crops: Optional[int] = None): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs model_config = ctx.model_config hf_config = ctx.get_hf_image_processor_config() @@ -447,7 +448,7 @@ def input_processor_for_phi3v(ctx: InputContext, else: raise TypeError(f"Invalid image type: {type(image_data)}") - prompt = llm_inputs.get("prompt") + prompt = inputs.get("prompt") if prompt is None: # for async server request, we assume prompt and its token_ids is always # in correct format. And num_image_tags == len(image_data) always True. @@ -464,7 +465,7 @@ def input_processor_for_phi3v(ctx: InputContext, image_data), "The count of image_placeholder not match image's" new_prompt = prompt - prompt_token_ids = llm_inputs["prompt_token_ids"].copy() + prompt_token_ids = inputs["prompt_token_ids"].copy() print("prompt_token_ids (old)", prompt_token_ids) @@ -506,10 +507,9 @@ def input_processor_for_phi3v(ctx: InputContext, new_token_ids.append(token_id) # NOTE: Create a defensive copy of the original inputs - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - return llm_inputs + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) @MULTIMODAL_REGISTRY.register_image_input_mapper() diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c8957dcae6b1..f34d21fdef56 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -14,7 +14,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -62,7 +62,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images - seq_data = SequenceData.from_token_counts( + seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), (0, seq_len - num_image_tokens), ) @@ -102,8 +102,8 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is not None and "image" in multi_modal_data: tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, @@ -112,15 +112,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder image_token_id = mm_encoder.special_ids.img - if image_token_id not in llm_inputs['prompt_token_ids']: + if image_token_id not in inputs['prompt_token_ids']: raise ValueError( - (f"You've passed {llm_inputs=} without {image_token_id=}" + (f"You've passed {inputs=} without {image_token_id=}" " Make sure to process your input via mistral_common's" " tokenizer or pass a chat completion request. For more" " For more info, see: " "https://github.com/vllm-project/vllm/issues/8411.")) - return llm_inputs + return inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index fd8a27eec3b9..cd3f7c1b6c4d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -22,7 +22,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -652,30 +653,30 @@ def get_image_text(image_num: int, padding: bool) -> str: def input_processor_for_qwen(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: + inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: """Processes the inputs, which may or may not be multimodal. Multimodal inputs will only be processed if the model has a "visual" component in its model config, otherwise they'll be ignored. Args: ctx: Context of the loaded model. - llm_inputs: LLM inputs which may have a multi_modal_data attribute. + inputs: LLM inputs which may have a multi_modal_data attribute. Returns: If the model is language only or not multimodal inputs were provided, - returns llm_inputs unmodified. Otherwise, processes the multimodal + returns inputs unmodified. Otherwise, processes the multimodal images / image embeddings and adds the fixed-length image placeholders. """ - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") # Only process images if we have multimodal data and a visual config hf_config = ctx.get_hf_config() if (multi_modal_data is None or "image" not in multi_modal_data or not hasattr(hf_config, "visual")): - return llm_inputs + return inputs - prompt = llm_inputs.get("prompt") - prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -713,9 +714,9 @@ def input_processor_for_qwen(ctx: InputContext, new_prompt_token_ids = tokenizer.encode(new_prompt) - return LLMInputs(prompt=new_prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: @@ -822,7 +823,7 @@ def dummy_data_for_qwen( # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceData.from_token_counts((0, seq_len)) + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) mm_data = None return seq_data, mm_data diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index bdc21df8b656..94c7d6507770 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -46,7 +46,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -716,7 +717,7 @@ def dummy_data_for_qwen2_vl( hf_config = ctx.get_hf_config(Qwen2VLConfig) - dummy_seqdata = SequenceData.from_token_counts( + dummy_seqdata = SequenceData.from_prompt_token_counts( (hf_config.vision_start_token_id, 1), (hf_config.image_token_id, max_llm_image_tokens), (hf_config.vision_end_token_id, 1), @@ -799,11 +800,13 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, return prompt_token_ids_with_data -def input_processor_for_qwen2_vl(ctx: InputContext, - llm_inputs: LLMInputs) -> LLMInputs: - multi_modal_data = llm_inputs.get("multi_modal_data", None) +def input_processor_for_qwen2_vl( + ctx: InputContext, + inputs: DecoderOnlyInputs, +) -> DecoderOnlyInputs: + multi_modal_data = inputs.get("multi_modal_data", None) if multi_modal_data is None: - return llm_inputs + return inputs image_inputs = multi_modal_data.get("image", None) video_inputs = multi_modal_data.get("video", None) @@ -817,7 +820,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext, # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. # # The following code is equivalent to: - # prompt = llm_inputs["prompt"] + # prompt = inputs["prompt"] # inputs = processor(text=[prompt], # images=image_inputs, # videos=video_inputs, @@ -825,9 +828,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, # return_tensors="pt") # prompt_token_ids = inputs["input_ids"][0].tolist() - prompt_token_ids = llm_inputs.get("prompt_token_ids", None) + prompt_token_ids = inputs.get("prompt_token_ids", None) if prompt_token_ids is None: - prompt = llm_inputs["prompt"] + prompt = inputs["prompt"] prompt_token_ids = processor.tokenizer( prompt, padding=True, @@ -868,9 +871,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, image_processor, prompt_token_ids) - return LLMInputs( + return token_inputs( prompt_token_ids=prompt_token_ids, - prompt=llm_inputs["prompt"], + prompt=inputs["prompt"], multi_modal_data=multi_modal_data, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 743a81f8f9e9..e717ab108c77 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -13,7 +13,7 @@ from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.inputs import LLMInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -67,7 +67,7 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) @@ -111,14 +111,14 @@ def dummy_video_for_siglip( def input_processor_for_siglip( model_config: ModelConfig, hf_config: SiglipVisionConfig, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, ): - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) @@ -135,14 +135,14 @@ def input_processor_for_siglip( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs( + return token_inputs( prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e162e3af008e..49c32cbeaa36 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs +from vllm.inputs.data import DecoderOnlyInputs, token_inputs from vllm.inputs.registry import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -156,10 +156,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): return MultiModalInputs({"audio_features": audio_features}) -def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") +def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "audio" not in multi_modal_data: - return llm_inputs + return inputs feature_extractor = whisper_feature_extractor(ctx) audios = multi_modal_data["audio"] @@ -196,16 +196,16 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], + inputs.get("prompt"), + inputs["prompt_token_ids"], placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, repeat_count=audio_token_counts, ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class StackAudioFrames(nn.Module): diff --git a/vllm/sequence.py b/vllm/sequence.py index 728445cb4b54..03f774df1693 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,8 +13,7 @@ import msgspec import torch -from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs -from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -22,6 +21,7 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: + from vllm.inputs import SingletonInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -29,6 +29,11 @@ VLLM_INVALID_TOKEN_ID = -1 +def array_full(token_id: int, count: int): + """:class:`array` equivalent of :func:`numpy.full`.""" + return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. # TODO(sang): Fix it. @@ -173,22 +178,34 @@ class SequenceData(msgspec.Struct, _mrope_position_delta: Optional[int] = None @staticmethod - def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": + def from_prompt_token_counts( + *token_counts: Tuple[int, int]) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance by concatenating + prompt token sequences. + + Each tuple represents one token sequence, expressed in the form + :code:`(token_id, count)`. + """ if len(token_counts) == 0: return SequenceData.from_seqs([]) - arrs = [ - array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - for token_id, count in token_counts - ] + prompt_token_ids_arr = reduce( + array.__iadd__, + (array_full(token_id, count) for token_id, count in token_counts), + ) - return SequenceData(reduce(array.__add__, arrs)) + return SequenceData(prompt_token_ids_arr) @staticmethod def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, ) -> "SequenceData": + """ + Construct a :class:`SequenceData` instance from prompt and output + token sequences. + """ prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids) @@ -362,14 +379,14 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - The sequence is constructed from the LLMInputs instance passed - in through the `inputs` constructor argument. + The sequence is constructed from the :code:`SingletonInputs` instance + passed in through the :code:`inputs` constructor argument. - For encoder/decoder models, LLMInputs encapsulates both a + For encoder/decoder models, SingletonInputs encapsulates both a decoder and encoder prompt, creating an ambiguity about which prompt to construct the sequence from. The `from_decoder_prompt` constructor argument signals whether to construct the Sequence - from the LLMInputs decoder prompt, or encoder prompt. + from the SingletonInputs decoder prompt, or encoder prompt. Args: seq_id: The ID of the sequence. @@ -379,16 +396,16 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt - (True) or encoder prompt (False.) Must be True - for decoder-only model. + from_decoder_prompt: Construct Sequence from SingletonInputs decoder + prompt (True) or encoder prompt (False.) Must be + True for decoder-only model. """ def __init__( self, seq_id: int, - inputs: "LLMInputs", + inputs: "SingletonInputs", block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, @@ -404,19 +421,19 @@ def __init__( self.from_decoder_prompt = from_decoder_prompt # For decoder-only models, a Sequence is constructed - # from an LLMInputs instance (the `inputs` arg.) + # from an DecoderOnlyInputs instance (the `inputs` arg.) # # For encoder/decoder models the same `inputs` # instance could be utilized to construct either an # encoder sequence or a decoder sequence, because - # `LLMInputs` has both decoder- and encoder-oriented + # `DecoderOnlyInputs` has both decoder- and encoder-oriented # member variables (i.e. it encapsulates both an encoder # and a decoder prompt.) The decision of which type of sequence # to generate is determined by the `from_decoder_prompt` argument. # # When constructing a encoder sequence # (`from_decoder_prompt` False) it matters that - # the `LLMInputs` instance stored in `inputs` is valid + # the `DecoderOnlyInputs` instance stored in `inputs` is valid # in the sense that its encoder-related member variables are # populated; below, an exception is raised if this is # not the case. @@ -424,8 +441,7 @@ def __init__( # When constructing a decoder sequence (`from_decoder_prompt` True) # it does not matter whether `inputs` has its encoder-related # member variables populated. - if not (from_decoder_prompt - or is_valid_encoder_decoder_llm_inputs(inputs)): + if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)): raise ValueError("Cannot extract encoder input prompt from " f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") @@ -471,15 +487,19 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> "MultiModalDataDict": - if self.inputs.get("multi_modal_data") and self.inputs.get( - "encoder_multi_modal_data"): + inputs = self.inputs + + if (inputs.get("multi_modal_data") + and inputs.get("encoder_multi_modal_data")): raise ValueError( "Multi-modal data in both encoder and decoder is not supported." ) - inputs = self.inputs - return self.inputs.get("multi_modal_data") or (cast( - EncoderDecoderLLMInputs, - inputs).get("encoder_multi_modal_data")) or {} + + return cast( + "MultiModalDataDict", + (inputs.get("multi_modal_data") + or inputs.get("encoder_multi_modal_data") or {}), + ) @property def mm_processor_kwargs(self) -> Dict[str, Any]: