From 9e4e312f780a8ff3478dbf0a4f1d9705b0eb92b8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 16:41:49 +0000 Subject: [PATCH 01/23] [V1] Scatter and gather placeholders in the model runner Signed-off-by: DarkLight1337 --- docs/source/contributing/model/multimodal.md | 16 +-- vllm/model_executor/models/chameleon.py | 6 +- vllm/model_executor/models/fuyu.py | 6 +- vllm/model_executor/models/gemma3_mm.py | 69 +++--------- vllm/model_executor/models/h2ovl.py | 2 +- vllm/model_executor/models/idefics3.py | 69 ++---------- vllm/model_executor/models/internvl.py | 43 +------- vllm/model_executor/models/llava.py | 55 ++-------- vllm/model_executor/models/minicpmo.py | 68 ++---------- vllm/model_executor/models/minicpmv.py | 106 +++---------------- vllm/model_executor/models/molmo.py | 60 ++--------- vllm/model_executor/models/nvlm_d.py | 9 +- vllm/model_executor/models/paligemma.py | 6 +- vllm/model_executor/models/phi3v.py | 11 +- vllm/model_executor/models/pixtral.py | 42 ++------ vllm/model_executor/models/qwen2_audio.py | 6 +- vllm/model_executor/models/qwen_vl.py | 6 +- vllm/model_executor/models/vision.py | 77 +------------- vllm/multimodal/inputs.py | 6 ++ vllm/multimodal/processing.py | 73 ++++++++----- vllm/v1/worker/gpu_model_runner.py | 56 ++++++---- vllm/v1/worker/tpu_model_runner.py | 81 ++++++++++---- vllm/v1/worker/utils.py | 45 ++++++++ 23 files changed, 297 insertions(+), 621 deletions(-) diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 9cbfc32991f0..c4894d39edc9 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( ) ``` -To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails` -with different `full` and `feature` attributes: +To assign the vision embeddings to only the image tokens, instead of a string +you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`: ```python hf_config = self.info.get_hf_config() @@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptUpdateDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, ) ``` @@ -914,9 +914,9 @@ def _get_prompt_updates( image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptUpdateDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, ) return [ diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index ebcd36148e07..3b47b0db20bc 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -161,9 +161,9 @@ def _get_prompt_updates( PromptReplacement( modality="image", target=[image_token_id], - replacement=PromptUpdateDetails( - full=([image_start_id] + image_tokens + [image_end_id]), - features=image_tokens, + replacement=PromptUpdateDetails.select_token_id( + [image_start_id] + image_tokens + [image_end_id], + embed_token_id=image_token_id, ), ) ] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index a807b047a1aa..7f33e1f17d56 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -252,9 +252,9 @@ def get_replacement_fuyu(item_idx: int): image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows - return PromptUpdateDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, ) return [ diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index bbdea70a7bcf..eb41ba94cfaa 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -36,7 +36,6 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features logger = init_logger(__name__) @@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict): num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - Gemma3ImageInputs = Gemma3ImagePixelInputs @@ -183,7 +174,7 @@ def get_image_repl( if processor is None: processor = self.get_hf_processor() - image_token = processor.boi_token + boi_token = processor.boi_token num_crops = self.get_num_crops( image_width=image_width, @@ -192,19 +183,21 @@ def get_image_repl( ) if num_crops == 0: - image_text = image_token + image_text = boi_token else: - crops_image_tokens = " ".join(image_token - for _ in range(num_crops)) + crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) image_text = ( - f"Here is the original image {image_token} and here are some " + f"Here is the original image {boi_token} and here are some " f"crops to help you see better {crops_image_tokens}") - repl_full = image_text.replace(image_token, + repl_full = image_text.replace(boi_token, processor.full_image_sequence) - repl_features = repl_full.strip("\n") - return PromptUpdateDetails(full=repl_full, features=repl_features) + tokenizer = processor.tokenizer + vocab = tokenizer.get_vocab() + image_token_id = vocab[tokenizer.image_token] + + return PromptUpdateDetails.select_token_id(repl_full, image_token_id) def get_num_image_tokens( self, @@ -222,7 +215,7 @@ def get_num_image_tokens( image_repl_tokens = encode_tokens( tokenizer, - image_repl.features, + image_repl.full, add_special_tokens=False, ) return len(image_repl_tokens) @@ -301,28 +294,6 @@ def _call_hf_processor( ] hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_repl_features = [ - self.info.get_image_repl(image_width=size.width, - image_height=size.height, - processor=hf_processor).features - for size in image_sizes - ] - - tokenizer = self.info.get_tokenizer() - image_repls_feature_tokens = [ - tokenizer.encode(image_repl, add_special_tokens=False) - for image_repl in image_repl_features - ] - - vocab = tokenizer.get_vocab() - image_token_id = vocab[tokenizer.image_token] - - embed_is_patch = [ - torch.tensor(image_repl_tokens) == image_token_id - for image_repl_tokens in image_repls_feature_tokens - ] - processed_outputs["embed_is_patch"] = embed_is_patch - num_crops = [ self.info.get_num_crops(image_width=size.width, image_height=size.height, @@ -344,7 +315,6 @@ def _get_mm_fields_config( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", num_crops + 1), num_crops=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -454,6 +424,7 @@ def get_repl_toks(tok: int) -> list[int]: item_idx=p.item_idx, start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, + is_embed=p.is_embed, ) for p in placeholders ] for modality, placeholders in repls.items() @@ -572,7 +543,6 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) num_crops = kwargs.pop("num_crops", None) - embed_is_patch = kwargs.pop("embed_is_patch", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Gemma3 does not support image_embeds." if pixel_values is None: @@ -586,19 +556,13 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of num_crops. " f"Got type: {type(num_crops)}") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) - embed_is_patch = flatten_bn(embed_is_patch) return Gemma3ImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values(pixel_values), num_patches=num_crops + 1, - embed_is_patch=embed_is_patch, ) def _image_pixels_to_features( @@ -635,12 +599,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -652,7 +611,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.config.image_token_index, ) return inputs_embeds diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 3b2ad695f83e..bd9ad23fdee1 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -257,7 +257,7 @@ def get_image_repl( repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END - return PromptUpdateDetails(full=repl_full, features=repl_features) + return PromptUpdateDetails.select_token_text(repl_full, IMG_CONTEXT) def resolve_min_max_num( self, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index da4a44346c32..8fe3306e8ec7 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -41,7 +41,7 @@ MultiModalDataItems, MultiModalFieldConfig, PromptReplacement, PromptUpdate, - encode_tokens) + PromptUpdateDetails, encode_tokens) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -54,7 +54,6 @@ from .llama import LlamaModel from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features class Idefics3ImagePixelInputs(TypedDict): @@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict): num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class Idefics3ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict): `hidden_size` must match the hidden size of language model backbone. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] @@ -364,28 +347,6 @@ def _call_hf_processor( ] hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_repl_features = [ - self.info.get_image_repl(image_width=size.width, - image_height=size.height, - processor=hf_processor) - for size in image_sizes - ] - - tokenizer = self.info.get_tokenizer() - image_repls_feature_tokens = [ - tokenizer.encode(image_repl, add_special_tokens=False) - for image_repl in image_repl_features - ] - - vocab = tokenizer.get_vocab() - image_token_id = vocab[hf_processor.image_token.content] - - embed_is_patch = [ - torch.tensor(image_repl_tokens) == image_token_id - for image_repl_tokens in image_repls_feature_tokens - ] - processed_outputs["embed_is_patch"] = embed_is_patch - num_patches = [ self.info.get_num_patches( image_width=size.width, @@ -415,7 +376,6 @@ def _get_mm_fields_config( "image", num_patches), image_embeds=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -427,17 +387,22 @@ def _get_prompt_updates( hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token.content - def get_replacement_idefics3(item_idx: int) -> str: + def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails: images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - return self.info.get_image_repl( + image_repl = self.info.get_image_repl( image_width=image_size.width, image_height=image_size.height, processor=hf_processor, ) + return PromptUpdateDetails.select_token_text( + image_repl, + embed_token_text=image_token, + ) + return [ PromptReplacement( modality="image", @@ -675,13 +640,6 @@ def _parse_and_validate_image_input( if pixel_values is None and image_embeds is None: return None - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - - embed_is_patch = flatten_bn(embed_is_patch) - if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " @@ -690,7 +648,6 @@ def _parse_and_validate_image_input( return Idefics3ImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds, concat=True), - embed_is_patch=embed_is_patch, ) if pixel_values is not None: @@ -718,7 +675,6 @@ def _parse_and_validate_image_input( pixel_values=self._validate_pixel_values(pixel_values), pixel_attention_mask=pixel_attention_mask, num_patches=num_patches, - embed_is_patch=embed_is_patch, ) raise AssertionError("This line should be unreachable.") @@ -754,12 +710,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -771,7 +722,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.config.image_token_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 0729f4c7d203..467673fdf825 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -39,7 +39,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features IMG_START = '' IMG_END = '' @@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict): num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -419,24 +410,12 @@ def __call__( torch.tensor([len(item) for item in pixel_values_lst]), } - tokenizer = self.tokenizer - image_token_id = self.image_token_id - - embed_is_patch = list[torch.Tensor]() - for pixel_values in pixel_values_lst: num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - feature_tokens = tokenizer.encode(image_repl.features, - add_special_tokens=False) - text = [t.replace('', image_repl.full, 1) for t in text] - embed_is_patch.append( - torch.tensor(feature_tokens) == image_token_id) - - image_inputs["embed_is_patch"] = embed_is_patch text_inputs = self.tokenizer(text) @@ -460,7 +439,7 @@ def get_image_repl( repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END - return PromptUpdateDetails(full=repl_full, features=repl_features) + return PromptUpdateDetails.select_token_text(repl_full, IMG_CONTEXT) class BaseInternVLProcessingInfo(BaseProcessingInfo): @@ -599,7 +578,6 @@ def _get_mm_fields_config( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -831,7 +809,6 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) - embed_is_patch = kwargs.pop("embed_is_patch", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values_flat is None and image_embeds is None: @@ -860,20 +837,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image_num_patches. " f"Got type: {type(image_num_patches)}") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) - embed_is_patch = flatten_bn(embed_is_patch) return InternVLImagePixelInputs( type="pixel_values", pixel_values_flat=self._validate_pixel_values( pixel_values_flat), num_patches=image_num_patches, - embed_is_patch=embed_is_patch, ) raise AssertionError("This line should be unreachable.") @@ -919,15 +890,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - if image_input["type"] != "pixel_values": - return image_features - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -941,7 +904,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.img_context_token_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 45a0bf73b837..b34ac38f6807 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -32,7 +32,8 @@ ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -42,8 +43,7 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import (get_vision_encoder_info, scatter_patch_features, - select_patch_features) +from .vision import get_vision_encoder_info class LlavaImagePixelInputs(TypedDict): @@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict): in which case the data is passed as a list instead of a batched tensor. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class LlavaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -343,23 +335,6 @@ def _call_hf_processor( for p, (h, w) in zip(pixel_values, image_sizes) ] - hf_config = self.info.get_hf_config() - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) - encoder_info = PixtralHFEncoderInfo(vision_config) - - tile_sizes = [ - encoder_info.get_patch_grid_size( - image_width=pixel_value.shape[-1], - image_height=pixel_value.shape[-2], - ) for pixel_value in processed_outputs["pixel_values"] - ] - embed_is_patch = [ - torch.tensor(([True] * ncols + [False]) * nrows) - for ncols, nrows in tile_sizes - ] - processed_outputs["embed_is_patch"] = embed_is_patch - return processed_outputs def _get_mm_fields_config( @@ -369,7 +344,6 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -404,7 +378,7 @@ def get_replacement(item_idx: int): tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id - return tokens + return PromptUpdateDetails.select_token_id(tokens, image_token_id) return [ PromptReplacement( @@ -612,17 +586,9 @@ def _parse_and_validate_image_input( f"Got type: {type(pixel_values)}") if self.config.vision_config.model_type == "pixtral": - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - - embed_is_patch = flatten_bn(embed_is_patch) - return PixtralHFImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), - embed_is_patch=embed_is_patch, ) return LlavaImagePixelInputs( @@ -714,16 +680,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - if image_input["type"] != "pixel_values_pixtral": - # The path is used for pixtral (V0 only) and llava (V0/V1) - return image_features - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -735,7 +692,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.config.image_token_index, ) return inputs_embeds diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index c74e086d3748..5fe683225094 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -40,7 +40,8 @@ DictEmbeddingItems, ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import PromptReplacement, PromptUpdate +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import ProcessorInputs from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, @@ -50,7 +51,6 @@ _minicpmv_field_config) from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix) -from .vision import scatter_patch_features CPU_DEVICE = torch.device("cpu") @@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict): which equals to `audio_features.shape[-1]` """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which audio embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_audios, num_embeds)` - """ - class MiniCPMOAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] @@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict): Length of each slice may vary, so pass it as a list. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which audio embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_audios, num_embeds)` - """ - MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioEmbeddingInputs] @@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): audio_features=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), - audio_embed_is_patch=MultiModalFieldConfig.batched("audio"), audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), ) @@ -295,13 +278,6 @@ def process_audios( if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): audio_inputs = {} - - audio_lens = [ - self.info.get_audio_len_by_num_chunks( - sum(map(len, - parsed_audios.get(i)["audio_embeds"]))) - for i in range(len(parsed_audios)) - ] else: audio_inputs = self._base_call_hf_processor( prompts=[self.info.audio_pattern] * len(parsed_audios), @@ -323,27 +299,7 @@ def process_audios( ] audio_inputs["audio_features"] = unpadded_audio_features - audio_lens = [ - parsed_audios.get_audio_length(i) - for i in range(len(parsed_audios)) - ] - - audio_repl_features = [ - self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens - ] - tokenizer = self.info.get_tokenizer() - audio_repls_feature_tokens = [ - tokenizer.encode(audio_repl, add_special_tokens=False) - for audio_repl in audio_repl_features - ] - - embed_is_patch = [ - self.get_embed_is_patch(audio_repl_tokens) - for audio_repl_tokens in audio_repls_feature_tokens - ] - audio_inputs["audio_embed_is_patch"] = embed_is_patch - unk_token_id = tokenizer.get_vocab()[""] audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) @@ -384,7 +340,10 @@ def get_audio_replacement(item_idx: int): else: audio_len = audios.get_audio_length(item_idx) - return self.get_audio_prompt_texts(audio_len) + return PromptUpdateDetails.select_token_text( + self.get_audio_prompt_texts(audio_len), + "", + ) return [ *base_updates, @@ -713,13 +672,6 @@ def _parse_and_validate_audio_input( assert isinstance(audio_token_id, torch.Tensor) self.mm_token_ids.add(audio_token_id.flatten().unique().item()) - audio_embed_is_patch = kwargs.pop("audio_embed_is_patch") - if not isinstance(audio_embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_embed_is_patch. " - f"Got type: {type(audio_embed_is_patch)}") - - audio_embed_is_patch = flatten_bn(audio_embed_is_patch) - if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of audio_embeds. " @@ -730,7 +682,6 @@ def _parse_and_validate_audio_input( return MiniCPMOAudioEmbeddingInputs( type="audio_embeds", audio_embeds=audio_embeds_flat, - embed_is_patch=audio_embed_is_patch, ) if not isinstance(audio_features, (torch.Tensor, list)): @@ -749,7 +700,6 @@ def _parse_and_validate_audio_input( type="audio_features", audio_features=audio_features_flat, audio_feature_lens=audio_feature_lens_flat, - embed_is_patch=audio_embed_is_patch, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -781,10 +731,6 @@ def _process_multimodal_inputs(self, modalities: dict): if modality == "audios": audio_input = modalities["audios"] audio_features = self._process_audio_input(audio_input) - multimodal_embeddings += tuple( - scatter_patch_features( - audio_features, - audio_input["embed_is_patch"], - )) + multimodal_embeddings += tuple(audio_features) return multimodal_embeddings diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5fab9df3f8f9..25f16eeb6932 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -56,7 +56,7 @@ VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -67,7 +67,6 @@ SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -90,14 +89,6 @@ class MiniCPMVImagePixelInputs(TypedDict): This should be in `(height, width)` format. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - num_slices: torch.Tensor """Shape: `(batch_size * num_images)`""" @@ -112,14 +103,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): instead of a batched tensor. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] @@ -245,12 +228,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): image_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), video_pixel_values=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"), - video_embed_is_patch=MultiModalFieldConfig.batched("video"), image_token_id=MultiModalFieldConfig.shared("image", num_images), video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -539,14 +520,6 @@ def get_video_prompt_texts(self, image_size: ImageSize, use_image_id=False, ) * num_frames - def get_embed_is_patch( - self, - input_ids: list[int], - ) -> torch.Tensor: - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()[""] - return torch.tensor(input_ids) == unk_token_id - def process_images( self, mm_data: Mapping[str, object], @@ -570,26 +543,7 @@ def process_images( out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) - image_sizes = [ - parsed_images.get_image_size(i) for i in range(len(parsed_images)) - ] - image_repl_features = [ - self.get_image_prompt_texts(size, idx) - for idx, size in enumerate(image_sizes) - ] - tokenizer = self.info.get_tokenizer() - image_repls_feature_tokens = [ - tokenizer.encode(image_repl, add_special_tokens=False) - for image_repl in image_repl_features - ] - - embed_is_patch = [ - self.get_embed_is_patch(image_repl_tokens) - for image_repl_tokens in image_repls_feature_tokens - ] - image_inputs["embed_is_patch"] = embed_is_patch - unk_token_id = tokenizer.get_vocab()[""] image_inputs["image_token_id"] = torch.tensor(unk_token_id) @@ -625,31 +579,9 @@ def process_videos( out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) - frame_sizes = [ - parsed_videos.get_frame_size(i) for i in range(len(parsed_videos)) - ] - num_frames = [ - parsed_videos.get_num_frames(i) for i in range(len(parsed_videos)) - ] - video_repl_features = [ - self.get_video_prompt_texts(size, nframes) - for size, nframes in zip(frame_sizes, num_frames) - ] - - tokenizer = self.info.get_tokenizer() - video_repls_feature_tokens = [ - tokenizer.encode(video_repl, add_special_tokens=False) - for video_repl in video_repl_features - ] - - embed_is_patch = [ - self.get_embed_is_patch(video_repl_tokens) - for video_repl_tokens in video_repls_feature_tokens - ] - video_inputs["embed_is_patch"] = embed_is_patch - video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} + tokenizer = self.info.get_tokenizer() unk_token_id = tokenizer.get_vocab()[""] video_inputs["video_token_id"] = torch.tensor(unk_token_id) @@ -740,7 +672,10 @@ def get_image_replacement(item_idx: int): image_size = images.get_image_size(item_idx) - return self.get_image_prompt_texts(image_size, item_idx) + return PromptUpdateDetails.select_token_text( + self.get_image_prompt_texts(image_size, item_idx), + "", + ) def get_video_replacement(item_idx: int): videos = mm_items.get_items( @@ -749,7 +684,10 @@ def get_video_replacement(item_idx: int): frame_size = videos.get_frame_size(item_idx) num_frames = videos.get_num_frames(item_idx) - return self.get_video_prompt_texts(frame_size, num_frames) + return PromptUpdateDetails.select_token_text( + self.get_video_prompt_texts(frame_size, num_frames), + "", + ) get_replacement = { "image": get_image_replacement, @@ -832,14 +770,6 @@ def _parse_and_validate_vision_input( assert isinstance(image_token_id, torch.Tensor) self.mm_token_ids.add(image_token_id.flatten().unique().item()) - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of embed_is_patch for {modality=}. " - f"Got type: {type(embed_is_patch)}") - - embed_is_patch = flatten_bn(embed_is_patch) - if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError( @@ -851,7 +781,6 @@ def _parse_and_validate_vision_input( return MiniCPMVImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds_flat, - embed_is_patch=embed_is_patch, ) if not isinstance(pixel_values, (torch.Tensor, list)): @@ -879,7 +808,6 @@ def _parse_and_validate_vision_input( type="pixel_values", pixel_values=pixel_values_flat, tgt_sizes=tgt_sizes_flat, - embed_is_patch=embed_is_patch, num_slices=num_slices_flat, ) @@ -936,19 +864,11 @@ def _process_multimodal_inputs(self, modalities: dict): if modality == "images": image_input = modalities["images"] image_features = self._process_vision_input(image_input) - multimodal_embeddings += tuple( - scatter_patch_features( - image_features, - image_input["embed_is_patch"], - )) + multimodal_embeddings += tuple(image_features) if modality == "videos": video_input = modalities["videos"] video_features = self._process_vision_input(video_input) - multimodal_embeddings += tuple( - scatter_patch_features( - video_features, - video_input["embed_is_patch"], - )) + multimodal_embeddings += tuple(video_features) return multimodal_embeddings @@ -971,7 +891,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, list(self.mm_token_ids), ) return inputs_embeds diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index b2f795155f17..1c6d9b4bbf42 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -46,7 +46,8 @@ MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate) + PromptInsertion, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -56,7 +57,6 @@ is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -84,14 +84,6 @@ class MolmoImageInputs(TypedDict): Shape: `(batch_size * num_images, num_crops, num_patch)` """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - num_crops: torch.Tensor """Shape: `(batch_size * num_images)`""" @@ -1146,30 +1138,6 @@ def __call__( if image_input_idx is not None: feat_is_patch = image_input_idx >= 0 - input_is_embed = torch.isin( - input_ids, - torch.tensor([ - self.image_patch_id, - self.im_col_id, - self.im_start_id, - self.im_end_id, - ]), - ) - embed_ids = input_ids[input_is_embed] - embed_is_patch = embed_ids == self.image_patch_id - assert embed_is_patch.sum() == feat_is_patch.sum() - - # image_tokens = extra_joint + joint - # Both `extra_joint` and `joint` have `im_start_id` and `im_end_id` - embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0] - embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0] - assert len(embed_start) == len(embed_end) == len(images) - - embed_is_patch = [ - embed_is_patch[start:end + 1] - for start, end in zip(embed_start, embed_end) - ] - tilings = [ self.select_tiling( image_width=image.size[0], @@ -1181,7 +1149,6 @@ def __call__( assert num_crops.sum() == len(feat_is_patch) outputs["feat_is_patch"] = feat_is_patch - outputs["embed_is_patch"] = embed_is_patch outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id @@ -1328,7 +1295,6 @@ def _get_mm_fields_config( "image", num_crops), feat_is_patch=MultiModalFieldConfig.flat_from_sizes( "image", num_crops), - embed_is_patch=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1368,8 +1334,10 @@ def get_insertion_molmo(item_idx: int): joint = ([img_start_id] + joint_row * ((nrows + 1) // pooling_size) + [img_end_id]) - image_tokens = extra_joint + joint - return image_tokens + return PromptUpdateDetails.select_token_id( + extra_joint + joint, + embed_token_id=img_patch_id, + ) return [ PromptInsertion( @@ -1475,11 +1443,6 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of feat_is_patch. " f"Got type: {type(feat_is_patch)}") - embed_is_patch = kwargs.pop("embed_is_patch", None) - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - num_crops = kwargs.pop("num_crops", None) if not isinstance(num_crops, (torch.Tensor, list)): raise ValueError("Incorrect type of num_crops. " @@ -1491,14 +1454,12 @@ def _parse_and_validate_image_input( f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() - embed_is_patch = flatten_bn(embed_is_patch) num_crops = flatten_bn(num_crops, concat=True) return MolmoImageInputs( images=images, image_masks=image_masks, feat_is_patch=feat_is_patch, - embed_is_patch=embed_is_patch, num_crops=num_crops, ) @@ -1537,12 +1498,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -1556,7 +1512,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.img_patch_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 9d04f30c8f3f..ffe5399ba7d5 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -57,7 +57,7 @@ def get_image_repl( # when trying to find "" - return PromptUpdateDetails(full=repl, features=repl) + return PromptUpdateDetails.select_token_text(repl, IMG_PAD) class NVLMProcessingInfo(BaseInternVLProcessingInfo): @@ -175,12 +175,7 @@ def get_replacement_nvlm(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - repl = hf_processor.get_image_repl(feature_size, num_patches) - - return PromptUpdateDetails( - full=repl.full + "\n", - features=repl.features + "\n", - ) + return hf_processor.get_image_repl(feature_size, num_patches) # See note in dummy data regarding why we have the extra newline return [ diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 6fedb8c81984..845f77ac39ce 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -162,9 +162,9 @@ def _get_prompt_updates( modality="image", target=PromptIndexTargets.prefix( [bos_token_id] if tokenizer.add_bos_token else []), - insertion=PromptUpdateDetails( - full=image_tokens + [bos_token_id], - features=image_tokens, + insertion=PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=image_token_id, ), ) ] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d5c64989e64d..d3b0688f21c3 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -40,8 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BoundPromptUpdate, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + PromptReplacement, PromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -443,12 +442,7 @@ def get_replacement_phi3v(item_idx: int): processor=hf_processor, ) - image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens - - return PromptUpdateDetails( - full=image_tokens, - features=image_tokens, - ) + return [_IMAGE_TOKEN_ID] * num_image_tokens num_images = mm_items.get_count("image", strict=False) @@ -517,6 +511,7 @@ def _apply_prompt_updates( item_idx=p.item_idx, start_idx=p.start_idx - 1, tokens=p.tokens, + is_embed=p.is_embed, ) for p in ps ] for modality, ps in placeholders.items() diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index da2017c987d4..597b28d422e5 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -37,7 +37,7 @@ MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, @@ -46,8 +46,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs, - scatter_patch_features, select_patch_features) +from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs try: from xformers import ops as xops @@ -68,14 +67,6 @@ class PixtralImagePixelInputs(TypedDict): The result of stacking :attr:`ImageEncoding.tokens` from each prompt. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class PixtralProcessorAdapter: """ @@ -144,11 +135,8 @@ def __call__( "For more info, see: " "https://github.com/vllm-project/vllm/issues/8411.") - image_token_id = self.image_token_id - images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() - images_embed_is_patch = list[torch.Tensor]() for image in images: image_inputs = self.image_processor(ImageChunk(image=image)) @@ -157,12 +145,10 @@ def __call__( images_processed.append(image_processed) images_tokens.append(image_tokens) - images_embed_is_patch.append(image_tokens == image_token_id) return { "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), "images": images_processed, - "embed_is_patch": images_embed_is_patch, } @@ -263,10 +249,7 @@ def _get_mm_fields_config( hf_inputs: Mapping[str, NestedTensors], hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - images=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), - ) + return dict(images=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, @@ -290,7 +273,7 @@ def get_replacement(item_idx: int): tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id - return tokens + return PromptUpdateDetails.select_token_id(tokens, image_token_id) return [ PromptReplacement( @@ -381,17 +364,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of images. " f"Got type: {type(images)}") - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - - embed_is_patch = flatten_bn(embed_is_patch) - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), - embed_is_patch=embed_is_patch, ) def _process_image_input( @@ -427,12 +402,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -444,7 +414,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.vision_args.image_token_id, ) return inputs_embeds diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index ccb5a3f600b2..54220037d253 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -229,9 +229,9 @@ def get_replacement_qwen2_audio(item_idx: int): audio_tokens = [audio_token_id] * num_features - return PromptUpdateDetails( - full=[audio_bos_id] + audio_tokens + [audio_eos_id], - features=audio_tokens, + return PromptUpdateDetails.select_token_id( + [audio_bos_id] + audio_tokens + [audio_eos_id], + embed_token_id=audio_token_id, ) return [ diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 4e9d02ae0abd..a2ec9a9a4d17 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -647,9 +647,9 @@ def _get_prompt_updates( PromptReplacement( modality="image", target=[img_start_id, img_end_id], - replacement=PromptUpdateDetails( - full=[img_start_id] + image_tokens + [img_end_id], - features=image_tokens, + replacement=PromptUpdateDetails.select_token_id( + [img_start_id] + image_tokens + [img_end_id], + embed_token_id=img_pad_id, ), ) ] diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 5c21fb2d4ad2..9a6fac2eec56 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from collections.abc import Sequence -from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast +from typing import Final, Generic, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig @@ -10,12 +9,9 @@ import vllm.envs as envs from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) -from vllm.jsontree import JSONTree, json_map_leaves from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform -from .interfaces import MultiModalEmbeddings - logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) @@ -152,74 +148,3 @@ def resolve_visual_encoder_outputs( if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) return torch.cat(hs_pool, dim=-1) - - -def scatter_patch_features( - patches: Union[torch.Tensor, Sequence[torch.Tensor]], - embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]], -) -> tuple[torch.Tensor, ...]: - """ - Scatter the patch features into a contiguous tensor that corresponds - to the embedding tokens defined by the multimodal processor. - - The rest of the values in the tensor are set to NaN so that they - can be filtered out by :func`select_patch_features`. - - Args: - patches: The patch features for each image. - Shape: `(num_images, , feature_depth)` - embed_is_patch: A boolean mask indicating which image embeddings - correspond to patch tokens for each image. - Shape: `(num_images, num_embeds)` - - Note: - The original code only considers patch tokens as feature - tokens, but our processor considers all image-related tokens - as feature tokens because the feature tokens need to be - consecutive in `input_ids`. - - Example: - A simplified example for one image: - - .. code-block:: - - Embedding tokens (from HF processor): - [ ] - - embed_is_patch (from HF processor): - [ False True True False True True False False ] - - Encoder outputs (from model): - [ p1 p2 p3 p4 ] - - The resulting embedding tensor is: - [ nan p1 p2 nan p3 p4 nan nan ] - """ - if len(patches) != len(embed_is_patch): - raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. " - f"{len(embed_is_patch)=}") - - def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor): - embed_one = patches_one.new_full( - (e_is_patch.shape[0], patches_one.shape[-1]), - fill_value=torch.nan, - ) - embed_one[e_is_patch] = patches_one - return embed_one - - return tuple( - get_embed_one(patches_one, e_is_patch) - for patches_one, e_is_patch in zip(patches, embed_is_patch)) - - -def select_patch_features( - multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings: - """ - Given the outputs of :func:`scatter_patch_features`, return only - the values that correspond to patch features. - """ - selected_features = json_map_leaves( - lambda x: x[~x.isnan()].view(-1, *x.shape[1:]), - cast(JSONTree[torch.Tensor], multimodal_embeddings), - ) - return cast(MultiModalEmbeddings, selected_features) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 81d72ff19022..adf6ff7cf08d 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -131,6 +131,12 @@ class PlaceholderRange(TypedDict): length: int """The length of the placeholder.""" + is_embed: NotRequired[Optional[torch.Tensor]] + """ + A boolean mask of shape `(length,)` indicating which positions + between `offset` and `offset + length` to assign embeddings to. + """ + NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]] diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c8864c33fe37..20177651a514 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -108,16 +108,42 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - features: _S + is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None """ - The part of the content that corresponds to feature placeholders; - this will be replaced by the output of the vision encoder during model - inference. + Given :attr:`full`, return a boolean mask of shape `(len(full),)` + indicating which positions of `full` to assign embeddings to. + + `None` (default) means to assign embeddings to all positions of `full`. + + The embeddings are obtained by calling + :class:`SupportsMultiModal.get_multimodal_embeddings`. """ @staticmethod def from_seq(seq: _S) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails(full=seq, features=seq) + return PromptUpdateDetails(full=seq) + + @staticmethod + def select_token_text( + seq: _S, + embed_token_text: str, + ) -> "PromptUpdateDetails[_S]": + + def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: + embed_token_id, = encode_tokens(full.tokenizer, embed_token_text) + return torch.tensor(full.token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) + + @staticmethod + def select_token_id( + seq: _S, + embed_token_id: int, + ) -> "PromptUpdateDetails[_S]": + return PromptUpdateDetails( + full=seq, + is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, + ) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -406,7 +432,7 @@ def token_ids(self) -> list[int]: @dataclass class _BoundPromptContent: full: _BoundPromptSequence - features: _BoundPromptSequence + is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] @dataclass @@ -466,10 +492,8 @@ def get_content(self, item_idx: int) -> _BoundPromptContent: bound_full = _BoundPromptSequence.from_seq(self.tokenizer, content.full) - bound_features = _BoundPromptSequence.from_seq(self.tokenizer, - content.features) bound_content = _BoundPromptContent(full=bound_full, - features=bound_features) + is_embed=content.is_embed) if cache_key is not None: self._content_cache[cache_key] = bound_content @@ -605,15 +629,19 @@ class PlaceholderFeaturesInfo: item_idx: int start_idx: int tokens: list[int] + is_embed: Optional[torch.Tensor] @property def length(self) -> int: return len(self.tokens) def to_range(self) -> PlaceholderRange: + # TODO: Is it worth it to optimize this by stripping the + # leading and ending positions where `is_embed=False`? return PlaceholderRange( offset=self.start_idx, length=self.length, + is_embed=self.is_embed, ) @@ -806,22 +834,17 @@ def _iter_placeholders( continue if prompt[start_idx:end_idx_full] == content_tokens_full: - content_tokens_feat = content.features.token_ids - - try: - match = next( - iter_token_matches(content_tokens_full, - content_tokens_feat)) - yield PlaceholderFeaturesInfo( - modality=modality, - item_idx=item_idx, - start_idx=start_idx + match.start_idx, - tokens=content_tokens_feat, - ) - except StopIteration: - raise AssertionError( - f"{content_tokens_feat=} should be a " - f"subsequence of {content_tokens_full=}") from None + content_is_embed = content.is_embed + if content_is_embed is not None: + content_is_embed = content_is_embed(content.full) + + yield PlaceholderFeaturesInfo( + modality=modality, + item_idx=item_idx, + start_idx=start_idx, + tokens=content_tokens_full, + is_embed=content_is_embed, + ) # Exclude overlapping matches start_idx = end_idx_full diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c7374cc3d330..e71b6b78e0e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,7 +19,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors @@ -41,7 +42,8 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import sanity_check_mm_encoder_outputs +from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr @@ -830,19 +832,22 @@ def _calc_spec_decode_metadata( ) return metadata - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. - mm_inputs: list[MultiModalKwargs] = [] - req_input_ids: list[tuple[str, int]] = [] + mm_inputs = list[MultiModalKwargs]() + req_ids_pos = list[tuple[str, int, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] - for input_id in encoder_input_ids: + for input_id, pos_info in zip( + encoder_input_ids, + req_state.mm_positions, + ): mm_inputs.append(req_state.mm_inputs[input_id]) - req_input_ids.append((req_id, input_id)) + req_ids_pos.append((req_id, input_id, pos_info)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -878,16 +883,23 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs.append(output) # Cache the encoder outputs. - for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + for (req_id, input_id, pos_info), output in zip( + req_ids_pos, + encoder_outputs, + ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - self.encoder_cache[req_id][input_id] = output - def _gather_encoder_outputs( + self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + output, + is_embed=pos_info.get("is_embed"), + ) + + def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: - encoder_outputs: list[torch.Tensor] = [] + mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] @@ -918,8 +930,16 @@ def _gather_encoder_outputs( assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] - encoder_outputs.append(encoder_output[start_idx:end_idx]) - return encoder_outputs + + if (is_embed := pos_info.get("is_embed")) is not None: + is_embed = is_embed[start_idx:end_idx] + + mm_embeds_item = gather_mm_placeholders( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) + return mm_embeds def get_model(self) -> nn.Module: return self.model @@ -984,10 +1004,10 @@ def execute_model( if self.is_multimodal_model: # Run the multimodal encoder if any. - self._execute_encoder(scheduler_output) - encoder_outputs = self._gather_encoder_outputs(scheduler_output) + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) else: - encoder_outputs = [] + mm_embeds = [] # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1009,9 +1029,9 @@ def execute_model( # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - if encoder_outputs: + if mm_embeds: inputs_embeds = self.model.get_input_embeddings( - input_ids, encoder_outputs) + input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8f6a54892a4e..9b8a2f24fe84 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -19,7 +19,8 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors @@ -37,7 +38,8 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from .utils import sanity_check_mm_encoder_outputs +from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -475,19 +477,47 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = logits_indices.to(self.device) return attn_metadata, logits_indices - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + def _scatter_placeholders( + self, + embeds: torch.Tensor, + is_embed: Optional[torch.Tensor], + ) -> torch.Tensor: + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + def _gather_placeholders( + self, + placeholders: torch.Tensor, + is_embed: Optional[torch.Tensor], + ) -> torch.Tensor: + if is_embed is None: + return placeholders + + return placeholders[is_embed] + + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. - mm_inputs: list[MultiModalKwargs] = [] - req_input_ids: list[tuple[str, int]] = [] + mm_inputs = list[MultiModalKwargs]() + req_ids_pos = list[tuple[str, int, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] - for input_id in encoder_input_ids: + for input_id, pos_info in zip( + encoder_input_ids, + req_state.mm_positions, + ): mm_inputs.append(req_state.mm_inputs[input_id]) - req_input_ids.append((req_id, input_id)) + req_ids_pos.append((req_id, input_id, pos_info)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -523,16 +553,23 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs.append(output) # Cache the encoder outputs. - for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + for (req_id, input_id, pos_info), output in zip( + req_ids_pos, + encoder_outputs, + ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} - self.encoder_cache[req_id][input_id] = output - def _gather_encoder_outputs( + self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + output, + is_embed=pos_info.get("is_embed"), + ) + + def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: - encoder_outputs: list[torch.Tensor] = [] + mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] @@ -563,8 +600,16 @@ def _gather_encoder_outputs( assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] - encoder_outputs.append(encoder_output[start_idx:end_idx]) - return encoder_outputs + + if (is_embed := pos_info.get("is_embed")) is not None: + is_embed = is_embed[start_idx:end_idx] + + mm_embeds_item = gather_mm_placeholders( + encoder_output[start_idx:end_idx], + is_embed=is_embed, + ) + mm_embeds.append(mm_embeds_item) + return mm_embeds @torch.no_grad() def execute_model( @@ -580,10 +625,10 @@ def execute_model( if self.is_multimodal_model: # Run the multimodal encoder if any. - self._execute_encoder(scheduler_output) - encoder_outputs = self._gather_encoder_outputs(scheduler_output) + self._execute_mm_encoder(scheduler_output) + mm_embeds = self._gather_mm_embeddings(scheduler_output) else: - encoder_outputs = [] + mm_embeds = [] # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) @@ -591,9 +636,9 @@ def execute_model( # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - if encoder_outputs: + if mm_embeds: inputs_embeds = self.model.get_input_embeddings( - self.input_ids, encoder_outputs) + self.input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(self.input_ids) input_ids = None diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b1d3aa7cd8af..e46ca0c90fe3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import torch @@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs( f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " "of the model's `get_multimodal_embeddings` method.") + + +def scatter_mm_placeholders( + embeds: torch.Tensor, + is_embed: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Scatter the multimodal embeddings into a contiguous tensor that represents + the placeholder tokens. + + :class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`. + + Args: + embeds: The multimodal embeddings. + Shape: `(num_embeds, embed_dim)` + is_embed: A boolean mask indicating which positions in the placeholder + tokens need to be filled with multimodal embeddings. + Shape: `(num_placeholders, num_embeds)` + """ + if is_embed is None: + return embeds + + placeholders = embeds.new_full( + (is_embed.shape[0], embeds.shape[-1]), + fill_value=torch.nan, + ) + placeholders[is_embed] = embeds + return placeholders + + +def gather_mm_placeholders( + placeholders: torch.Tensor, + is_embed: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Reconstructs the embeddings from the placeholder tokens. + + This is the operation of :func:`scatter_mm_placeholders`. + """ + if is_embed is None: + return placeholders + + return placeholders[is_embed] From d2ad7e634a054feafde40c20c8e56353be68591b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 16:57:58 +0000 Subject: [PATCH 02/23] Fix placeholder token calculation Signed-off-by: DarkLight1337 --- vllm/multimodal/inputs.py | 7 +++++++ vllm/multimodal/profiling.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index adf6ff7cf08d..b0ded8cbcae2 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -138,6 +138,13 @@ class PlaceholderRange(TypedDict): """ +def get_num_embeds(placeholder_info: PlaceholderRange) -> int: + if (is_embed := placeholder_info.get("is_embed")) is None: + return placeholder_info["length"] + + return int(is_embed.sum().item()) + + NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]] """ diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 1df9a1f5eba1..f57559db6356 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -14,7 +14,7 @@ from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs, MultiModalKwargs, - MultiModalPlaceholderDict) + MultiModalPlaceholderDict, get_num_embeds) from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -180,7 +180,7 @@ def get_and_validate_mm_inputs( placeholders_by_modality = mm_inputs["mm_placeholders"] total_placeholders_by_modality = { - modality: sum(item["length"] for item in placeholders) + modality: sum(get_num_embeds(item) for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } expected_placeholders_by_modality = { From 776074b8c942071e705d47b4d32572b500f88864 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 16:59:54 +0000 Subject: [PATCH 03/23] Improve error message Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 20177651a514..5329f08336f8 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -130,8 +130,11 @@ def select_token_text( ) -> "PromptUpdateDetails[_S]": def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_id, = encode_tokens(full.tokenizer, embed_token_text) - return torch.tensor(full.token_ids) == embed_token_id + embed_token_id = encode_tokens(full.tokenizer, embed_token_text) + if len(embed_token_id) > 1: + raise ValueError(f"{embed_token_text} is not a single token") + + return torch.tensor(full.token_ids) == embed_token_id[0] return PromptUpdateDetails(full=seq, is_embed=is_embed) From d8c10c803eb7042abf5b51bc22dfc793af358e2d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 17:02:47 +0000 Subject: [PATCH 04/23] Loosen check Signed-off-by: DarkLight1337 --- vllm/multimodal/processing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5329f08336f8..76ef88feed82 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -130,11 +130,12 @@ def select_token_text( ) -> "PromptUpdateDetails[_S]": def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_id = encode_tokens(full.tokenizer, embed_token_text) - if len(embed_token_id) > 1: - raise ValueError(f"{embed_token_text} is not a single token") + embed_token_ids = encode_tokens(full.tokenizer, embed_token_text) - return torch.tensor(full.token_ids) == embed_token_id[0] + return torch.isin( + torch.tensor(full.token_ids), + torch.tensor(embed_token_ids), + ) return PromptUpdateDetails(full=seq, is_embed=is_embed) From 6257621135db32b873a1e6d1c2b066cb283c98fd Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 17:17:33 +0000 Subject: [PATCH 05/23] Fix Molmo Signed-off-by: DarkLight1337 --- vllm/model_executor/models/molmo.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 1c6d9b4bbf42..6857bfa810e3 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1187,17 +1187,13 @@ def get_num_image_tokens( ) pooling_size = processor.pooling_size - base_image_input_size = processor.base_image_input_size - base_image_input_d = processor.image_patch_size - - crop_patches = base_image_input_size[0] // base_image_input_d + image_token_length_w = processor.image_token_length_w + image_token_length_h = processor.image_token_length_h - per_row = ncols // pooling_size + 1 - joint = per_row * (nrows // pooling_size) + 2 - image_token_length = (crop_patches + pooling_size - 1) // pooling_size - resize = (image_token_length + 1) * image_token_length + 2 + extra = image_token_length_w * image_token_length_h + joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size) - return resize + joint + return extra + joint def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() From 60964067d3a1c87be3a1426d196562a921720235 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 17:26:15 +0000 Subject: [PATCH 06/23] Rename Signed-off-by: DarkLight1337 --- vllm/model_executor/models/h2ovl.py | 2 +- vllm/model_executor/models/idefics3.py | 4 ++-- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmo.py | 2 +- vllm/model_executor/models/minicpmv.py | 4 ++-- vllm/model_executor/models/nvlm_d.py | 2 +- vllm/multimodal/processing.py | 6 +++--- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index bd9ad23fdee1..f975a19a364e 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -257,7 +257,7 @@ def get_image_repl( repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END - return PromptUpdateDetails.select_token_text(repl_full, IMG_CONTEXT) + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) def resolve_min_max_num( self, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8fe3306e8ec7..1ffceb995e65 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -398,9 +398,9 @@ def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails: processor=hf_processor, ) - return PromptUpdateDetails.select_token_text( + return PromptUpdateDetails.select_text( image_repl, - embed_token_text=image_token, + embed_text=image_token, ) return [ diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 467673fdf825..cf5608e3de7b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -439,7 +439,7 @@ def get_image_repl( repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END - return PromptUpdateDetails.select_token_text(repl_full, IMG_CONTEXT) + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) class BaseInternVLProcessingInfo(BaseProcessingInfo): diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 5fe683225094..968f1efcdc87 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -340,7 +340,7 @@ def get_audio_replacement(item_idx: int): else: audio_len = audios.get_audio_length(item_idx) - return PromptUpdateDetails.select_token_text( + return PromptUpdateDetails.select_text( self.get_audio_prompt_texts(audio_len), "", ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 25f16eeb6932..579ec8a55183 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -672,7 +672,7 @@ def get_image_replacement(item_idx: int): image_size = images.get_image_size(item_idx) - return PromptUpdateDetails.select_token_text( + return PromptUpdateDetails.select_text( self.get_image_prompt_texts(image_size, item_idx), "", ) @@ -684,7 +684,7 @@ def get_video_replacement(item_idx: int): frame_size = videos.get_frame_size(item_idx) num_frames = videos.get_num_frames(item_idx) - return PromptUpdateDetails.select_token_text( + return PromptUpdateDetails.select_text( self.get_video_prompt_texts(frame_size, num_frames), "", ) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index ffe5399ba7d5..c0bd18b12d93 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -57,7 +57,7 @@ def get_image_repl( # when trying to find "" - return PromptUpdateDetails.select_token_text(repl, IMG_PAD) + return PromptUpdateDetails.select_text(repl, IMG_PAD) class NVLMProcessingInfo(BaseInternVLProcessingInfo): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 76ef88feed82..a37d2975e9d2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -124,13 +124,13 @@ def from_seq(seq: _S) -> "PromptUpdateDetails[_S]": return PromptUpdateDetails(full=seq) @staticmethod - def select_token_text( + def select_text( seq: _S, - embed_token_text: str, + embed_text: str, ) -> "PromptUpdateDetails[_S]": def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_token_text) + embed_token_ids = encode_tokens(full.tokenizer, embed_text) return torch.isin( torch.tensor(full.token_ids), From d1568adfc24dcef917589e0e2c736811b57e7a64 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 17:48:11 +0000 Subject: [PATCH 07/23] Fix gemma3 Signed-off-by: DarkLight1337 --- vllm/model_executor/models/gemma3_mm.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index eb41ba94cfaa..9552ee1f0b3a 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -25,7 +25,7 @@ PlaceholderFeaturesInfo, PromptReplacement, PromptTargetMatch, PromptUpdate, PromptUpdateDetails, - encode_tokens, find_mm_placeholders, + find_mm_placeholders, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -206,19 +206,17 @@ def get_num_image_tokens( image_height: int, processor: Optional[Gemma3Processor], ) -> int: - tokenizer = self.get_tokenizer() - image_repl = self.get_image_repl( + if processor is None: + processor = self.get_hf_processor() + + num_crops = self.get_num_crops( image_width=image_width, image_height=image_height, processor=processor, ) + image_seq_len = processor.image_seq_length - image_repl_tokens = encode_tokens( - tokenizer, - image_repl.full, - add_special_tokens=False, - ) - return len(image_repl_tokens) + return (num_crops + 1) * image_seq_len def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() From a68228ecf233be8e02d949883a9cf1151bcb238f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 31 Mar 2025 17:52:45 +0000 Subject: [PATCH 08/23] Skip h2ovl for current transformers Signed-off-by: DarkLight1337 --- tests/models/registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 69ebfe4c9241..26a4172a63ec 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -274,7 +274,9 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", - extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501 + extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 + max_transformers_version="4.48", # noqa: E501 + transformers_version_reason="HF model is not compatible."), # noqa: E501 "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 trust_remote_code=True), From cb3cccee3fa46d5a761d7ad66628812da38c93ab Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 12:51:25 +0000 Subject: [PATCH 09/23] Fix MiniCPM Signed-off-by: DarkLight1337 --- examples/offline_inference/audio_language.py | 2 +- vllm/model_executor/models/minicpmo.py | 6 ++-- vllm/model_executor/models/minicpmv.py | 36 +++++++++++++++----- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 840892ea0701..f33efbab955e 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - max_num_seqs=5, + max_num_seqs=2, limit_mm_per_prompt={"audio": audio_count}, ) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 968f1efcdc87..a4fb0cb1741e 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -180,8 +180,7 @@ def get_max_audio_tokens_per_chunk(self) -> int: pool_step = self.get_default_audio_pool_step() fbank_feat_in_chunk = 100 cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 - num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 - return num_audio_tokens + 2 # + return (cnn_feat_in_chunk - pool_step) // pool_step + 1 def get_max_audio_chunks_with_most_features(self) -> int: return 30 @@ -192,8 +191,7 @@ def get_max_audio_tokens(self) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: sampling_rate = self.get_default_audio_sampling_rate() - # exclude - num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 + num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 def get_num_frames_with_most_features( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 579ec8a55183..eb20a963ae2a 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -379,22 +379,43 @@ def get_slice_image_placeholder( use_image_id=use_image_id, ) + def get_sliced_grid( + self, + image_size: ImageSize, + # For MiniCPM V/O 2.6 + max_slice_nums: Optional[int] = None, + ) -> Optional[tuple[int, int]]: + image_processor = self.get_image_processor() + version = self.get_model_version() + + if version == (2, 0) or version == (2, 5): + return image_processor.get_sliced_grid(image_size) + + if max_slice_nums is None: + max_slice_nums = image_processor.max_slice_nums + + return image_processor.get_sliced_grid( + image_size, + max_slice_nums=max_slice_nums, + ) + def get_num_image_tokens( self, image_size: ImageSize, max_slice_nums: Optional[int] = None, - use_image_id: bool = True, ) -> int: - tokenizer = self.get_tokenizer() - image_placeholders = self.get_slice_image_placeholder( + image_processor = self.get_image_processor() + + grid = self.get_sliced_grid( image_size, max_slice_nums=max_slice_nums, - use_image_id=use_image_id, ) - image_token_ids = tokenizer.encode(image_placeholders, - add_special_tokens=False) + if grid is None: + ncols = nrows = 0 + else: + ncols, nrows = grid - return len(image_token_ids) + return (ncols * nrows + 1) * image_processor.image_feature_size def get_max_image_tokens(self) -> int: image_size = self.get_image_size_with_most_features() @@ -414,7 +435,6 @@ def get_max_video_frame_tokens(self) -> int: return self.get_num_image_tokens( frame_size, max_slice_nums=self.get_video_max_slice_num(), - use_image_id=False, ) def get_max_video_tokens( From cff414e36fa55fd396168b09c88be84ed8fafa6e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 12:56:27 +0000 Subject: [PATCH 10/23] Fix Fuyu Signed-off-by: DarkLight1337 --- vllm/model_executor/models/fuyu.py | 79 ++++++++++++------------------ 1 file changed, 30 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7f33e1f17d56..189b91db4a86 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -18,7 +18,7 @@ """ PyTorch Fuyu model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Literal, Optional, Set, Tuple, TypedDict import torch import torch.nn as nn @@ -43,7 +43,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -66,14 +65,6 @@ class FuyuImagePatchInputs(TypedDict): flattened just like `flat_data`. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class FuyuProcessingInfo(BaseProcessingInfo): @@ -94,15 +85,7 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - target_width, target_height = self.get_image_size_with_most_features() - - max_ncols, max_nrows = self.get_image_feature_grid_size( - image_width=target_width, - image_height=target_height, - ) - max_image_tokens = (max_ncols + 1) * max_nrows - - return {"image": max_image_tokens} + return {"image": self.get_max_image_tokens()} def get_image_feature_grid_size( self, @@ -128,11 +111,32 @@ def get_image_feature_grid_size( nrows = math.ceil(image_height / patch_height) return ncols, nrows + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + ncols, nrows = self.get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + ) + + return ncols * nrows + def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): @@ -192,19 +196,6 @@ def _call_hf_processor( processed_outputs["image_patches"] = image_patches[0] - # get patch grid size for each image - embed_is_patch = [] - for image in images: - ncols, nrows = self.info.get_image_feature_grid_size( - image_width=image.width, - image_height=image.height, - ) - - mask = torch.tensor(([True] * ncols + [False]) * nrows) - embed_is_patch.append(mask) - - processed_outputs["embed_is_patch"] = embed_is_patch - return processed_outputs def _apply_hf_processor_tokens_only( @@ -224,8 +215,7 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(image_patches=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image")) + return dict(image_patches=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, @@ -329,20 +319,13 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image patches. " f"Got type: {type(image_patches)}") - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - image_patches_flat = flatten_bn(image_patches) - embed_is_patch = flatten_bn(embed_is_patch) return FuyuImagePatchInputs( type="image_patches", flat_data=self._validate_pixel_values( flatten_bn(image_patches_flat, concat=True)), patches_per_image=[x.size(0) for x in image_patches_flat], - embed_is_patch=embed_is_patch, ) return None @@ -364,12 +347,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -379,8 +357,11 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID) + input_ids, + inputs_embeds, + multimodal_embeddings, + _IMAGE_TOKEN_ID, + ) return inputs_embeds def forward( From 699c34616bcdae5b094dcaa63ac6de6d01c670f1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 14:20:14 +0000 Subject: [PATCH 11/23] Fix NVLM Signed-off-by: DarkLight1337 --- vllm/model_executor/models/nvlm_d.py | 29 +++------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index c0bd18b12d93..314f75c20301 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -84,31 +84,6 @@ def get_hf_processor( **kwargs, ) - def get_max_image_tokens(self) -> int: - hf_processor = self.get_hf_processor() - tokenizer = hf_processor.tokenizer - - max_num_patches = hf_processor.max_dynamic_patch - # we need +1 here because max_dynamic_patch in config doesn't - # include the thumbnail patch - tile_pos_identifiers = [ - f"" for i in range(max_num_patches) - ] - if hf_processor.use_thumbnail and max_num_patches != 1: - tile_pos_identifiers += [""] - - # "<", "tile"] - # so we include in the start_str - start_str = "" + tile_pos_identifiers.pop(0) - end_str = "" - start_token_len = len(tokenizer.encode(start_str)) - end_token_len = len(tokenizer.encode(end_str)) - tile_token_len = sum( - len(tokenizer.encode(identifier)) - for identifier in tile_pos_identifiers) - non_image_tokens_num = start_token_len + end_token_len + tile_token_len - return super().get_max_image_tokens() + non_image_tokens_num - class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): @@ -175,7 +150,9 @@ def get_replacement_nvlm(item_idx: int): if num_patches is not None: assert isinstance(num_patches, int) - return hf_processor.get_image_repl(feature_size, num_patches) + repl = hf_processor.get_image_repl(feature_size, num_patches) + + return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD) # See note in dummy data regarding why we have the extra newline return [ From cd868618026d352a82216e292fbecddc36565fc9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 14:21:58 +0000 Subject: [PATCH 12/23] Enable `PlaceholderRange` equality check Signed-off-by: DarkLight1337 --- .../vision_language/test_pixtral.py | 24 ++++------ .../multimodal/processing/test_llava_next.py | 4 +- .../processing/test_llava_onevision.py | 4 +- tests/v1/core/test_kv_cache_utils.py | 46 +++++++------------ vllm/multimodal/base.py | 4 +- vllm/multimodal/inputs.py | 29 ++++++++---- vllm/multimodal/profiling.py | 4 +- vllm/v1/core/kv_cache_utils.py | 7 ++- vllm/v1/core/sched/scheduler.py | 8 ++-- vllm/v1/request.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 +- vllm/v1/worker/tpu_model_runner.py | 4 +- 12 files changed, 66 insertions(+), 74 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index ee619d8d80c4..400b016a4dd1 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -198,22 +198,14 @@ def test_chat( @large_gpu_test(min_gb=48) -@pytest.mark.parametrize( - "prompt,expected_ranges", - [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ - "offset": 11, - "length": 494 - }]), - (_create_engine_inputs_hf(IMG_URLS[1:4]), [{ - "offset": 11, - "length": 266 - }, { - "offset": 277, - "length": 1056 - }, { - "offset": 1333, - "length": 418 - }])]) +@pytest.mark.parametrize("prompt,expected_ranges", + [(_create_engine_inputs_hf(IMG_URLS[:1]), + [PlaceholderRange(offset=11, length=494)]), + (_create_engine_inputs_hf(IMG_URLS[1:4]), [ + PlaceholderRange(offset=11, length=266), + PlaceholderRange(offset=277, length=1056), + PlaceholderRange(offset=1333, length=418) + ])]) def test_multi_modal_placeholders(vllm_runner, prompt, expected_ranges: list[PlaceholderRange], monkeypatch) -> None: diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py index fe56a200a330..b82bfe483dbb 100644 --- a/tests/models/multimodal/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( first_placeholder = image_placeholders[0] # NOTE: There is a BOS token - assert first_placeholder["offset"] == 1 - assert first_placeholder["length"] == ( + assert first_placeholder.offset == 1 + assert first_placeholder.length == ( len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs except Exception as exc: diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py index 7cefdd37ee49..dcc8dc8dab5a 100644 --- a/tests/models/multimodal/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( first_placeholder = image_placeholders[0] - assert first_placeholder["offset"] == 0 - assert first_placeholder["length"] == len( + assert first_placeholder.offset == 0 + assert first_placeholder.length == len( processed_inputs["prompt_token_ids"]) // num_imgs except Exception as exc: failed_size_excs.append((image_size, exc)) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 8362af24a67e..51836644b325 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -3,7 +3,7 @@ import pytest import torch -from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import sha256 # disable yapf here as it formats differently than isort such that both fail @@ -158,13 +158,10 @@ def test_generate_block_hash_extra_keys(): request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(20)], - mm_positions=[{ - "offset": 0, - "length": 5 - }, { - "offset": 10, - "length": 5 - }], + mm_positions=[ + PlaceholderRange(offset=0, length=5), + PlaceholderRange(offset=10, length=5), + ], mm_hashes=["hash1", "hash2"], ) @@ -222,13 +219,10 @@ def test_hash_request_tokens(hash_fn): request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], - mm_positions=[{ - "offset": 0, - "length": 3 - }, { - "offset": 3, - "length": 3 - }], + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], mm_hashes=["hash1", "hash2"], ) @@ -253,25 +247,19 @@ def test_hash_tokens_different_mm_input(hash_fn): request1 = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], - mm_positions=[{ - "offset": 0, - "length": 3 - }, { - "offset": 3, - "length": 3 - }], + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], mm_hashes=["hash1", "hash2"], ) request2 = make_request( request_id=1, prompt_token_ids=[_ for _ in range(6)], - mm_positions=[{ - "offset": 0, - "length": 3 - }, { - "offset": 3, - "length": 3 - }], + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], mm_hashes=["hash3", "hash2"], ) block_size = 3 diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 5159b0bca8c1..ad95b982499c 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -385,8 +385,8 @@ def append_items_from_seq_group( for placeholder_dict, mm_item in zip(multi_modal_placeholders, multi_modal_items): placeholder = range( - placeholder_dict["offset"], - placeholder_dict["offset"] + placeholder_dict["length"], + placeholder_dict.offset, + placeholder_dict.offset + placeholder_dict.length, ) intersection = range( max(positions.start, placeholder.start), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index b0ded8cbcae2..53729799b629 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -109,7 +109,8 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ -class PlaceholderRange(TypedDict): +@dataclass(frozen=True) +class PlaceholderRange: """ Placeholder location information for multi-modal data. @@ -121,8 +122,8 @@ class PlaceholderRange(TypedDict): .. code-block:: - A: { "offset": 0, "length": 4 } - B: { "offset": 5, "length": 4 } + A: PlaceholderRange(offset=0, length=4) + B: PlaceholderRange(offset=5, length=4) """ offset: int @@ -131,18 +132,30 @@ class PlaceholderRange(TypedDict): length: int """The length of the placeholder.""" - is_embed: NotRequired[Optional[torch.Tensor]] + is_embed: Optional[torch.Tensor] = None """ A boolean mask of shape `(length,)` indicating which positions between `offset` and `offset + length` to assign embeddings to. """ + def get_num_embeds(self) -> int: + if self.is_embed is None: + return self.length + + return int(self.is_embed.sum().item()) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + if not (self.offset, self.length) == (other.offset, other.length): + return False -def get_num_embeds(placeholder_info: PlaceholderRange) -> int: - if (is_embed := placeholder_info.get("is_embed")) is None: - return placeholder_info["length"] + if self.is_embed is None: + return other.is_embed is None + if other.is_embed is None: + return self.is_embed is None - return int(is_embed.sum().item()) + return nested_tensors_equal(self.is_embed, other.is_embed) NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index f57559db6356..4616e4e95785 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -14,7 +14,7 @@ from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs, MultiModalKwargs, - MultiModalPlaceholderDict, get_num_embeds) + MultiModalPlaceholderDict) from .processing import BaseMultiModalProcessor, BaseProcessingInfo logger = init_logger(__name__) @@ -180,7 +180,7 @@ def get_and_validate_mm_inputs( placeholders_by_modality = mm_inputs["mm_placeholders"] total_placeholders_by_modality = { - modality: sum(get_num_embeds(item) for item in placeholders) + modality: sum(item.get_num_embeds() for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } expected_placeholders_by_modality = { diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 34bc9369b125..afcf7e344a0f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, # Note that we assume mm_positions is sorted by offset. # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. - if mm_positions[-1]["offset"] + mm_positions[-1][ - "length"] < start_token_idx: + if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. @@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, curr_mm_idx = start_mm_idx while mm_positions and curr_mm_idx < len(mm_positions): assert mm_hashes[curr_mm_idx] is not None - offset = mm_positions[curr_mm_idx]["offset"] - length = mm_positions[curr_mm_idx]["length"] + offset = mm_positions[curr_mm_idx].offset + length = mm_positions[curr_mm_idx].length if end_token_idx > offset: if start_token_idx > offset + length: # This block has passed the current mm input. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4d477567b9b6..58840c400d1a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -504,8 +504,8 @@ def _try_schedule_encoder_inputs( assert mm_positions is not None assert len(mm_positions) > 0 for i, pos_info in enumerate(mm_positions): - start_pos = pos_info["offset"] - num_encoder_tokens = pos_info["length"] + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -590,8 +590,8 @@ def update_from_output( if cached_encoder_input_ids: for input_id in list(cached_encoder_input_ids): mm_positions = request.mm_positions[input_id] - start_pos = mm_positions["offset"] - num_tokens = mm_positions["length"] + start_pos = mm_positions.offset + num_tokens = mm_positions.length if start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 490fe4e83d3a..daf59fd76e9a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -121,7 +121,7 @@ def get_finished_reason(self) -> Union[FinishReason, None]: def get_num_encoder_tokens(self, input_id: int) -> int: assert input_id < len(self.mm_positions) - num_tokens = self.mm_positions[input_id]["length"] + num_tokens = self.mm_positions[input_id].length return num_tokens @property diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 078bc63b4f04..c7450b38804e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -908,8 +908,8 @@ def _gather_mm_embeddings( num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): - start_pos = pos_info["offset"] - num_encoder_tokens = pos_info["length"] + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 994396da789e..c1175cfa5d30 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -613,8 +613,8 @@ def _gather_mm_embeddings( num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): - start_pos = pos_info["offset"] - num_encoder_tokens = pos_info["length"] + start_pos = pos_info.offset + num_encoder_tokens = pos_info.length # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, From f9b92dfa618d9a7416ac94b80cdff2a9438c7bed Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 15:37:06 +0000 Subject: [PATCH 13/23] Fix Signed-off-by: DarkLight1337 --- vllm/multimodal/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index fc0fb8929b1e..77c83f0c2b21 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata( all_items.append((modality, placeholder, hash_value)) # Sort all items by offset - all_items.sort(key=lambda x: x[1]['offset']) + all_items.sort(key=lambda x: x[1].offset) # Split into separate lists sorted_modalities = [item[0] for item in all_items] From 53f550786860908079a1f481e4f95a1910daef02 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Apr 2025 17:44:39 +0000 Subject: [PATCH 14/23] Fix pos_info.is_embed Signed-off-by: mgoin --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c7450b38804e..e440eda3d863 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -893,7 +893,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( output, - is_embed=pos_info.get("is_embed"), + is_embed=pos_info.is_embed, ) def _gather_mm_embeddings( @@ -932,7 +932,7 @@ def _gather_mm_embeddings( assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] - if (is_embed := pos_info.get("is_embed")) is not None: + if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] mm_embeds_item = gather_mm_placeholders( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c1175cfa5d30..8416f2e5c555 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -598,7 +598,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( output, - is_embed=pos_info.get("is_embed"), + is_embed=pos_info.is_embed, ) def _gather_mm_embeddings( @@ -637,7 +637,7 @@ def _gather_mm_embeddings( assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] - if (is_embed := pos_info.get("is_embed")) is not None: + if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] mm_embeds_item = gather_mm_placeholders( From f798d892fee09fe7bfbf19d5b5ecfe0dee675b56 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 21:32:31 +0000 Subject: [PATCH 15/23] Fix Idefics3 Signed-off-by: DarkLight1337 --- vllm/model_executor/models/idefics3.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 1ffceb995e65..347106bc4dcf 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -41,7 +41,7 @@ MultiModalDataItems, MultiModalFieldConfig, PromptReplacement, PromptUpdate, - PromptUpdateDetails, encode_tokens) + PromptUpdateDetails) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -258,19 +258,16 @@ def get_num_image_tokens( image_height: int, processor: Optional[Idefics3Processor], ) -> int: - tokenizer = self.get_tokenizer() - image_repl = self.get_image_repl( + if processor is None: + processor = self.get_hf_processor() + + num_patches = self.get_num_patches( image_width=image_width, image_height=image_height, processor=processor, ) - image_repl_tokens = encode_tokens( - tokenizer, - image_repl, - add_special_tokens=False, - ) - return len(image_repl_tokens) + return num_patches * processor.image_seq_len def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() From 06bf57d44e7ac6b84815dea5085b6bbb56fd4333 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 1 Apr 2025 21:39:07 +0000 Subject: [PATCH 16/23] Remove embed_is_patch from new models Signed-off-by: DarkLight1337 --- vllm/model_executor/models/aya_vision.py | 109 +++++------------------ vllm/model_executor/models/mistral3.py | 42 ++------- vllm/model_executor/models/skyworkr1v.py | 50 +---------- 3 files changed, 28 insertions(+), 173 deletions(-) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index b4bf1d82c083..24d5f851bcc1 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -27,7 +27,7 @@ BaseProcessingInfo, MultiModalFieldConfig, PromptReplacement, PromptUpdate, - encode_tokens) + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -35,7 +35,6 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features class AyaVisionImagePixelInputs(TypedDict): @@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict): num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class AyaVisionMultiModalProjector(nn.Module): @@ -135,21 +127,20 @@ def get_mm_max_tokens_per_item( def get_max_image_tokens(self) -> int: hf_processor = self.get_hf_processor() image_processor = hf_processor.image_processor + image_size = self.get_image_size_with_most_features() - tokenizer = hf_processor.tokenizer num_patches = self.get_num_patches( image_width=image_size.width, image_height=image_size.height, size=image_processor.size, min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches) - image_string = hf_processor._prompt_split_image(num_patches) - x = encode_tokens( - tokenizer, - image_string, - add_special_tokens=False, + max_patches=image_processor.max_patches, ) - return len(x) + + img_patches_per_tile = (hf_processor.img_size // + hf_processor.patch_size)**2 + + return num_patches * img_patches_per_tile def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -207,60 +198,6 @@ def get_dummy_processor_inputs( class AyaVisionMultiModalProcessor( BaseMultiModalProcessor[AyaVisionProcessingInfo]): - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt, - mm_data, - mm_kwargs, - ) - hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_processor = hf_processor.image_processor - - hf_config = self.info.get_hf_config() - # HF processor pops the `num_patches` kwarg, which is needed by vLLM - if (images := - mm_data.get("images")) is not None and '' in prompt: - assert isinstance(images, list) - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) - image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) - ] - num_patches = [ - self.info.get_num_patches( - image_width=image_size.width, - image_height=image_size.height, - size=image_processor.size, - min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches) - for image_size in image_sizes - ] - image_tokens_list = [ - hf_processor._prompt_split_image(num_patch) - for num_patch in num_patches - ] - tokenizer = self.info.get_tokenizer() - image_token_ids = [ - tokenizer.encode(image_tokens, add_special_tokens=False) - for image_tokens in image_tokens_list - ] - embed_is_patch = [ - torch.tensor(image_repl_tokens) == hf_config.image_token_index - for image_repl_tokens in image_token_ids - ] - processed_outputs["embed_is_patch"] = embed_is_patch - processed_outputs["num_patches"] = torch.tensor(num_patches) - - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -271,7 +208,6 @@ def _get_mm_fields_config( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -283,6 +219,7 @@ def _get_prompt_updates( ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token + img_patch_token = hf_processor.img_patch_token image_processor = hf_processor.image_processor def get_replacement(item_idx: int): @@ -294,8 +231,11 @@ def get_replacement(item_idx: int): image_height=image_size.height, size=image_processor.size, min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches) - return hf_processor._prompt_split_image(num_patches=num_patches) + max_patches=image_processor.max_patches, + ) + repl = hf_processor._prompt_split_image(num_patches=num_patches) + + return PromptUpdateDetails.select_text(repl, img_patch_token) return [ PromptReplacement( @@ -424,7 +364,6 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) - embed_is_patch = kwargs.pop("embed_is_patch", None) image_embeds = kwargs.pop("image_embeds", None) assert image_embeds is None, "Aya Vision does not support image_embeds." @@ -436,18 +375,13 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of num_patches. " f"Got type: {type(num_patches)}") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - pixel_values = flatten_bn(pixel_values, concat=True) num_patches = flatten_bn(num_patches, concat=True) - embed_is_patch = flatten_bn(embed_is_patch) + return AyaVisionImagePixelInputs( type="pixel_values", pixel_values=self._validate_pixel_values(pixel_values), num_patches=num_patches, - embed_is_patch=embed_is_patch, ) def get_multimodal_embeddings( @@ -455,11 +389,8 @@ def get_multimodal_embeddings( image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - image_features = self._process_image_input(image_input, **kwargs) - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + + return self._process_image_input(image_input, **kwargs) def get_input_embeddings( self, @@ -471,9 +402,9 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, inputs_embeds=inputs_embeds, - multimodal_embeddings=select_patch_features( - multimodal_embeddings), - placeholder_token_id=self.config.image_token_index) + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.image_token_index, + ) return inputs_embeds diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 4cd9a7bf58e7..64ecaa29665c 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -27,7 +27,8 @@ MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -36,7 +37,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import get_vision_encoder_info, select_patch_features +from .vision import get_vision_encoder_info class Mistral3ImagePixelInputs(TypedDict): @@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict): in which case the data is passed as a list instead of a batched tensor. """ - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size, num_images, num_embeds)` - """ - class Mistral3PatchMerger(nn.Module): """ @@ -266,23 +259,6 @@ def _call_hf_processor( p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) ] - hf_config = self.info.get_hf_config() - vision_config = hf_config.vision_config - assert isinstance(vision_config, PixtralVisionConfig) - encoder_info = PixtralHFEncoderInfo(vision_config) - - tile_sizes = [ - encoder_info.get_patch_grid_size( - image_width=pixel_value.shape[-1], - image_height=pixel_value.shape[-2], - ) for pixel_value in processed_outputs["pixel_values"] - ] - embed_is_patch = [ - torch.tensor(([True] * ncols + [False]) * nrows) - for ncols, nrows in tile_sizes - ] - processed_outputs["embed_is_patch"] = embed_is_patch - return processed_outputs def _get_mm_fields_config( @@ -292,7 +268,6 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -327,7 +302,7 @@ def get_replacement(item_idx: int): tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id - return tokens + return PromptUpdateDetails.select_token_id(tokens, image_token_id) return [ PromptReplacement( @@ -509,16 +484,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - assert self.config.vision_config.model_type == "pixtral" - embed_is_patch = kwargs.pop("embed_is_patch") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - return Mistral3ImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), - embed_is_patch=embed_is_patch, ) def _process_image_input( @@ -569,7 +537,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.config.image_token_index, ) return inputs_embeds diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index ac5de0e36b89..09c0e528b24f 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -40,7 +40,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import scatter_patch_features, select_patch_features IMG_START = '' IMG_END = '' @@ -61,14 +60,6 @@ class SkyworkR1VImagePixelInputs(TypedDict): num_patches: torch.Tensor """Shape: `(batch_size * num_images)`""" - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] - """ - A boolean mask indicating which image embeddings correspond - to patch tokens. - - Shape: `(batch_size * num_images, num_embeds)` - """ - class SkyworkR1VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] @@ -419,25 +410,6 @@ def __call__( torch.tensor([len(item) for item in pixel_values_lst]), } - tokenizer = self.tokenizer - image_token_id = self.image_token_id - - embed_is_patch = list[torch.Tensor]() - - for pixel_values in pixel_values_lst: - num_patches = pixel_values.shape[0] - feature_size = num_patches * self.num_image_token - - image_repl = self.get_image_repl(feature_size, num_patches) - feature_tokens = tokenizer.encode(image_repl.features, - add_special_tokens=False) - - text = [t.replace('', image_repl.full, 1) for t in text] - embed_is_patch.append( - torch.tensor(feature_tokens) == image_token_id) - - image_inputs["embed_is_patch"] = embed_is_patch - text_inputs = self.tokenizer(text) return { @@ -460,7 +432,7 @@ def get_image_repl( repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END - return PromptUpdateDetails(full=repl_full, features=repl_features) + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): @@ -599,7 +571,6 @@ def _get_mm_fields_config( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), - embed_is_patch=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -835,7 +806,6 @@ def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) - embed_is_patch = kwargs.pop("embed_is_patch", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values_flat is None and image_embeds is None: @@ -864,20 +834,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image_num_patches. " f"Got type: {type(image_num_patches)}") - if not isinstance(embed_is_patch, (torch.Tensor, list)): - raise ValueError("Incorrect type of embed_is_patch. " - f"Got type: {type(embed_is_patch)}") - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) - embed_is_patch = flatten_bn(embed_is_patch) return SkyworkR1VImagePixelInputs( type="pixel_values", pixel_values_flat=self._validate_pixel_values( pixel_values_flat), num_patches=image_num_patches, - embed_is_patch=embed_is_patch, ) raise AssertionError("This line should be unreachable.") @@ -923,15 +887,7 @@ def get_multimodal_embeddings( if image_input is None: return None - image_features = self._process_image_input(image_input) - - if image_input["type"] != "pixel_values": - return image_features - - return scatter_patch_features( - image_features, - image_input["embed_is_patch"], - ) + return self._process_image_input(image_input) def get_input_embeddings( self, @@ -945,7 +901,7 @@ def get_input_embeddings( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, - select_patch_features(multimodal_embeddings), + multimodal_embeddings, self.img_context_token_id, ) return inputs_embeds From df174e6c68a56326f3ca69b3bcd4b978de9ae8bc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 2 Apr 2025 09:43:04 +0000 Subject: [PATCH 17/23] Fix missing replacement Signed-off-by: DarkLight1337 --- vllm/model_executor/models/skyworkr1v.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 09c0e528b24f..e3deae828a33 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -410,6 +410,14 @@ def __call__( torch.tensor([len(item) for item in pixel_values_lst]), } + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + + image_repl = self.get_image_repl(feature_size, num_patches) + + text = [t.replace('', image_repl.full, 1) for t in text] + text_inputs = self.tokenizer(text) return { From 1e9fb1bd35286d330f2762305e20096a954cd2d9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 2 Apr 2025 09:58:21 +0000 Subject: [PATCH 18/23] Fix Pixtral token calculation and update docs Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.md | 5 +---- vllm/model_executor/models/mistral3.py | 7 ++----- vllm/model_executor/models/pixtral.py | 6 ++---- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index af0f7304c665..87fd52a856e7 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -883,7 +883,7 @@ See [this page](#generative-models) for more information on how to use generativ * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. * * ✅︎ - * + * ✅︎ - * `MllamaForConditionalGeneration` * Llama 3.2 * T + I+ @@ -984,9 +984,6 @@ See [this page](#generative-models) for more information on how to use generativ + Multiple items can be inputted per text prompt for this modality. :::{important} -To use Gemma3 series models, you have to install Hugging Face Transformers library from source via -`pip install git+https://github.com/huggingface/transformers`. - Pan-and-scan image pre-processing is currently supported on V0 (but not V1). You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`. ::: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 64ecaa29665c..b6fbc6b1fa3d 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -32,8 +32,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsV0Only) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -393,14 +392,12 @@ def init_vision_tower_for_llava( ) -# TODO(mgoin): Support V1, there are issues with image batching/chunking -# that need to be resolved first. @MULTIMODAL_REGISTRY.register_processor( _build_mistral3_processor, info=_build_mistral3_info, dummy_inputs=Mistral3DummyInputsBuilder) class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsV0Only): + SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index f68d3720f59e..e07c6516aef2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -199,7 +199,7 @@ def get_num_image_tokens( ncols, nrows = processor.image_processor._image_to_num_tokens( Image.new("RGB", (image_width, image_height))) - return (ncols + 1) * nrows + return ncols * nrows def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_hf_processor().image_processor @@ -933,9 +933,7 @@ def get_num_image_tokens( image_width=image_width, image_height=image_height, ) - - # Consider the image_break_token - return (ncols + 1) * nrows + return ncols * nrows def get_max_image_tokens(self) -> int: image_size = self.get_image_size() From e3ec92d17b98448ad03d50e217c50fe37c938172 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 3 Apr 2025 14:12:10 +0000 Subject: [PATCH 19/23] Fix CI Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 9 ++++++ vllm/model_executor/models/aya_vision.py | 40 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index da112bd7a921..fa9588a05096 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -785,6 +785,7 @@ def test_find_update_tokens( item_idx=0, start_idx=6, tokens=[32000, 32000], + is_embed=None, ), ], "pattern_4": [ @@ -793,6 +794,7 @@ def test_find_update_tokens( item_idx=0, start_idx=3, tokens=[32000], + is_embed=None, ), ], } @@ -807,12 +809,14 @@ def test_find_update_tokens( item_idx=0, start_idx=1, tokens=[32000, 32000], + is_embed=None, ), PlaceholderFeaturesInfo( modality="pattern_1", item_idx=1, start_idx=5, tokens=[32000, 32000], + is_embed=None, ), ], "pattern_3": [ @@ -821,6 +825,7 @@ def test_find_update_tokens( item_idx=0, start_idx=7, tokens=[1550, 918, 1550], + is_embed=None, ), ], # No match for pattern_4 as it has lower priority than pattern_1 @@ -835,12 +840,14 @@ def test_find_update_tokens( item_idx=0, start_idx=1, tokens=[32000, 32000], + is_embed=None, ), PlaceholderFeaturesInfo( modality="pattern_1", item_idx=1, start_idx=3, tokens=[32000, 32000], + is_embed=None, ), ], "pattern_4": [ @@ -849,6 +856,7 @@ def test_find_update_tokens( item_idx=0, start_idx=5, tokens=[32000], + is_embed=None, ), ], "pattern_3": [ @@ -857,6 +865,7 @@ def test_find_update_tokens( item_idx=0, start_idx=6, tokens=[1550, 918, 1550], + is_embed=None, ), ], } diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 24d5f851bcc1..6b68885d375a 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -198,6 +198,46 @@ def get_dummy_processor_inputs( class AyaVisionMultiModalProcessor( BaseMultiModalProcessor[AyaVisionProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + ) + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_processor = hf_processor.image_processor + + # HF processor pops the `num_patches` kwarg, which is needed by vLLM + if (images := + mm_data.get("images")) is not None and '' in prompt: + assert isinstance(images, list) + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) + for i in range(len(parsed_images)) + ] + + num_patches = [ + self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches) + for image_size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + return processed_outputs + def _get_mm_fields_config( self, hf_inputs: BatchFeature, From 6d511188090fee17fb8994ddd02f99e5f1cc1a42 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 3 Apr 2025 21:11:43 -0700 Subject: [PATCH 20/23] increase default timeout for testing Signed-off-by: Roger Wang --- tests/models/decoder_only/audio_language/test_ultravox.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 83ece5d22bfb..ab50c17d4bff 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -55,7 +55,10 @@ def server(request, audio_assets): for key, value in request.param.items() ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, + args, + env={"VLLM_AUDIO_FETCH_TIMEOUT": + "30"}) as remote_server: yield remote_server From 4940d1678459f66890128f3fe59cc3d0f7edefd7 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 3 Apr 2025 21:19:02 -0700 Subject: [PATCH 21/23] typo Signed-off-by: Roger Wang --- tests/models/decoder_only/audio_language/test_ultravox.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index ab50c17d4bff..242f3398b921 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -57,8 +57,8 @@ def server(request, audio_assets): with RemoteOpenAIServer(MODEL_NAME, args, - env={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": + "30"}) as remote_server: yield remote_server From 522833b183e5ac9123d861b9ec08ced31dbefa1c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 3 Apr 2025 23:49:05 -0700 Subject: [PATCH 22/23] lower max model len Signed-off-by: Roger Wang --- tests/models/decoder_only/vision_language/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3b34f012f626..b984cd6f5488 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -167,7 +167,7 @@ "cherry_blossom": "What is the season?", # noqa: E501 }), multi_image_prompt="Describe the two images in detail.", # noqa: E501 - max_model_len=8192, + max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}} From b64983c271c034e8207bdfe52eaf5285fc7df3f3 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Fri, 4 Apr 2025 11:59:33 -0700 Subject: [PATCH 23/23] update pixtral test Signed-off-by: Roger Wang --- tests/models/decoder_only/vision_language/test_pixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 400b016a4dd1..6ebe75f0e812 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -176,6 +176,8 @@ def test_chat( model, dtype=dtype, tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", max_model_len=max_model_len, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: