diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index cf85a2135c81..97d03d5e3b40 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1596,7 +1596,6 @@ def schedule( multi_modal_placeholders=( seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None), - mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ba0cdee461a2..6cc9b881464e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -493,12 +493,11 @@ async def add_request_async( tokenizer = await self.get_tokenizer_async(lora_request) self._validate_token_prompt(prompt, tokenizer=tokenizer) - preprocessed_inputs = await self.input_preprocessor.preprocess_async( + processed_inputs = await self.input_preprocessor.preprocess_async( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) if isinstance(params, SamplingParams) and \ params.guided_decoding is not None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 276891489836..c23530990611 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,8 +29,7 @@ from vllm.entrypoints.openai.logits_processors import ( get_logits_processors as get_openai_logits_processors) from vllm.executor.executor_base import ExecutorBase -from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType, SingletonInputs) +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -213,7 +212,6 @@ def __init__( log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -274,11 +272,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.tokenizer, mm_registry) - self.input_registry = input_registry - self.input_processor = input_registry.create_input_processor( - self.model_config) - - self.model_executor = executor_class(vllm_config=vllm_config, ) + self.model_executor = executor_class(vllm_config=vllm_config) if self.model_config.runner_type != "pooling": self._initialize_kv_caches() @@ -762,12 +756,11 @@ def add_request( prompt, tokenizer=self.get_tokenizer(lora_request=lora_request)) - preprocessed_inputs = self.input_preprocessor.preprocess( + processed_inputs = self.input_preprocessor.preprocess( prompt, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) - processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 6f8f2cd758f7..ca706e202836 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -2,10 +2,9 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonInputsAdapter, SingletonPrompt, - TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -27,7 +26,6 @@ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", - "SingletonInputsAdapter", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 138a8f61107b..970b36bca9be 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,17 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 - from collections.abc import Iterable -from dataclasses import dataclass -from functools import cached_property from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast -import torch -from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never +from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, - MultiModalPlaceholderDict) - from vllm.multimodal.inputs import MultiModalInputs + from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs class TextPrompt(TypedDict): @@ -147,46 +141,11 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired["MultiModalDataDict"] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - multi_modal_inputs: NotRequired["MultiModalKwargs"] - """ - Optional multi-modal inputs to pass to the model, - if the model supports it. - """ - - multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] - """ - Placeholder ranges for the multi-modal data. - """ - - multi_modal_hashes: NotRequired[list[str]] - """ - The hashes of the multi-modal data. - """ - - mm_processor_kwargs: NotRequired[dict[str, Any]] - """ - Optional multi-modal processor kwargs to be forwarded to the - multimodal input mapper & processor. Note that if multiple modalities - have registered mappers etc for the model being considered, we attempt - to pass the mm_processor_kwargs to each of them. - """ - def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, - multi_modal_data: Optional["MultiModalDataDict"] = None, - multi_modal_inputs: Optional["MultiModalKwargs"] = None, - multi_modal_hashes: Optional[list[str]] = None, - multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -195,16 +154,6 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids - if multi_modal_data is not None: - inputs["multi_modal_data"] = multi_modal_data - if multi_modal_inputs is not None: - inputs["multi_modal_inputs"] = multi_modal_inputs - if multi_modal_hashes is not None: - inputs["multi_modal_hashes"] = multi_modal_hashes - if multi_modal_placeholders is not None: - inputs["multi_modal_placeholders"] = multi_modal_placeholders - if mm_processor_kwargs is not None: - inputs["mm_processor_kwargs"] = mm_processor_kwargs return inputs @@ -237,112 +186,6 @@ class EncoderDecoderInputs(TypedDict): :class:`vllm.sequence.Sequence`. """ - -@dataclass -class SingletonInputsAdapter: - """ - Unified interface to access the components of :class:`SingletonInputs`. - """ - inputs: SingletonInputs - - @cached_property - def prompt(self) -> Optional[str]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt") - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_token_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt_token_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def token_type_ids(self) -> list[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("token_type_ids", []) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def prompt_embeds(self) -> Optional[torch.Tensor]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return None - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_data(self) -> "MultiModalDataDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_data", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_inputs", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_hashes(self) -> list[str]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_hashes", []) - - if inputs["type"] == "multimodal": - # only the case when we use MultiModalInputs - return inputs.get("mm_hashes", []) # type: ignore[return-value] - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_placeholders", {}) - - assert_never(inputs) # type: ignore[arg-type] - - @cached_property - def mm_processor_kwargs(self) -> dict[str, Any]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) - - if inputs["type"] == "multimodal": - return {} - - assert_never(inputs) # type: ignore[arg-type] - - ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ The inputs to :data:`vllm.inputs.InputProcessor`. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a4609290a900..0edb6da06209 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -223,28 +223,6 @@ async def _tokenize_prompt_async( lora_request=lora_request, add_special_tokens=add_special_tokens) - def _can_process_multimodal(self) -> bool: - model_config = self.model_config - - if not model_config.is_multimodal_model: - raise ValueError("Your model does not support multi-modal inputs") - - # Interim measure so we can handle models that have yet to be - # updated to use the new multi-modal processor - can_process_multimodal = self.mm_registry.has_processor(model_config) - if not can_process_multimodal: - from vllm.model_executor.models.registry import _VLLM_MODELS - if not any(arch in _VLLM_MODELS - for arch in model_config.architectures): - logger.warning_once( - "Your model uses the legacy input pipeline, which will be " - "removed in an upcoming release. " - "Please upgrade to the new multi-modal processing pipeline " - "(https://docs.vllm.ai/en/latest/design/mm_processing.html)" - ) - - return can_process_multimodal - def _process_multimodal( self, prompt: Union[str, list[int]], @@ -258,8 +236,7 @@ def _process_multimodal( returning the corresponding token IDs and metadata. """ # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -285,8 +262,7 @@ async def _process_multimodal_async( ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" # At the moment on model (PrithviGeoSpatialMAE) requires to be - # initialized without a tokenizer while using also multi-modal - # input. + # initialized without a tokenizer while using also multi-modal input if not self.tokenizer: tokenizer = object() # Dummy else: @@ -343,7 +319,7 @@ def _prompt_to_llm_inputs( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_token_ids, multi_modal_data, @@ -355,8 +331,6 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, token_type_ids=token_type_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) if parsed["type"] == "text": @@ -366,7 +340,7 @@ def _prompt_to_llm_inputs( multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return self._process_multimodal( prompt_text, multi_modal_data, @@ -383,8 +357,6 @@ def _prompt_to_llm_inputs( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -417,7 +389,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_token_ids, multi_modal_data, @@ -426,11 +398,7 @@ async def _prompt_to_llm_inputs_async( return_mm_hashes=return_mm_hashes, ) - return token_inputs( - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - ) + return token_inputs(prompt_token_ids=prompt_token_ids) if parsed["type"] == "text": text_content = parsed["content"] @@ -439,7 +407,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = text_content.get("multi_modal_data") mm_processor_kwargs = text_content.get("mm_processor_kwargs") - if multi_modal_data is not None and self._can_process_multimodal(): + if multi_modal_data is not None: return await self._process_multimodal_async( prompt_text, multi_modal_data, @@ -456,8 +424,6 @@ async def _prompt_to_llm_inputs_async( return token_inputs( prompt=prompt_text, prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, ) assert_never(parsed) @@ -594,15 +560,13 @@ def _process_encoder_decoder_prompt( decoder_inputs = self._prompt_to_llm_inputs(decoder_input) # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = self._prompt_to_llm_inputs(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( @@ -637,15 +601,13 @@ async def _process_encoder_decoder_prompt_async( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) else: inputs = await self._prompt_to_llm_inputs_async(prompt) - if self.model_config.is_multimodal_model and ( - self._can_process_multimodal()): + if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 0579893e5d76..4c334ab62d3e 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,24 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -from collections import UserDict from collections.abc import Mapping from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional, - Protocol, Union) +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union -from torch import nn from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar -from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) - -from .data import ProcessorInputs, SingletonInputs -from .parse import split_enc_dec_inputs +from vllm.utils import resolve_mm_processor_kwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -26,8 +16,6 @@ MultiModalRegistry) from vllm.sequence import SequenceData -logger = init_logger(__name__) - _T = TypeVar("_T") _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) _P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) @@ -172,142 +160,23 @@ def call_hf_processor( raise RuntimeError(msg) from exc -N = TypeVar("N", bound=type[nn.Module]) - - class DummyData(NamedTuple): - """Dummy data used for profiling.""" + """ + Dummy data used for profiling. + + Note: This is only used in V0. + """ seq_data: "SequenceData" multi_modal_data: Optional["MultiModalDataDict"] = None multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None -class DummyDataFactory(Protocol): - - def __call__( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - **mm_processor_kwargs: Any, - ) -> DummyData: - """ - Create dummy data to be inputted into the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - - The :code:`mm_processor_kwargs` are overrides provided at - initialization time to values in the config whose values - may affect the number of tokens per instance. - """ - ... - - -class _MultiModalCounts(UserDict[str, int]): - """ - Wraps `mm_counts` for a more informative error message - when attempting to access a plugin that does not exist. - """ - - def __getitem__(self, key: str) -> int: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"There is no multi-modal plugin with the key: {key}. " - f"Available keys: {set(self.keys())}") - raise KeyError(msg) from exc - - -InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] -"""Preprocess the inputs to the model.""" - - class InputRegistry: """ - A registry to dispatch data processing - according to the target model. + Note: This is only used in V0. """ - def __init__(self) -> None: - self._dummy_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._dummy_encoder_factories_by_model_type = \ - ClassRegistry[nn.Module, DummyDataFactory]() - self._input_processors_by_model_type = \ - ClassRegistry[nn.Module, InputProcessor]() - - def _default_dummy_data_factory( - self, - ctx: InputContext, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> DummyData: - """ - The default dummy data factory represents the longest possible text - that can be inputted to the model. - - Note: - :data:`InputProcessor` is not applied to the dummy data. - """ - # Avoid circular import - from vllm.sequence import SequenceData - - return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) - - def register_dummy_data(self, factory: DummyDataFactory): - """ - Register a dummy data factory to a model class. - - During memory profiling, the provided function is invoked to create - dummy data to be inputted into the model. The resulting memory usage - should be an upper bound of what the model would use at inference time. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_factories_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has dummy data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - - def register_dummy_encoder_data(self, factory: DummyDataFactory): - """ - Register a dummy encoder data factory to a model class - - This is similar to :meth:`~register_dummy_data`, but for encoder input. - """ - - def wrapper(model_cls: N) -> N: - if self._dummy_encoder_factories_by_model_type.contains( - model_cls, strict=True): - logger.warning( - "Model class %s already has dummy encoder data " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._dummy_encoder_factories_by_model_type[model_cls] = factory - - return model_cls - - return wrapper - - def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]): - return self._dummy_encoder_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) - def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -319,169 +188,25 @@ def dummy_data_for_profiling( Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. - - Note: - This should be called after - :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. """ # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.profiling import MultiModalProfiler from vllm.sequence import SequenceData - if mm_registry.has_processor(model_config): - processor = mm_registry.create_processor(model_config, - disable_cache=True) - profiler = MultiModalProfiler(processor) - - dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len) - if is_encoder_data else - profiler.get_decoder_dummy_data(seq_len)) - _seq_data = SequenceData.from_seqs( - dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined] - - dummy_data = DummyData( - seq_data=_seq_data, - multi_modal_data=getattr(dummy_data_v1, "multi_modal_data", - None), - multi_modal_placeholders=getattr(dummy_data_v1, - "multi_modal_placeholders", - None), - ) - else: - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) - else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) - - # Having more tokens is over-conservative but otherwise fine - num_tokens = dummy_data.seq_data.prompt_token_ids - if len(num_tokens) < seq_len: - if is_encoder_data: - logger.warning_once( - f"Expected at least {seq_len} dummy encoder tokens for " - f"profiling, but found {len(num_tokens)} tokens instead.") - else: - raise AssertionError( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - - if (dummy_data.multi_modal_data is not None and - not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): - for k, v in dummy_data.multi_modal_data.items(): - num_items = len(v) if isinstance(v, list) else 1 - num_expected = mm_counts[k] - assert num_items >= num_expected, ( - f"Expected at least {num_expected} dummy '{k}' instances " - f"for profiling, but found {num_items} instances instead.") - - return dummy_data - - def _default_input_processor( - self, - ctx: InputContext, - inputs: ProcessorInputs, - **kwargs: object, - ) -> ProcessorInputs: - """The default input processor is a no-op.""" - return inputs - - def register_input_processor(self, processor: InputProcessor): - """ - Register an input processor to a model class. - - The provided function is invoked on each input to the model. This - happens before - :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`. - """ - - def wrapper(model_cls: N) -> N: - if self._input_processors_by_model_type.contains(model_cls, - strict=True): - logger.warning( - "Model class %s already has input processor " - "registered to %s. It is overwritten by the new one.", - model_cls, self) - - self._input_processors_by_model_type[model_cls] = processor - - return model_cls + if not model_config.is_multimodal_model: + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) + return DummyData(seq_data=seq_data) - return wrapper + # Encoder dummy data does not contain multi-modal data + if is_encoder_data: + enc_data = mm_registry.get_encoder_dummy_data( + model_config, seq_len) + seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) + return DummyData(seq_data=seq_data) - def _get_model_input_processor(self, model_cls: type[nn.Module]): - return self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) - - def _ensure_mm_kwargs( - self, - inputs: SingletonInputs, - mm_processor_kwargs: dict[str, Any], - ): - if inputs["type"] == "token": - # In case the input processor for that model fails to set it - if "mm_processor_kwargs" not in inputs: - inputs["mm_processor_kwargs"] = mm_processor_kwargs - elif inputs["type"] == "multimodal": - # Be more strict in V2 - assert "mm_kwargs" in inputs - else: - assert_never(inputs["type"]) # type: ignore[arg-type] - - def process_input(self, model_config: "ModelConfig", - inputs: ProcessorInputs) -> ProcessorInputs: - """ - Apply an input processor to an instance of model inputs. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - processor = self._get_model_input_processor(model_cls) - - # Handle multimodal processor kwargs with priority: - # Inference kwargs -> Init kwargs -> {} - # If it's empty, it'll fall back to the default kwarg values - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs", {}), # type: ignore - processor, - requires_kw_only=False, - allow_var_kwargs=True, - ) + dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) - processed_inputs = processor( - InputContext(model_config), - inputs, - **mm_processor_kwargs, + return DummyData( + seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), + multi_modal_data=dec_data.multi_modal_data, + multi_modal_placeholders=dec_data.multi_modal_placeholders, ) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - if encoder_inputs is not None: - self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs) - if decoder_inputs is not None: - self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs) - - return processed_inputs - - def create_input_processor(self, model_config: "ModelConfig"): - """ - Create an input processor (see :meth:`_process_input`) for a - specific model. - """ - return functools.partial(self.process_input, model_config) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 741bd1a6a1c1..c65d9407dcd1 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - -from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .base import MultiModalPlaceholderMap from .hasher import MultiModalHashDict, MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, @@ -26,7 +25,6 @@ "MultiModalKwargs", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", - "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 70a912c9c9ef..1fd2ab7f87d1 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -7,11 +7,9 @@ import numpy as np import numpy.typing as npt -from vllm.inputs.registry import InputContext from vllm.utils import PlaceholderModule -from .base import MediaIO, MultiModalPlugin -from .inputs import AudioItem, ModalityData, MultiModalKwargs +from .base import MediaIO try: import librosa @@ -24,25 +22,6 @@ soundfile = PlaceholderModule("soundfile") # type: ignore[assignment] -class AudioPlugin(MultiModalPlugin): - """Plugin for audio data.""" - - def get_data_key(self) -> str: - return "audio" - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[AudioItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - raise NotImplementedError("There is no default audio input mapper") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - raise NotImplementedError( - "There is no default maximum multimodal tokens") - - def resample_audio_librosa( audio: npt.NDArray[np.floating], *, diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ad95b982499c..2f93922fcedb 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,247 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple, - Optional, TypeVar, Union) - -from torch import nn - -from vllm.inputs import InputContext -from vllm.logger import init_logger -from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - resolve_mm_processor_kwargs) +from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar if TYPE_CHECKING: - from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs, - PlaceholderRange) - -logger = init_logger(__name__) - -MultiModalInputMapper = Callable[[InputContext, ModalityData[object]], - MultiModalKwargs] -""" -Return a dictionary to be passed as keyword arguments to -:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers -and processors in HuggingFace Transformers. - -If the data is not supported, throw :exc:`TypeError`. -""" - -MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] -""" -Calculate the maximum number of multimodal tokens input to the language -model. This does not include tokens that correspond to the input text. -""" +from .inputs import MultiModalKwargs, PlaceholderRange _T = TypeVar("_T") -N = TypeVar("N", bound=type[nn.Module]) - - -class MultiModalPlugin(ABC): - """ - Base class that defines data processing logic for a specific modality. - - In particular, we adopt a registry pattern to dispatch data processing - according to the model being used (considering that different models may - process the same data differently). This registry is in turn used by - :class:`~MultiModalRegistry` which acts at a higher level - (i.e., the modality of the data). - """ - - def __init__(self) -> None: - self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() - self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() - - @abstractmethod - def get_data_key(self) -> str: - """ - Get the data key corresponding to the modality. - """ - raise NotImplementedError - - @abstractmethod - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[Any], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - """ - Return a dictionary to be passed as keyword arguments to - :meth:`~torch.nn.Module.forward`. This is similar in concept to - tokenizers and processors in HuggingFace Transformers. - - If the data is not supported, throw :exc:`TypeError`. - """ - raise NotImplementedError - - def register_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper to a model class. - - When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_key`), the provided function is - invoked to transform the data into a dictionary of model inputs. - - If `None` is provided, then the default input mapper is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._input_mappers.contains(model_cls, strict=True): - logger.warning( - "Model class %s already has an input mapper " - "registered to %s. It is overwritten by the new one.", - model_cls, - self, - ) - - self._input_mappers[model_cls] = (mapper - or self._default_input_mapper) - - return model_cls - - return wrapper - - def map_input( - self, - model_config: "ModelConfig", - data: ModalityData[Any], - mm_processor_kwargs: Optional[dict[str, Any]], - ) -> MultiModalKwargs: - """ - Transform the data into a dictionary of model inputs using the - input mapper registered for that model. - - The model is identified by ``model_config``. - - Raises: - TypeError: If the data type is not supported. - """ - - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - - mapper = self._input_mappers.get(model_cls) - - if mapper is None: - raise KeyError(f"No input mapper in {self} is registered for " - f"model class {model_cls.__name__}.") - - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - - # In the case of the default mapper, we have to get resource - # processor through its HuggingFace autoclass; since this goes - # through **kwargs, we can't inspect it the same way, so we allow - # drop mm_processor_kwargs based on signature inspection - # if we're using the default mapper. - # - # This should be safe in general due to the sanitation, since the - # transformers resource should filter unused kwargs anyway. - uses_default_mapper = mapper == self._default_input_mapper - mm_processor_kwargs = resolve_mm_processor_kwargs( - model_config.mm_processor_kwargs, - mm_processor_kwargs, - callable=mapper, - allow_var_kwargs=uses_default_mapper, - ) - return mapper(InputContext(model_config), data, **mm_processor_kwargs) - - @abstractmethod - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - """ - Calculate the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model. - """ - raise NotImplementedError - - def _validate_max_multimodal_tokens(self, max_mm_tokens: int): - if max_mm_tokens < 1: - raise ValueError("You should set the number of tokens to a " - f"positive integer. Found: {max_mm_tokens}") - - def register_max_multimodal_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data, that are passed to the language model - for a model class. - - If `None` is provided, then the default calculation is used instead. - """ - - def wrapper(model_cls: N) -> N: - if self._max_mm_tokens.contains(model_cls, strict=True): - logger.warning( - "Model class %s already calculates maximum number of " - "tokens in %s. It is overwritten by the new one.", - model_cls, - self, - ) - - if isinstance(max_mm_tokens, int): - self._validate_max_multimodal_tokens(max_mm_tokens) - - self._max_mm_tokens[model_cls] = ( - max_mm_tokens or self._default_max_multimodal_tokens) - - return model_cls - - return wrapper - - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: - """ - Get the maximum number of multi-modal tokens - for profiling the memory usage of a model. - - If this registry is not applicable to the model, `0` is returned. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - from vllm.model_executor.models import supports_multimodal - - model_cls, _ = get_model_architecture(model_config) - - if not supports_multimodal(model_cls): - return 0 - - max_mm_tokens = self._max_mm_tokens.get(model_cls) - if max_mm_tokens is None: - return 0 - - if callable(max_mm_tokens): - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - max_mm_tokens, - overrides=model_config.mm_processor_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - max_mm_tokens = max_mm_tokens(InputContext(model_config), - **mm_processor_kwargs) - - self._validate_max_multimodal_tokens(max_mm_tokens) - - return max_mm_tokens class MultiModalPlaceholderMap: """ Relates multi-modal embeddings to their corresponding placeholders. + + Note: This is only used in V0. """ class IndexMap(NamedTuple): @@ -279,8 +55,7 @@ def __init__(self): @classmethod def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[Optional[MultiModalDataDict], dict[str, - "MultiModalPlaceholderMap"]]: + ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: """ Returns the multi-modal items that intersect with the portion of a prompt (``seq_group``) represented by ``positions``, as well as a @@ -323,48 +98,24 @@ def from_seq_group( seq_mm_placeholders = seq_group.multi_modal_placeholders if not seq_mm_data or not seq_mm_placeholders: - return seq_mm_data, {} - - # For merged processor, we directly use mm_kwargs as mm_data - if isinstance(seq_mm_data, MultiModalKwargs): - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() + return MultiModalKwargs({}), {} - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - mm_data = {**seq_mm_data} - placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( - MultiModalPlaceholderMap) + placeholder_maps = dict[str, MultiModalPlaceholderMap]() for modality, placeholders in seq_mm_placeholders.items(): - mm_items = mm_data.pop(modality) - if not isinstance(mm_items, list): - mm_items = [mm_items] + placeholder_map = MultiModalPlaceholderMap() if positions: - intersecting_items = placeholder_maps[modality] \ - .append_items_from_seq_group( - positions, - mm_items, - placeholders, - ) + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) - if intersecting_items: - mm_data[modality] = intersecting_items + placeholder_maps[modality] = placeholder_map - return mm_data, placeholder_maps + return seq_mm_data, placeholder_maps def append_items_from_seq_group( self, @@ -445,8 +196,7 @@ def index_map(self) -> "IndexMap": f"The number of source ({len(src_indices)}) and destination " f"indices ({len(dest_indices)}) must be the same.") - return MultiModalPlaceholderMap.IndexMap(src=src_indices, - dest=dest_indices) + return self.IndexMap(src=src_indices, dest=dest_indices) class MediaIO(ABC, Generic[_T]): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 0c5a84c6508a..939928bbf108 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,89 +3,11 @@ import base64 from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import torch from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_image_processor -from vllm.utils import is_list_of - -from .base import MediaIO, MultiModalPlugin -from .inputs import ImageItem, ModalityData, MultiModalKwargs - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class ImagePlugin(MultiModalPlugin): - """Plugin for image data.""" - - def get_data_key(self) -> str: - return "image" - - def _get_hf_image_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[ImageItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - # PIL image - if isinstance(data, Image.Image) or is_list_of(data, Image.Image): - image_processor = self._get_hf_image_processor( - model_config, - mm_processor_kwargs, - ) - - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - # NOTE: It may make sense to forward the mm_processor_kwargs - # here too. For now, to keep it simple, we only allow it be - # used for the initialization call though, just in case the - # signatures of the preprocessor initializer don't match - # preprocess() - batch_data = image_processor \ - .preprocess(data, return_tensors="pt") \ - .data - except Exception: - logger.error( - "Failed to process image (%s) with the default mapper. " - "This is most likely an edge-case with this model's image " - "processor in transformers (type: %s), and not vLLM.", - data, - type(image_processor).__name__) - raise - - return MultiModalKwargs(batch_data) - - # Image embedding - elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): - return MultiModalKwargs({"image_embeds": data}) - - raise TypeError(f"Invalid image type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 3000 +from .base import MediaIO def rescale_image_size(image: Image.Image, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5c687e49d22b..ec4f15681019 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -import json -from collections import UserDict -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn +from typing_extensions import deprecated from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext @@ -16,15 +13,10 @@ cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .audio import AudioPlugin -from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc -from .image import ImagePlugin -from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache) from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) -from .video import VideoPlugin if TYPE_CHECKING: from vllm.config import ModelConfig @@ -85,169 +77,23 @@ def build_processor( return self.processor(info, dummy_inputs_builder, cache=cache) -class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]): - """ - Wraps `_limits_by_model` for a more informative error message - when attempting to access a model that does not exist. - """ - - def __getitem__(self, key: "ModelConfig") -> dict[str, int]: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " - "forget to call `init_mm_limits_per_prompt`?") - raise KeyError(msg) from exc - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. """ - DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) - - def __init__( - self, - *, - plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: - self._plugins = {p.get_data_key(): p for p in plugins} - + def __init__(self) -> None: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - # This is used for non-multimodal models - self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} - - self._limits_by_model = _MultiModalLimits() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) - def register_plugin(self, plugin: MultiModalPlugin) -> None: - """ - Register a multi-modal plugin so it can be recognized by vLLM. - """ - data_type_key = plugin.get_data_key() - - if data_type_key in self._plugins: - logger.warning( - "A plugin is already registered for data type %s, " - "and will be overwritten by the new plugin %s.", data_type_key, - plugin) - - self._plugins[data_type_key] = plugin - - def _get_plugin(self, data_type_key: str): - plugin = self._plugins.get(data_type_key) - if plugin is not None: - return plugin - - msg = f"Unknown multi-modal data type: {data_type_key}" - raise NotImplementedError(msg) - - def register_input_mapper( - self, - data_type_key: str, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for a specific modality to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self._get_plugin(data_type_key).register_input_mapper(mapper) - - def register_image_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for image data to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self.register_input_mapper("image", mapper) - - def map_input( - self, - model_config: "ModelConfig", - data: MultiModalDataDict, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ) -> MultiModalKwargs: - """ - Apply an input mapper to the data passed to the model. - - The data belonging to each modality is passed to the corresponding - plugin which in turn converts the data into into keyword arguments - via the input mapper registered for that model. - - See :meth:`MultiModalPlugin.map_input` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. - """ - merged_dict = dict[str, NestedTensors]() - - for data_key, data_value in data.items(): - plugin = self._get_plugin(data_key) - - num_items = len(data_value) if isinstance(data_value, list) else 1 - max_items = self._limits_by_model[model_config][data_key] - if num_items > max_items: - raise ValueError( - f"You set '{json.dumps({data_key: max_items})}' (or " - "defaulted to 1) in `--limit-mm-per-prompt`, but found " - f"{num_items} items in the same prompt.") - - input_dict = plugin.map_input(model_config, data_value, - mm_processor_kwargs) - for input_key, input_tensor in input_dict.items(): - if input_key in merged_dict: - raise ValueError(f"The input mappers (keys={set(data)}) " - f"resulted in a conflicting keyword " - f"argument to `forward()`: {input_key}") - - merged_dict[input_key] = input_tensor - - return MultiModalKwargs(merged_dict) - + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def create_input_mapper(self, model_config: "ModelConfig"): - """ - Create an input mapper (see :meth:`map_input`) for a specific model. - """ - # NOTE - we currently make the assumption that if a model has multiple - # supported modalities, they take the same kwargs. For the default, - # this could be an issue in the future if it falls back to two HF - # resources and we can't inspect the signature easily since it's - # getting initialized through the autoclass. - # - # If this is a problem in the future, we should revisit it, but since - # it potentially introduces a lot of complexity for a currently - # uncommon case, we do not for simplicity of both use & implementation - return functools.partial(self.map_input, model_config) - - def register_max_multimodal_tokens( - self, - data_type_key: str, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of tokens, corresponding to a single - instance of multimodal data belonging to a specific modality, that are - passed to the language model for a model class. - """ - return self._get_plugin(data_type_key) \ - .register_max_multimodal_tokens(max_mm_tokens) - - def register_max_image_tokens( - self, - max_mm_tokens: Optional[MultiModalTokensCalc] = None, - ): - """ - Register the maximum number of image tokens, corresponding to a single - image, that are passed to the language model for a model class. - """ - return self.register_max_multimodal_tokens("image", max_mm_tokens) + return lambda data, mm_processor_kwargs: data def get_max_tokens_per_item_by_modality( self, @@ -257,25 +103,22 @@ def get_max_tokens_per_item_by_modality( Get the maximum number of tokens per data item from each modality based on underlying model configuration. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - - seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) - - return profiler.get_mm_max_tokens( - seq_len, - { - modality: 1 - for modality, limit in mm_limits.items() if limit > 0 - }, - ) + if not model_config.is_multimodal_model: + return {} - return { - key: plugin.get_max_multimodal_tokens(model_config) - for key, plugin in self._plugins.items() - } + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + + seq_len = model_config.max_model_len + mm_limits = self.get_mm_limits_per_prompt(model_config) + + return profiler.get_mm_max_tokens( + seq_len, + { + modality: 1 + for modality, limit in mm_limits.items() if limit > 0 + }, + ) def get_max_tokens_per_item_by_nonzero_modality( self, @@ -308,9 +151,6 @@ def get_max_tokens_by_modality( for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ mm_limits = self.get_mm_limits_per_prompt(model_config) @@ -326,47 +166,18 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ return sum(self.get_max_tokens_by_modality(model_config).values()) + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def init_mm_limits_per_prompt( self, model_config: "ModelConfig", ) -> None: - """ - Initialize the maximum number of multi-modal input instances for each - modality that are allowed per prompt for a model class. - """ - if model_config in self._limits_by_model: - logger.warning( - "`mm_limits` has already been set for model=%s, and will " - "be overwritten by the new values.", model_config.model) - - multimodal_config = model_config.multimodal_config - if multimodal_config is None: - limits_per_plugin = self._disabled_limits_per_plugin - else: - config_limits_per_plugin = multimodal_config.limit_per_prompt - - extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() - if extra_keys: - logger.warning( - "Detected extra keys in `--limit-mm-per-prompt` which " - "are not registered as multi-modal plugins: %s. " - "They will be ignored.", extra_keys) - - # NOTE: Currently the default is set to 1 for each plugin - # TODO: Automatically determine the limits based on budget - # once more models support multi-image inputs - limits_per_plugin = { - key: multimodal_config.get_limit_per_prompt(key) - for key in self._plugins - } - - self._limits_by_model[model_config] = limits_per_plugin + pass def get_mm_limits_per_prompt( self, @@ -375,16 +186,13 @@ def get_mm_limits_per_prompt( """ Get the maximum number of multi-modal input instances for each modality that are allowed per prompt for a model class. - - Note: - This should be called after :meth:`init_mm_limits_per_prompt`. """ - if self.has_processor(model_config): - processor = self.create_processor(model_config, disable_cache=True) - profiler = MultiModalProfiler(processor) - return profiler.get_mm_limits() + if not model_config.is_multimodal_model: + return {} - return self._limits_by_model[model_config] + processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + return profiler.get_mm_limits() def register_processor( self, @@ -428,14 +236,12 @@ def _get_model_cls(self, model_config: "ModelConfig"): model_cls, _ = get_model_architecture(model_config) return model_cls + @deprecated("Legacy input processor/mapper pipeline has been removed. " + "Please update your model runner to use " + "`seq_group_metadata.multi_modal_data` directly without " + "further processing.") def has_processor(self, model_config: "ModelConfig") -> bool: - """ - Test whether a multi-modal processor is defined for a specific model. - - See also: - :ref:`mm-processing` - """ - return self._get_model_cls(model_config) in self._processor_factories + return True def create_processor( self, @@ -450,6 +256,9 @@ def create_processor( See also: :ref:`mm-processing` """ + if not model_config.is_multimodal_model: + raise ValueError(f"{model_config.model} is not a multimodal model") + if tokenizer is None: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index f7c3f1052954..6d875a1c651e 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -4,80 +4,13 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional import numpy as np import numpy.typing as npt from PIL import Image -from vllm.inputs.registry import InputContext -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_get_video_processor -from vllm.utils import is_list_of - -from .base import MediaIO, ModalityData -from .image import ImageMediaIO, ImagePlugin -from .inputs import MultiModalKwargs, VideoItem - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -logger = init_logger(__name__) - - -class VideoPlugin(ImagePlugin): - """Plugin for video data.""" - - def get_data_key(self) -> str: - return "video" - - def _get_hf_video_processor( - self, - model_config: "ModelConfig", - mm_processor_kwargs: Optional[dict[str, Any]] = None, - ): - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return cached_get_video_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - - def _default_input_mapper( - self, - ctx: InputContext, - data: ModalityData[VideoItem], - **mm_processor_kwargs, - ) -> MultiModalKwargs: - model_config = ctx.model_config - - if isinstance(data, list) and len(data) == 1: - data = data[0] # type: ignore - - if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): - video_processor = self._get_hf_video_processor( - model_config, - mm_processor_kwargs, - ) - if video_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the video object") - try: - # NOTE: Similar to image; it may be a good idea to filter and - # pass mm_processor_kwargs here too, but for now we don't to - # avoid extra complexity if the initializer and preprocess - # signatures of the processor don't align - batch_data = video_processor(data, return_tensors="pt").data - except Exception: - logger.error("Failed to process video (%s)", data) - raise - - return MultiModalKwargs(batch_data) - - raise TypeError(f"Invalid video type: {type(data)}") - - def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: - return 4096 +from .base import MediaIO +from .image import ImageMediaIO def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: diff --git a/vllm/sequence.py b/vllm/sequence.py index 61867b025315..a97409523c94 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,9 +14,9 @@ import msgspec import torch -from vllm.inputs import SingletonInputs, SingletonInputsAdapter +from vllm.inputs import SingletonInputs from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -419,7 +419,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id - self.inputs = SingletonInputsAdapter(inputs) + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request @@ -448,31 +448,29 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: - return self.inputs.prompt + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> list[int]: - return self.inputs.prompt_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self.inputs.prompt_embeds + return self.inputs["prompt_token_ids"] @property def token_type_ids(self) -> list[int]: - return self.inputs.token_type_ids + return self.inputs.get("token_type_ids", []) @property - def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.multi_modal_data + def multi_modal_data(self) -> MultiModalKwargs: + if self.inputs["type"] == "multimodal": + return self.inputs["mm_kwargs"] + + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - return self.inputs.multi_modal_placeholders + if self.inputs["type"] == "multimodal": + return self.inputs["mm_placeholders"] - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - return self.inputs.mm_processor_kwargs + return {} @property def lora_int_id(self) -> int: @@ -723,12 +721,12 @@ def token_type_ids(self) -> Optional[list[int]]: return self.first_seq.token_type_ids @property - def multi_modal_data(self) -> MultiModalDataDict: + def multi_modal_data(self) -> MultiModalKwargs: if self.first_seq.multi_modal_data: return self.first_seq.multi_modal_data elif self.encoder_seq is not None: return self.encoder_seq.multi_modal_data - return {} + return MultiModalKwargs({}) @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: @@ -738,14 +736,6 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.encoder_seq.multi_modal_placeholders return {} - @property - def mm_processor_kwargs(self) -> dict[str, Any]: - if self.first_seq.multi_modal_data: - return self.first_seq.mm_processor_kwargs - elif self.encoder_seq is not None: - return self.encoder_seq.mm_processor_kwargs - return {} - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -969,12 +959,9 @@ class SequenceGroupMetadata( computed_block_nums: Optional[list[int]] = None state: Optional[SequenceGroupState] = msgspec.field( default_factory=lambda: SequenceGroupState()) - # "MultiModalDataDict" types. We have to use Any due to msgspec - # doesn't allow to have union of 2 different dicts. token_type_ids: Optional[list[int]] = None - multi_modal_data: Optional[Any] = None + multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - mm_processor_kwargs: Optional[dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[list[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index ed2f4b076ded..4f06950c42e2 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -208,38 +208,3 @@ def cached_image_processor_from_config( trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, **kwargs), ) - - -def get_video_processor( - processor_name: str, - *args: Any, - trust_remote_code: bool = False, - **kwargs: Any, -): - """Load a video processor for the given model name via HuggingFace.""" - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers.image_processing_utils import BaseImageProcessor - - processor = get_processor( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs, - ) - - return cast(BaseImageProcessor, processor.video_processor) - - -cached_get_video_processor = lru_cache(get_video_processor) - - -def cached_video_processor_from_config( - model_config: "ModelConfig", - **kwargs: Any, -): - return cached_get_video_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **_merge_mm_kwargs(model_config, **kwargs), - ) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 87b7f02ab6db..710ca1a13b0c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -22,8 +22,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap) +from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, + MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.worker.model_runner_base import ( @@ -154,7 +154,6 @@ def __init__(self, self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None if self.runner.attn_backend is not None: # spec decode (e.g. Medusa) does not have atten backend @@ -359,22 +358,14 @@ def _compute_multi_modal_input(self, computed_len = seq_data.get_num_computed_tokens() seq_len = self.input_data.seq_lens[-1] - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(computed_len, seq_len)) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - # special processing for mrope position deltas. if self.runner.model_config.uses_mrope: assert not self.chunked_prefill, \ @@ -480,12 +471,6 @@ def __init__( use_mla=self.model_config.use_mla, ) if needs_attn_backend else None - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) - # Lazy initialization. self.model: nn.Module # Set after init_Model # Set after load_model. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index dee59041ec6f..4df192a8727c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -100,6 +100,8 @@ def __init__( vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, + input_registry=input_registry, + mm_registry=mm_registry, ) # Crash for unsupported encoder/scenarios diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7b606272adf8..e25864349e28 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -45,8 +45,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SequenceGroupToSample -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, @@ -545,10 +544,6 @@ def _set_gc_threshold(self) -> None: ] gc.set_threshold(*requested_gc_thrs) - # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' @@ -731,9 +726,8 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.append(list(range(context_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 66b12d5be1a2..73e0eff9a8b7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -457,7 +457,6 @@ def __init__(self, self.enable_lora = self.runner.lora_config is not None self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) - self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper # Attention metadata inputs. if self.attn_backend is not None: @@ -675,23 +674,15 @@ def _compute_prompt_adapter_input( def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - # NOTE: mm_data only includes the subset of multi-modal items that + # NOTE: mm_kwargs only includes the subset of multi-modal items that # intersect with the current prefill positions. positions = inter_data.input_positions[0] - mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(positions[0], positions[0] + len(positions))) - if not mm_data: + if not mm_kwargs: return - if self.runner.mm_registry.has_processor(self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps @@ -1085,9 +1076,6 @@ def __init__( # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model @@ -1327,8 +1315,8 @@ def _dummy_run(self, dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) + seq_len, + self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f2093fc42ad1..e046ebc449de 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -15,8 +15,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs) +from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -69,11 +68,6 @@ def __init__( self.device = self.device_config.device self.pin_memory = is_pin_memory_available() - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -149,16 +143,8 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - if self.mm_registry.has_processor(self.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - + mm_kwargs = seq_group_metadata.multi_modal_data + if mm_kwargs: multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index fc9461a24b90..7042b575aa78 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -188,20 +188,11 @@ def _prepare_prompt( input_positions.extend(list(positions_range)) if seq_group_metadata.multi_modal_data: - # NOTE: mm_data only includes the subset of multi-modal items + # NOTE: mm_kwargs only includes the subset of multi-modal items # that intersect with the current prefill positions. - mm_data, placeholder_maps = MultiModalPlaceholderMap \ + mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - if self.runner.mm_registry.has_processor( - self.runner.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.runner.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): @@ -404,9 +395,6 @@ def __init__( # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry \ - .create_input_mapper(model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model