diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index b0cb4a62333a..e7e73f446df2 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -24,7 +24,6 @@ from typing import Iterable, Optional, Set, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from transformers import CohereConfig diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 432f26141048..327ec4640f03 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -17,16 +17,14 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union import torch -import torch.utils.checkpoint from torch import nn from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, Idefics3Processor) from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -35,13 +33,16 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.parse import ImageProcessorItems, ImageSize +# yapf conflicts with isort for this block +# yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, MultiModalDataItems, MultiModalFieldConfig, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + encode_tokens) +# yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -53,18 +54,28 @@ from .llama import LlamaModel from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) - -logger = init_logger(__name__) +from .vision import scatter_patch_features, select_patch_features class Idefics3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor + pixel_values: torch.Tensor """ Shape: `(batch_size * num_images * num_patches, num_channels, height, width)` """ - pixel_attention_mask: Optional[torch.BoolTensor] + pixel_attention_mask: torch.Tensor + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ class Idefics3ImageEmbeddingInputs(TypedDict): @@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict): `hidden_size` must match the hidden size of language model backbone. """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] @@ -100,32 +119,14 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - hf_processor = self.get_hf_processor() - image_processor: Idefics3ImageProcessor = hf_processor.image_processor - grid_w, grid_h = self._get_image_feature_grid_size( - image_width=image_processor.size['longest_edge'], - image_height=image_processor.size['longest_edge'], - ) - num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len - # Calculate Non-image-token length - # NOTE: and are special token for SmolVLM - # but not for Idefic3, so we need to tokenize them to get actual length. - tokenizer = self.get_tokenizer() - tile_token_len = len(tokenizer.tokenize("")) - glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag)) - # linebreak and always cost 1 token - fake_token_len = lb_len = 1 - non_image_token = (grid_w * grid_h) * ( - tile_token_len + fake_token_len) + glob_token_len + ( - grid_h + 1) * lb_len + fake_token_len - return {"image": num_image_token + non_image_token} + return {"image": self.get_max_image_tokens()} def _resize_output_size(self, *, height: int, width: int, max_len: Optional[int] = None, - min_len: Optional[int] = 1, + min_len: int = 1, max_size: Optional[int] = None) -> tuple[int, int]: # Set default value for max_len if not provided max_len = max(height, width) if max_len is None else max_len @@ -181,10 +182,13 @@ def _get_image_feature_grid_size( *, image_width: int, image_height: int, - size: Optional[dict[str, object]] = None, + processor: Optional[Idefics3Processor], ) -> tuple[int, int]: - hf_processor = self.get_hf_processor(size=size) - image_processor: Idefics3ImageProcessor = hf_processor.image_processor + if processor is None: + processor = self.get_hf_processor() + + image_processor: Idefics3ImageProcessor = processor.image_processor + max_image_size = image_processor.max_image_size['longest_edge'] size = image_processor.size['longest_edge'] assert size % max_image_size == 0, ( @@ -204,6 +208,105 @@ def _get_image_feature_grid_size( grid_h = grid_w = 0 return grid_w, grid_h + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + return grid_w * grid_h + 1 + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> str: + if processor is None: + processor = self.get_hf_processor() + + image_token = processor.image_token.content + fake_image_token = processor.fake_image_token.content + global_img_token = processor.global_image_tag + image_seq_len = processor.image_seq_len + grid_placeholder = "" + + p_img = image_token * image_seq_len + global_img_placeholder = fake_image_token + global_img_token + p_img + tile_img_placeholder = fake_image_token + grid_placeholder + p_img + + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + if grid_w == 0 and grid_h == 0: + return global_img_placeholder + fake_image_token + + tiles_placeholder = list[str]() + for i in range(grid_h): + for j in range(grid_w): + placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, + n_w=j + 1) + tiles_placeholder.append(placeholder_per_tile) + # Add line break if it is the last tile in the row + if j == grid_w - 1: + tiles_placeholder.append("\n") + + return "".join([ + *tiles_placeholder, + "\n", + global_img_placeholder, + fake_image_token, + ]) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + tokenizer = self.get_tokenizer() + image_repl = self.get_image_repl( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + image_repl_tokens = encode_tokens( + tokenizer, + image_repl, + add_special_tokens=False, + ) + return len(image_repl_tokens) + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = processor.image_processor + + return ImageSize( + width=image_processor.size["longest_edge"], + height=image_processor.size["longest_edge"], + ) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + processor=None, + ) + class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ): @@ -217,7 +320,7 @@ def get_dummy_processor_inputs( hf_processor = self.info.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor longest_edge = image_processor.max_image_size['longest_edge'] - image_token: str = hf_processor.image_token.content + image_token = hf_processor.image_token.content mm_data = { "image": @@ -241,26 +344,61 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - if mm_data: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) - image_grids = [ - self.info._get_image_feature_grid_size( - image_width=img.width, - image_height=img.height, - **mm_kwargs, - ) for img in mm_data["images"] - ] - image_patches = list(map(lambda x: math.prod(x) + 1, image_grids)) - for key in ("pixel_values", "pixel_attention_mask"): - data = processed_outputs.pop(key) - data = data.flatten(0, 1).split(image_patches) - processed_outputs[key] = data - else: - tokenizer = self.info.get_tokenizer() - processed_outputs = tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") + # Text-only input not supported in composite processor + if not (images := mm_data.get("images", [])): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + ) + + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + image_repl_features = [ + self.info.get_image_repl(image_width=size.width, + image_height=size.height, + processor=hf_processor) + for size in image_sizes + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + + vocab = tokenizer.get_vocab() + image_token_id = vocab[hf_processor.image_token.content] + + embed_is_patch = [ + torch.tensor(image_repl_tokens) == image_token_id + for image_repl_tokens in image_repls_feature_tokens + ] + processed_outputs["embed_is_patch"] = embed_is_patch + + num_patches = [ + self.info.get_num_patches( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) for size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + # Remove the extra batch dimension + processed_outputs["pixel_values"].squeeze_(0) + processed_outputs["pixel_attention_mask"].squeeze_(0) + return processed_outputs def _get_mm_fields_config( @@ -268,10 +406,16 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - pixel_attention_mask=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), image_embeds=MultiModalFieldConfig.batched("image"), + num_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -281,42 +425,18 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_token = hf_processor.image_token.content - fake_image_token = hf_processor.fake_image_token.content - global_img_token = hf_processor.global_image_tag - image_seq_len = hf_processor.image_seq_len - grid_placeholder = "" - - p_img = image_token * image_seq_len - global_img_placeholder = fake_image_token + global_img_token + p_img - tile_img_placeholder = fake_image_token + grid_placeholder + p_img def get_replacement_idefics3(item_idx: int) -> str: images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - grid_w, grid_h = self.info._get_image_feature_grid_size( + + return self.info.get_image_repl( image_width=image_size.width, image_height=image_size.height, - **hf_processor_mm_kwargs, + processor=hf_processor, ) - if grid_w == 0 and grid_h == 0: - image_placeholder = global_img_placeholder - else: - tiles_placeholder = list[str]() - for i in range(grid_h): - for j in range(grid_w): - placeholder_per_tile = tile_img_placeholder.format( - n_h=i + 1, n_w=j + 1) - tiles_placeholder.append(placeholder_per_tile) - # Add line break if it is the last tile in the row - if j == grid_w - 1: - tiles_placeholder.append("\n") - - image_placeholder = "".join( - [*tiles_placeholder, "\n", global_img_placeholder]) - return image_placeholder + fake_image_token return [ PromptReplacement( @@ -424,73 +544,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vision_config.patch_size)**2) / (config.scale_factor**2)) self.image_token_id = self.config.image_token_id - def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) - raise ValueError( - "The expected shape of pixel values per image per batch " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[ImageInputs]: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) - - if pixel_values is None and image_embeds is None: - return None - - if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - - return Idefics3ImageEmbeddingInputs( - type="image_embeds", - data=flatten_bn(image_embeds, concat=True), - ) - - if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if isinstance(pixel_values, list): - pixel_values = torch.cat(pixel_values, dim=1) - pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1) - else: - pixel_values = flatten_bn(pixel_values) - pixel_attention_mask = flatten_bn(pixel_attention_mask) - - return Idefics3ImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - pixel_attention_mask=pixel_attention_mask) - - raise AssertionError("This line should be unreachable.") - - def _image_pixels_to_features( + def image_pixels_to_features( self, pixel_values: torch.Tensor, - pixel_attention_mask: Optional[torch.BoolTensor] = None, - ) -> NestedTensors: + pixel_attention_mask: torch.Tensor, + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - num_patches = [x.size(0) for x in pixel_values] pixel_values = pixel_values.to( dtype=self.vision_model.embeddings.patch_embedding.weight.dtype ) # fp16 compatibility @@ -502,17 +562,9 @@ def _image_pixels_to_features( pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=(pixel_values.size(0), pixel_values.size(2), - pixel_values.size(3)), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask[ - real_images_inds].contiguous() + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask[ + real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size patches_subgrid = pixel_attention_mask.unfold(dimension=1, @@ -529,27 +581,7 @@ def _image_pixels_to_features( patch_attention_mask=patch_attention_mask, ) - return image_hidden_states.split(num_patches) - - def _process_image_pixels( - self, inputs: Idefics3ImagePixelInputs) -> NestedTensors: - assert self.vision_model is not None - - pixel_values = inputs["data"] - pixel_attention_mask = inputs["pixel_attention_mask"] - - return self._image_pixels_to_features(pixel_values, - pixel_attention_mask) - - def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: - if image_input["type"] == "image_embeds": - return image_input["data"] - - assert self.vision_model is not None - image_features = self._process_image_pixels(image_input) - num_patches = [x.size(0) for x in image_features] - image_features = torch.cat(image_features) - return self.connector(image_features).split(num_patches) + return image_hidden_states def get_input_embeddings( self, @@ -616,13 +648,113 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.sampler = get_sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") + + embed_is_patch = flatten_bn(embed_is_patch) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Idefics3ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + embed_is_patch=embed_is_patch, + ) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_attention_mask = kwargs.pop("pixel_attention_mask") + if not isinstance(pixel_attention_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_attention_mask. " + f"Got type: {type(pixel_attention_mask)}") + + num_patches = kwargs.pop("num_patches") + if not isinstance(num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + pixel_attention_mask = flatten_bn(pixel_attention_mask, + concat=True) + num_patches = flatten_bn(num_patches, concat=True) + + return Idefics3ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + embed_is_patch=embed_is_patch, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_pixels( + self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["pixel_values"] + pixel_attention_mask = inputs["pixel_attention_mask"] + + return self.model.image_pixels_to_features( + pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + image_features = self._process_image_pixels(image_input) + image_features = self.model.connector(image_features) + + num_patches = image_input["num_patches"] + return image_features.split(num_patches.tolist()) + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - image_input = self.model._parse_and_validate_image_input(**kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self.model._process_image_input(image_input) - return vision_embeddings + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, @@ -632,8 +764,11 @@ def get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.config.image_token_id) + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + self.config.image_token_id, + ) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d2c8fb723727..ac4bdbc41e44 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -21,7 +21,6 @@ import numpy as np import torch import torch.nn.functional as F -import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama from PIL.Image import Image from torch import nn diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f63bd0a11459..ccb5a3f600b2 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -160,7 +160,7 @@ def _call_hf_processor( mm_kwargs: Mapping[str, Any], ) -> BatchFeature: # Text-only input not supported in composite processor - if not mm_data or not mm_data.get("audios", []): + if not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index cb1e14383849..6e73a2ae656c 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -8,7 +8,6 @@ from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import functional as F from transformers import BatchFeature, ProcessorMixin @@ -160,7 +159,7 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: # Text-only input not supported in composite processor - if not mm_data or not mm_data.get("audios", []): + if not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode( prompt, add_special_tokens=False) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)