diff --git a/vllm/config.py b/vllm/config.py index 3eca09b232a7..5f61c2d592c9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2405,6 +2405,15 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode()).hexdigest() return hash_str + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality. + + If not set by the user, this defaults to `1`. + """ + return self.limit_per_prompt.get(modality, 1) + # TODO: Add configs to init vision tower or not. diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 532400b3b425..fd5d5a564b5e 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -14,7 +14,6 @@ from transformers import BatchFeature from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -25,8 +24,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, @@ -42,8 +41,6 @@ init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -logger = init_logger(__name__) - # The image token id may be various _IMAGE_TOKEN = "" @@ -216,30 +213,6 @@ def get_dummy_processor_inputs( class DeepseekVL2MultiModalProcessor( BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): - def __init__( - self, - info: DeepseekVL2ProcessingInfo, - dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]", - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__( - info, - dummy_inputs, - cache=cache, - enable_sanity_checks=enable_sanity_checks, - ) - - mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt - if self.cache is not None and mm_limit["image"] > 2: - # The processor output depends on the number of images passed, - # making it incompatible with processing cache which is supposed - # to be invariant of how many images are passed per prompt - self.cache = None - logger.warning_once( - f"{type(self).__name__} does not support processing cache with " - "image limit larger than 2.") - def _call_hf_processor( self, prompt: str, @@ -316,6 +289,31 @@ def get_replacement_deepseek_vl2(item_idx: int): ) ] + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + # The processor logic is different for len(images) <= 2 vs > 2 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 2: + # This code path corresponds to the cache being disabled + return self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + @MULTIMODAL_REGISTRY.register_processor( DeepseekVL2MultiModalProcessor, diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index d336d7521a27..e23765cc4fb5 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -8,21 +8,19 @@ # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Mapping, Sequence -from typing import Optional +from typing import Optional, Union import torch from PIL import Image from transformers import PretrainedConfig -from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - PromptUpdate, PromptUpdateDetails) -from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -32,8 +30,6 @@ InternVLMultiModalProcessor, build_transform, find_closest_aspect_ratio, get_internvl_target_ratios) -logger = init_logger(__name__) - def resolve_h2ovl_min_max_num( *, @@ -465,29 +461,6 @@ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int: class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ): - def __init__(self, - info: H2OVLProcessingInfo, - dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]", - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__( - info, - dummy_inputs, - cache=cache, - enable_sanity_checks=enable_sanity_checks, - ) - - mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt - if self.cache is not None and mm_limit["image"] >= 2: - # The processor output depends on the number of images passed, - # making it incompatible with processing cache which is supposed - # to be invariant of how many images are passed per prompt - self.cache = None - logger.warning_once( - f"{type(self).__name__} does not support processing cache with " - "multi-image support enabled.") - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -543,6 +516,31 @@ def get_replacement_internvl(item_idx: int): ) ] + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs, bool]: + # The processor logic is different for len(images) <= 1 vs > 1 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 1: + # This code path corresponds to the cache being disabled + return self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + @MULTIMODAL_REGISTRY.register_processor( H2OVLMultiModalProcessor, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 508b47d13519..d974c3d22409 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -133,7 +133,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() - max_videos = mm_config.limit_per_prompt.get("video", 1) + max_videos = mm_config.get_limit_per_prompt("video") max_total_frames = self._get_max_video_frames(seq_len) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e87ef24ce2ca..f41f45e3e409 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -206,8 +206,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) - max_videos = mm_config.limit_per_prompt.get("video", 1) + max_images = mm_config.get_limit_per_prompt("image") + max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index f35c230c0cea..bf6c38d27963 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -201,9 +201,9 @@ def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) - max_videos = mm_config.limit_per_prompt.get("video", 1) - max_audios = mm_config.limit_per_prompt.get("audio", 1) + max_images = mm_config.get_limit_per_prompt("image") + max_videos = mm_config.get_limit_per_prompt("video") + max_audios = mm_config.get_limit_per_prompt("audio") # count tokens # which are not in get_max_image_tokens diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index cf103edd0bcc..48f0c09cdfb3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -446,8 +446,8 @@ def get_max_video_frames(self, max_tokens: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) - max_videos = mm_config.limit_per_prompt.get("video", 1) + max_images = mm_config.get_limit_per_prompt("image") + max_videos = mm_config.get_limit_per_prompt("video") # count tokens # which are not in get_max_image_tokens diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 8acc07ac353a..f17f9fb8e0c7 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_token_id = mm_encoder.special_ids.img mm_config = ctx.get_mm_config() - num_images = mm_config.limit_per_prompt.get("image", 1) + num_images = mm_config.get_limit_per_prompt("image") # dummy size size = 256 diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 523b53d5ee41..ac3d154dd881 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -911,8 +911,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int: mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) - max_videos = mm_config.limit_per_prompt.get("video", 1) + max_images = mm_config.get_limit_per_prompt("image") + max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 3f13cd8582fe..ba8a458e84c8 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -984,10 +984,10 @@ def _to_mm_items( before passing them to :meth:`_get_hf_mm_data`. """ mm_items = self.data_parser.parse_mm_data(mm_data) + mm_config = self.info.ctx.get_mm_config() - mm_limits = self.info.ctx.get_mm_config().limit_per_prompt for modality, items in mm_items.items(): - limit = mm_limits.get(modality, 1) + limit = mm_config.get_limit_per_prompt(modality) if len(items) > limit: raise ValueError( f"You set {modality}={limit} (or defaulted to 1) in " diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 3178b0f8c3e6..57f045dae6bd 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -110,12 +110,10 @@ def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: def get_mm_limits(self) -> Mapping[str, int]: mm_config = self.processing_info.ctx.get_mm_config() - mm_limit_per_prompt = mm_config.limit_per_prompt - supported_mm_limits = self.processing_info.get_supported_mm_limits() mm_limits = { - modality: mm_limit_per_prompt.get(modality, 1) + modality: mm_config.get_limit_per_prompt(modality) for modality in supported_mm_limits } diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index a9eb250cb877..4987cdc4a2e8 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -355,7 +355,7 @@ def init_mm_limits_per_prompt( # TODO: Automatically determine the limits based on budget # once more models support multi-image inputs limits_per_plugin = { - key: config_limits_per_plugin.get(key, 1) + key: multimodal_config.get_limit_per_prompt(key) for key in self._plugins }