From 24f730a9628957cc73d5687295669a9877dca483 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 2 Feb 2025 21:42:16 +0800 Subject: [PATCH 01/17] impl idefics3 multimodal processor and v1 support Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 533 +++++++++++-------------- 1 file changed, 233 insertions(+), 300 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index d16a77f862d9..eb2607644fe9 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -14,21 +14,17 @@ """Inference-only Idefics3 model compatible with HuggingFace weights.""" import math -from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple, - Optional, Set, Tuple, TypedDict, Union) +from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) import torch import torch.utils.checkpoint -from PIL import Image from torch import nn -# Temporary solution for transformers below 4.46.0. -from transformers import PretrainedConfig as Idefics3Config -from transformers import ProcessorMixin as Idefics3ImageProcessor +from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, + Idefics3Processor) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -38,11 +34,15 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import NestedTensors -from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.processor import cached_get_processor -from vllm.utils import is_list_of +from vllm.multimodal.parse import ImageProcessorItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalDataItems, + MultiModalFieldConfig, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors # yapf: disable from .idefics2_vision_model import ( @@ -75,307 +75,238 @@ class Idefics3ImageEmbeddingInputs(TypedDict): """ -class Idefics3ProcessorSize(NamedTuple): - """Hashable wrapper for unhashable `size` dict of Idefics3Processor.""" - # NOTE: cached_get_processor/cached_get_image_processor uses lru_cache, - # we need to use NamedTuple instead of TypedDict to avoid hashing issues. - longest_edge: int - - def __contains__(self, key: str) -> bool: - return key in self._asdict() and getattr(self, key) is not None - - def __getitem__(self, key: str) -> int: - return getattr(self, key) - - ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] -def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict: - mm_processor_kwargs = {} - if size: - mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size) - return mm_processor_kwargs - - -def input_mapper_for_idefics3( - ctx: InputContext, - data: object, - *, - size: Optional[Dict[str, int]] = None, -): - model_config = ctx.model_config - mm_processor_kwargs = get_mm_processor_kwargs(size) - image_processor = cached_get_image_processor( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - **mm_processor_kwargs) - if image_processor is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - - if isinstance(data, Image.Image): - images = [[data]] - elif is_list_of(data, Image.Image): - images = [data] - else: - raise TypeError(f"Invalid image type: {type(data)}") - - try: - batch_data = image_processor(images, - return_tensors="pt", - return_row_col_info=True).data - except Exception: - logger.error("Failed to process image (%s)", data) - raise - - return MultiModalKwargs(batch_data) - - -def _resize_output_size(height: int, - width: int, - max_len: Optional[int] = None, - min_len: Optional[int] = 1, - max_size: Optional[int] = None) -> Tuple[int, int]: - # Set default value for max_len if not provided - max_len = max(height, width) if max_len is None else max_len - aspect_ratio = width / height - - # Handle the maximum size constraint - if max_size is not None: - max_len = min(max_len, max_size) - - # Adjust dimensions according to the aspect ratio - if width >= height: - width = max_len - height = int(width / aspect_ratio) - else: - height = max_len - width = int(height * aspect_ratio) - - # Ensure both width and height are even (if needed) - height += 1 if height % 2 != 0 else 0 - width += 1 if width % 2 != 0 else 0 - - # Ensure dimensions are not smaller than the minimum length - height = max(height, min_len) - width = max(width, min_len) - - return height, width - - -def _get_resize_output_image_size( - image_size: Tuple[int, int], - resolution_max_side: int, - max_image_size: int = 1820, -) -> Tuple[int, int]: - if resolution_max_side > max_image_size: - raise ValueError( - "`resolution_max_side` cannot be larger than `max_image_size`") - - height, width = image_size - - # Find the output size, when rescaling the longest edge to max_len and - # preserving the aspect ratio - height, width = _resize_output_size(height, - width, - max_len=resolution_max_side) - - return height, width - - -def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int, - fake_token_around_image: str, image_token: str, - global_img_token: str) -> str: - """ - Prompt with expanded image tokens for when the image is split - into patches. - """ - text_split_images = "" - for n_h in range(image_rows): - for n_w in range(image_cols): - text_split_images += (fake_token_around_image + - f"" + - image_token * image_seq_len) - text_split_images += "\n" - - text_split_images += "\n" + _prompt_single_image( - image_seq_len=image_seq_len, - fake_token_around_image=fake_token_around_image, - image_token=image_token, - global_img_token=global_img_token) - return text_split_images - - -def _prompt_single_image(image_seq_len: int, fake_token_around_image: str, - image_token: str, global_img_token: str): - """Prompt with expanded image tokens for a single image.""" - return (fake_token_around_image + global_img_token + - image_token * image_seq_len + fake_token_around_image) - - -def _get_image_prompt_string(image_rows: int, image_cols: int, - image_seq_len: int, fake_token_around_image: str, - image_token: str, global_img_token: str): - if image_rows == 0 and image_cols == 0: - return _prompt_single_image( - image_seq_len=image_seq_len, - fake_token_around_image=fake_token_around_image, - image_token=image_token, - global_img_token=global_img_token, - ) - return _prompt_split_image(image_seq_len, image_rows, image_cols, - fake_token_around_image, image_token, - global_img_token) - - -def input_processor_for_idefics3(ctx: InputContext, - inputs: DecoderOnlyInputs, - *, - size: Optional[Dict[str, int]] = None): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - mm_processor_kwargs = get_mm_processor_kwargs(size) - processor = cached_get_processor(model_config.model, **mm_processor_kwargs) - image_processor = processor.image_processor - tokenizer = processor.tokenizer - size = image_processor.size['longest_edge'] - max_image_size = image_processor.max_image_size['longest_edge'] - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_list = [image_data] - elif is_list_of(image_data, Image.Image): - image_list = image_data - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - image_rows = [] - image_cols = [] - for image in image_list: - height, width = _get_resize_output_image_size(image.size, size) - - rows = math.ceil(height / max_image_size) - cols = math.ceil(width / max_image_size) - image_rows.append(rows) - image_cols.append(cols) - image_rows = [image_rows] - image_cols = [image_cols] - - n_images_in_text = [] - - text = inputs.get("prompt") - if text is None: - prompt_token_ids = inputs.get("prompt_token_ids", []) - assert prompt_token_ids - text = tokenizer.decode(prompt_token_ids) - - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, " - "or a list of strings") - - fake_image_token = processor.fake_image_token.content - image_token = processor.image_token.content - global_img_token = processor.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): - n_images_in_text.append(sample.count(image_token)) - - # Replace the image token with fake tokens around the expanded - # image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = _get_image_prompt_string( - n_rows, - n_cols, - processor.image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, - ) - image_prompt_strings.append(image_prompt_string) - - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError("The image token should be present in the text.") - - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) +class Idefics3ProcessingInfo(BaseProcessingInfo): - prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + def get_hf_processor( + self, + *, + size: Optional[Dict[str, int]] = None) -> Idefics3Processor: + if size is not None: + return self.ctx.get_hf_processor(Idefics3Processor, size=size) - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_strings[0], - multi_modal_data=multi_modal_data, - ) + return self.ctx.get_hf_processor(Idefics3Processor) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} -def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: - size = image_processor.size['longest_edge'] - max_image_size = image_processor.max_image_size['longest_edge'] - resized_height, resized_width = size, size + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + hf_processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_processor.size['longest_edge'], + image_height=image_processor.size['longest_edge'], + ) + return {"image": (grid_w * grid_h + 1) * hf_processor.image_seq_len} + + def _resize_output_size(self, + *, + height: int, + width: int, + max_len: Optional[int] = None, + min_len: Optional[int] = 1, + max_size: Optional[int] = None) -> tuple[int, int]: + # Set default value for max_len if not provided + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + # Handle the maximum size constraint + if max_size is not None: + max_len = min(max_len, max_size) + + # Adjust dimensions according to the aspect ratio + if width >= height: + width = max_len + height = int(width / aspect_ratio) + else: + height = max_len + width = int(height * aspect_ratio) - grid_h = resized_height // max_image_size - grid_w = resized_width // max_image_size - return (grid_h * grid_w + 1) + # Ensure both width and height are even (if needed) + height += 1 if height % 2 != 0 else 0 + width += 1 if width % 2 != 0 else 0 + # Ensure dimensions are not smaller than the minimum length + height = max(height, min_len) + width = max(width, min_len) -def get_max_idefics3_image_tokens(ctx: InputContext, - *, - size: Optional[Dict[str, - int]] = None) -> int: - model_config = ctx.model_config - mm_processor_kwargs = get_mm_processor_kwargs(size) - processor = cached_get_processor(model_config.model, **mm_processor_kwargs) - image_seq_len = processor.image_seq_len - image_processor = processor.image_processor + return height, width - max_num_image_patches = _get_max_num_image_patch(image_processor) + def _get_resize_output_image_size( + self, + *, + image_width: int, + image_height: int, + resolution_max_side: int, + ) -> tuple[int, int]: + hf_processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + max_image_size = image_processor.size['longest_edge'] + if resolution_max_side > max_image_size: + raise ValueError( + "`resolution_max_side` cannot be larger than `max_image_size`") + + height, width = image_height, image_width + + # Find the output size, when rescaling the longest edge to max_len and + # preserving the aspect ratio + height, width = self._resize_output_size(height=height, + width=width, + max_len=resolution_max_side) + return height, width + + def _get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + hf_processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + max_image_size = image_processor.max_image_size['longest_edge'] + size = image_processor.size['longest_edge'] + assert size % max_image_size == 0, ( + "`longest_edge` in image_processor's `size` must be divisible by " + "`longest_edge` in `max_image_size`, this may cause by incorrect " + "mm_kwargs override.") + + resized_height, resized_width = self._get_resize_output_image_size( + image_width=image_width, + image_height=image_height, + resolution_max_side=size, + ) + if resized_height > max_image_size or resized_width > max_image_size: + grid_h = math.ceil(resized_height / max_image_size) + grid_w = math.ceil(resized_width / max_image_size) + else: + grid_h = grid_w = 0 + return grid_w, grid_h - return max_num_image_patches * image_seq_len +class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] + ): -def dummy_data_for_idefics3( - ctx: InputContext, + def get_dummy_processor_inputs( + self, seq_len: int, mm_counts: Mapping[str, int], - *, - size: Optional[Dict[str, int]] = None) -> DummyData: - hf_config = ctx.get_hf_config() - num_images = mm_counts["image"] + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + hf_processor = self.info.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + longest_edge = image_processor.max_image_size['longest_edge'] + image_token: str = hf_processor.image_token.content + + mm_data = { + "image": + self._get_dummy_images(width=longest_edge, + height=longest_edge, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + - mm_processor_kwargs = get_mm_processor_kwargs(size) - processor = cached_get_processor(ctx.model_config.model, - **mm_processor_kwargs) - max_num_image_patches = _get_max_num_image_patch(processor.image_processor) - image_seq_len = processor.image_seq_len - max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images +class Idefics3MultimodalProcessor( + BaseMultiModalProcessor[Idefics3ProcessingInfo]): - if seq_len - max_llm_image_tokens < 0: - raise RuntimeError( - f"Idefics3 cannot process {num_images} images in a prompt, " - "please increase max_model_len or reduce image limit by " - "--limit-mm-per-prompt.") + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + image_grids = [ + self.info._get_image_feature_grid_size( + image_width=img.width, + image_height=img.height, + ) for img in mm_data["images"] + ] + patches_per_image = list( + map(lambda x: math.prod(x) + 1, image_grids)) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + processed_outputs["patches_per_image"] = patches_per_image + for key in ("pixel_values", "pixel_attention_mask"): + processed_outputs[key] = processed_outputs[key].flatten( + 0, 1).split(patches_per_image) + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs - seq_data = SequenceData.from_prompt_token_counts( - (hf_config.image_token_id, max_llm_image_tokens), - (0, seq_len - max_llm_image_tokens)) + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.pop("patches_per_image", []) + slices_idxs = [0] + [ + sum(num_patches[:i]) for i in range(1, + len(num_patches) + 1) + ] + slices = [ + slice(slices_idxs[i], slices_idxs[i + 1]) + for i in range(len(num_patches)) + ] + return dict( + pixel_values=MultiModalFieldConfig.flat("image", slices=slices), + pixel_attention_mask=MultiModalFieldConfig.flat("image", + slices=slices), + image_embeds=MultiModalFieldConfig.batched("image"), + ) - width = height = hf_config.vision_config.image_size - image = Image.new("RGB", (width, height), color=0) - mm_data = {"image": [image] if num_images == 1 else [image] * num_images} + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_processor = self.info.get_hf_processor() + + image_token = hf_processor.image_token.content + fake_image_token = hf_processor.fake_image_token.content + global_img_token = hf_processor.global_image_tag + image_seq_len = hf_processor.image_seq_len + grid_placeholder = "" + + global_img_placeholder = global_img_token + image_token * image_seq_len + tile_img_placeholder = grid_placeholder + image_token * image_seq_len + + def get_replacement_idefics3(item_idx: int) -> str: + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + grid_w, grid_h = self.info._get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) - return DummyData(seq_data, mm_data) + if grid_w == 1 and grid_h == 1: + image_placeholder = global_img_placeholder + else: + tiles_placeholder = "".join( + tile_img_placeholder.format(n_h=i + 1, n_w=j + 1) + for i in range(grid_h) for j in range(grid_w)) + image_placeholder = (tiles_placeholder + "\n\n" + + global_img_placeholder) + return fake_image_token + image_placeholder + fake_image_token + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_idefics3, + ) + ] class Idefics3SimpleMLP(nn.Module): @@ -539,7 +470,7 @@ def _image_pixels_to_features( self, pixel_values: torch.Tensor, pixel_attention_mask: Optional[torch.BoolTensor] = None, - ) -> torch.Tensor: + ) -> NestedTensors: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower batch_size, num_images, num_channels, height, width = pixel_values.shape @@ -585,10 +516,10 @@ def _image_pixels_to_features( patch_attention_mask=patch_attention_mask, ) - return image_hidden_states + return image_hidden_states.split([num_images] * batch_size) def _process_image_pixels( - self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + self, inputs: Idefics3ImagePixelInputs) -> NestedTensors: assert self.vision_model is not None pixel_values = inputs["data"] @@ -603,7 +534,9 @@ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: assert self.vision_model is not None image_features = self._process_image_pixels(image_input) - return self.connector(image_features) + num_patches = [x.size(0) for x in image_features] + image_features = torch.cat(image_features) + return self.connector(image_features).split(num_patches) def get_input_embeddings( self, @@ -632,10 +565,10 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) -@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) +@MULTIMODAL_REGISTRY.register_processor( + Idefics3MultimodalProcessor, + info=Idefics3ProcessingInfo, + dummy_inputs=Idefics3DummyInputsBuilder) class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA): packed_modules_mapping = { From cd7f2a51032db6d2d515a3d422806c5656707dd0 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 2 Feb 2025 21:57:45 +0800 Subject: [PATCH 02/17] hash dict Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/inputs/registry.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4b73ade7af5f..676a6b8fe605 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -29,6 +29,15 @@ P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin) +class HashableDict(dict): + """ + A dictionary that can be hashed by lru_cache. + """ + + def __hash__(self) -> int: + return hash(frozenset(self.items())) + + @dataclass(frozen=True) class InputContext: """ @@ -102,6 +111,13 @@ def get_hf_processor( if isinstance(typ, type): merged_kwargs["processor_cls"] = typ + # NOTE: Pythonic dict is not hashable and will raise unhashable type + # error when calling `cached_get_processor`, therefore we need to + # wrap it to a hashable dict. + for key, value in merged_kwargs.items(): + if isinstance(value, dict): + merged_kwargs[key] = HashableDict(value) + hf_processor = cached_get_processor( self.model_config.model, trust_remote_code=self.model_config.trust_remote_code, From 0dbbc2504b427dd03d909495674cab8c27d7eb40 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 3 Feb 2025 23:38:04 +0800 Subject: [PATCH 03/17] update processor test Signed-off-by: Isotr0py <2037008807@qq.com> --- .../multimodal/processing/test_common.py | 1 + .../multimodal/processing/test_idefics3.py | 179 ++++-------------- vllm/model_executor/models/idefics3.py | 28 ++- 3 files changed, 51 insertions(+), 157 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index ca28da268fa0..038f941aca48 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -146,6 +146,7 @@ def _test_processing_correctness( "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", "adept/fuyu-8b", + "HuggingFaceM4/Idefics3-8B-Llama3", "llava-hf/llava-1.5-7b-hf", "llava-hf/llava-v1.6-mistral-7b-hf", "llava-hf/LLaVA-NeXT-Video-7B-hf", diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index 69b91ad4a5df..c6deac4dfdfc 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -1,12 +1,8 @@ """Tests for Idefics3's multimodal preprocessing kwargs.""" -from typing import Optional - import pytest -import torch -from transformers import AutoImageProcessor, AutoTokenizer -from vllm.inputs import InputContext, token_inputs -from vllm.multimodal import MultiModalRegistry +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer from ....conftest import _ImageAssets from ...utils import build_model_context @@ -14,163 +10,52 @@ models = ["HuggingFaceM4/Idefics3-8B-Llama3"] -# Wrap lazy imports to avoid initializing CUDA during test collection -@pytest.fixture() -def input_processor_for_idefics3(): - from vllm.model_executor.models.idefics3 import ( - input_processor_for_idefics3) - return input_processor_for_idefics3 - - -@pytest.fixture() -def dummy_data_for_idefics3(): - from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3 - return dummy_data_for_idefics3 - - -@pytest.fixture() -def get_max_idefics3_image_tokens(): - from vllm.model_executor.models.idefics3 import ( - get_max_idefics3_image_tokens) - return get_max_idefics3_image_tokens - - @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336]) -def test_input_mapper_override(model: str, image_assets: _ImageAssets, - longest_edge: Optional[int]): - """Ensure that the [default] input mapper handles size properly.""" - - mm_processor_kwargs = { - "size": { - "longest_edge": longest_edge - } - } if longest_edge is not None else {} - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=mm_processor_kwargs, - ) - - hf_processor = AutoImageProcessor.from_pretrained(model, - trust_remote_code=True, - **mm_processor_kwargs) - - mm_registry = MultiModalRegistry() - mm_registry.init_mm_limits_per_prompt(ctx.model_config) - - image = image_assets[0].pil_image - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - - vllm_result = mm_registry.map_input( - ctx.model_config, - {"image": image}, - ) - - assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("longest_edge, expected_max_tokens", [ - (None, 2873), - (168, 169), - (336, 169), - (400, 338), - (672, 338), -]) -def test_max_tokens_override(get_max_idefics3_image_tokens, model: str, - longest_edge: Optional[int], - expected_max_tokens: int): - """Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs.""" - size = {"longest_edge": longest_edge} if longest_edge is not None else None - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - actual_max_tokens = get_max_idefics3_image_tokens( - ctx=InputContext(ctx.model_config), - size=size, - ) - - assert expected_max_tokens == actual_max_tokens - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [ - (168, 169, 1), - (168, 169, 2), - (400, 338, 1), - (400, 338, 2), -]) -def test_dummy_data_override(dummy_data_for_idefics3, model: str, - longest_edge: int, toks_per_img: int, - num_imgs: int): - """Ensure dummy_data_for_idefics3 handles num_crops properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the dummy data func. - size = {"longest_edge": longest_edge} if longest_edge is not None else None - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - dummy_data = dummy_data_for_idefics3( - ctx=ctx, - seq_len=8192, # Should be bigger than num_imgs * toks_per_img - mm_counts={"image": num_imgs}, - size=size) - sequence_data = dummy_data.seq_data - # Ensure we have the right number of placeholders per size - image_token_id = ctx.get_hf_config().image_token_id - img_tok_count = sequence_data.get_token_ids().count(image_token_id) - assert img_tok_count == toks_per_img * num_imgs - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [ - (336, 169 * (1**2 + 1), 1), - (336, 169 * (1**2 + 1), 2), - (400, 169 * (2**2 + 1), 1), - (400, 169 * (2**2 + 1), 2), -]) -def test_input_processor_override(input_processor_for_idefics3, - image_assets: _ImageAssets, model: str, - longest_edge: int, - expected_toks_per_img: int, num_imgs: int): +# yapf: disable +@pytest.mark.parametrize( + ("mm_processor_kwargs", "expected_toks_per_img"), + [ + ({"size": {"longest_edge": 364}}, 169), + ({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_override(image_assets: _ImageAssets, model: str, + mm_processor_kwargs: dict[str, object], + expected_toks_per_img: int, num_imgs: int): """Ensure input_processor_for_idefics3 handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by # the partial when calling the custom input processor. - size = {"longest_edge": longest_edge} if longest_edge is not None else None ctx = build_model_context( model_name=model, tokenizer_name=model, trust_remote_code=True, mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + processor = MULTIMODAL_REGISTRY.create_processor( + ctx.model_config, + tokenizer=tokenizer, ) + hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) # Build the image str / prompt based on the number of images we pass - tokenizer = AutoTokenizer.from_pretrained(model) + # placeholders = "" if num_imgs == 1 else "\n".join( + # f"Image-{i}: \n" for i in range(1, num_imgs + 1)) placeholders = "" if num_imgs == 1 else "\n".join( f"Image-{i}: \n" for i in range(1, num_imgs + 1)) prompt = f"<|begin_of_text|>User:{placeholders}\n\nAssistant:" # noqa: E501 - images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs - - inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) - - processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size) + mm_data = { + "image": [image_assets[0].pil_image.resize( + (336 * 4, 336 * 4))] * num_imgs + } + + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) + assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ + "input_ids"][0] # Ensure we have the right number of placeholders per num_crops size image_token_id = ctx.get_hf_config().image_token_id diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index eb2607644fe9..912de874ebc3 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -234,8 +234,8 @@ def _call_hf_processor( prompt, mm_data, mm_kwargs) processed_outputs["patches_per_image"] = patches_per_image for key in ("pixel_values", "pixel_attention_mask"): - processed_outputs[key] = processed_outputs[key].flatten( - 0, 1).split(patches_per_image) + processed_outputs[key] = list(processed_outputs[key].flatten( + 0, 1).split(patches_per_image)) else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, @@ -278,7 +278,8 @@ def _get_prompt_replacements( image_seq_len = hf_processor.image_seq_len grid_placeholder = "" - global_img_placeholder = global_img_token + image_token * image_seq_len + global_img_placeholder = fake_image_token + global_img_token + ( + image_token * image_seq_len) tile_img_placeholder = grid_placeholder + image_token * image_seq_len def get_replacement_idefics3(item_idx: int) -> str: @@ -289,16 +290,23 @@ def get_replacement_idefics3(item_idx: int) -> str: image_width=image_size.width, image_height=image_size.height, ) - - if grid_w == 1 and grid_h == 1: + if grid_w == 0 and grid_h == 0: image_placeholder = global_img_placeholder else: - tiles_placeholder = "".join( - tile_img_placeholder.format(n_h=i + 1, n_w=j + 1) - for i in range(grid_h) for j in range(grid_w)) - image_placeholder = (tiles_placeholder + "\n\n" + + placeholder_per_tile = [] + for i in range(grid_h): + for j in range(grid_w): + image_placeholder = tile_img_placeholder.format( + n_h=i + 1, n_w=j + 1) + if j == grid_w - 1: + image_placeholder += "\n" + placeholder_per_tile.append(image_placeholder) + + tiles_placeholder = fake_image_token + fake_image_token.join( + placeholder_per_tile) + image_placeholder = (tiles_placeholder + "\n" + global_img_placeholder) - return fake_image_token + image_placeholder + fake_image_token + return image_placeholder + fake_image_token return [ PromptReplacement( From 08ee4d1a3504f2404fb5efb35154d48b881f53ec Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 00:50:21 +0800 Subject: [PATCH 04/17] fix batching Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 912de874ebc3..303f2786589b 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -234,8 +234,7 @@ def _call_hf_processor( prompt, mm_data, mm_kwargs) processed_outputs["patches_per_image"] = patches_per_image for key in ("pixel_values", "pixel_attention_mask"): - processed_outputs[key] = list(processed_outputs[key].flatten( - 0, 1).split(patches_per_image)) + processed_outputs[key] = processed_outputs[key].flatten(0, 1) else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, @@ -390,7 +389,7 @@ class Idefics3Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config: Idefics3Config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config @@ -463,9 +462,6 @@ def _parse_and_validate_image_input( if isinstance(pixel_values, list): pixel_values = torch.cat(pixel_values, dim=1) pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1) - else: - pixel_values = flatten_bn(pixel_values) - pixel_attention_mask = flatten_bn(pixel_attention_mask) return Idefics3ImagePixelInputs( type="pixel_values", @@ -481,12 +477,10 @@ def _image_pixels_to_features( ) -> NestedTensors: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - batch_size, num_images, num_channels, height, width = pixel_values.shape + num_patches = [x.size(0) for x in pixel_values] pixel_values = pixel_values.to( dtype=self.vision_model.embeddings.patch_embedding.weight.dtype ) # fp16 compatibility - pixel_values = pixel_values.view(batch_size * num_images, - *pixel_values.shape[2:]) # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() @@ -504,8 +498,6 @@ def _image_pixels_to_features( ) else: # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:]) pixel_attention_mask = pixel_attention_mask[ real_images_inds].contiguous() @@ -524,7 +516,7 @@ def _image_pixels_to_features( patch_attention_mask=patch_attention_mask, ) - return image_hidden_states.split([num_images] * batch_size) + return image_hidden_states.split(num_patches) def _process_image_pixels( self, inputs: Idefics3ImagePixelInputs) -> NestedTensors: From a71a670f2b43ea5425b455f3af2246264c9c8c03 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 11:54:24 +0800 Subject: [PATCH 05/17] simplify batching Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 29 ++++++++++---------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 303f2786589b..2642548292f3 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -222,19 +222,19 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) image_grids = [ self.info._get_image_feature_grid_size( image_width=img.width, image_height=img.height, ) for img in mm_data["images"] ] - patches_per_image = list( - map(lambda x: math.prod(x) + 1, image_grids)) - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs) - processed_outputs["patches_per_image"] = patches_per_image + image_patches = list(map(lambda x: math.prod(x) + 1, image_grids)) for key in ("pixel_values", "pixel_attention_mask"): - processed_outputs[key] = processed_outputs[key].flatten(0, 1) + data = processed_outputs.pop(key) + data = data.flatten(0, 1).split(image_patches) + processed_outputs[key] = data else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, @@ -247,19 +247,9 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_patches = hf_inputs.pop("patches_per_image", []) - slices_idxs = [0] + [ - sum(num_patches[:i]) for i in range(1, - len(num_patches) + 1) - ] - slices = [ - slice(slices_idxs[i], slices_idxs[i + 1]) - for i in range(len(num_patches)) - ] return dict( - pixel_values=MultiModalFieldConfig.flat("image", slices=slices), - pixel_attention_mask=MultiModalFieldConfig.flat("image", - slices=slices), + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_attention_mask=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -462,6 +452,9 @@ def _parse_and_validate_image_input( if isinstance(pixel_values, list): pixel_values = torch.cat(pixel_values, dim=1) pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1) + else: + pixel_values = flatten_bn(pixel_values) + pixel_attention_mask = flatten_bn(pixel_attention_mask) return Idefics3ImagePixelInputs( type="pixel_values", From 8ed33cd04e468e4a15e57f36ec8b62fe8159c40f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 13:08:12 +0800 Subject: [PATCH 06/17] fix override test Signed-off-by: Isotr0py <2037008807@qq.com> --- .../models/multimodal/processing/test_idefics3.py | 14 ++++++++------ vllm/model_executor/models/idefics3.py | 7 +++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index c6deac4dfdfc..5701e649f91d 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -1,5 +1,6 @@ """Tests for Idefics3's multimodal preprocessing kwargs.""" import pytest +from transformers import Idefics3Config from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer @@ -42,17 +43,18 @@ def test_processor_override(image_assets: _ImageAssets, model: str, hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs) # Build the image str / prompt based on the number of images we pass - # placeholders = "" if num_imgs == 1 else "\n".join( - # f"Image-{i}: \n" for i in range(1, num_imgs + 1)) placeholders = "" if num_imgs == 1 else "\n".join( f"Image-{i}: \n" for i in range(1, num_imgs + 1)) prompt = f"<|begin_of_text|>User:{placeholders}\n\nAssistant:" # noqa: E501 - mm_data = { - "image": [image_assets[0].pil_image.resize( - (336 * 4, 336 * 4))] * num_imgs - } + + # Build mm_data + image_size = ctx.get_hf_config(Idefics3Config).vision_config.image_size + dummy_image_size = (image_size * 4, image_size * 4) + dummy_image = image_assets[0].pil_image.resize(dummy_image_size) + mm_data = {"image": [dummy_image] * num_imgs} processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + # Ensure the placeholders format are correct hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ "input_ids"][0] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 2642548292f3..8b719d32667a 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -162,8 +162,9 @@ def _get_image_feature_grid_size( *, image_width: int, image_height: int, + size: Optional[dict[str, object]] = None, ) -> tuple[int, int]: - hf_processor = self.get_hf_processor() + hf_processor = self.get_hf_processor(size=size) image_processor: Idefics3ImageProcessor = hf_processor.image_processor max_image_size = image_processor.max_image_size['longest_edge'] size = image_processor.size['longest_edge'] @@ -228,6 +229,7 @@ def _call_hf_processor( self.info._get_image_feature_grid_size( image_width=img.width, image_height=img.height, + size=mm_kwargs.get("size", None), ) for img in mm_data["images"] ] image_patches = list(map(lambda x: math.prod(x) + 1, image_grids)) @@ -259,7 +261,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self.info.get_hf_processor() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token = hf_processor.image_token.content fake_image_token = hf_processor.fake_image_token.content @@ -278,6 +280,7 @@ def get_replacement_idefics3(item_idx: int) -> str: grid_w, grid_h = self.info._get_image_feature_grid_size( image_width=image_size.width, image_height=image_size.height, + **hf_processor_mm_kwargs, ) if grid_w == 0 and grid_h == 0: image_placeholder = global_img_placeholder From fce1417ee86768e728f4ca462b886a16b417d40b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 14:19:14 +0800 Subject: [PATCH 07/17] fix v1 profiling Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.md | 2 +- vllm/model_executor/models/idefics3.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index afaad8818bdc..fd03f14895f4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -657,7 +657,7 @@ See [this page](#generative-models) for more information on how to use generativ * `HuggingFaceM4/Idefics3-8B-Llama3` etc. * ✅︎ * - * + * ✅︎ - * `InternVLChatModel` * InternVL 2.5, Mono-InternVL, InternVL 2.0 * T + IE+ diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 8b719d32667a..3369f186b3bf 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -99,7 +99,13 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: image_width=image_processor.size['longest_edge'], image_height=image_processor.size['longest_edge'], ) - return {"image": (grid_w * grid_h + 1) * hf_processor.image_seq_len} + # Non-image-token: + # cost 2 token per patch + # each row has one line break cost 1 token + # at the last + num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len + non_image_token = (grid_w * grid_h + 1) * 2 + grid_h + 1 + return {"image": num_image_token + non_image_token} def _resize_output_size(self, *, @@ -616,7 +622,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.config.text_config.tie_word_embeddings: self.lm_head.weight = self.model.text_model.wte.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self.model._parse_and_validate_image_input(**kwargs) From 4f2837de325999c0cf9d230f8d0aab7a8a03b864 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 15:29:39 +0800 Subject: [PATCH 08/17] migrate model tests Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/decoder_only/vision_language/test_models.py | 4 ++-- .../decoder_only/vision_language/vlm_utils/model_utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 62c644f73d62..621ab441808d 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -252,14 +252,14 @@ patch_hf_runner=model_utils.h2ovl_patch_hf_runner, ), "idefics3": VLMTestInfo( - models=["HuggingFaceM4/Idefics3-8B-Llama3"], + models=["HuggingFaceTB/SmolVLM-256M-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 img_idx_to_prompt=lambda idx: "", max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForVision2Seq, - marks=[large_gpu_mark(min_gb=48)], + hf_output_post_proc=model_utils.idefics3_trunc_hf_output, ), "intern_vl": VLMTestInfo( models=[ diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 07bdb2cee44d..a6b3763bc928 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -191,6 +191,14 @@ def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, return output_ids, output_str, out_logprobs +def idefics3_trunc_hf_output(hf_output: RunnerOutput, + model: str) -> RunnerOutput: + output_ids, output_str, out_logprobs = hf_output + if output_str.endswith(""): + output_str = output_str.split("")[0] + return output_ids, output_str, out_logprobs + + def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output From 90006395ff306d4c600376686227f498c8d0011e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 15:39:28 +0800 Subject: [PATCH 09/17] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 3369f186b3bf..e8cd70834615 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -235,7 +235,7 @@ def _call_hf_processor( self.info._get_image_feature_grid_size( image_width=img.width, image_height=img.height, - size=mm_kwargs.get("size", None), + **mm_kwargs, ) for img in mm_data["images"] ] image_patches = list(map(lambda x: math.prod(x) + 1, image_grids)) @@ -275,9 +275,9 @@ def _get_prompt_replacements( image_seq_len = hf_processor.image_seq_len grid_placeholder = "" - global_img_placeholder = fake_image_token + global_img_token + ( - image_token * image_seq_len) - tile_img_placeholder = grid_placeholder + image_token * image_seq_len + p_img = image_token * image_seq_len + global_img_placeholder = fake_image_token + global_img_token + p_img + tile_img_placeholder = fake_image_token + grid_placeholder + p_img def get_replacement_idefics3(item_idx: int) -> str: images = mm_items.get_items("image", ImageProcessorItems) @@ -291,17 +291,16 @@ def get_replacement_idefics3(item_idx: int) -> str: if grid_w == 0 and grid_h == 0: image_placeholder = global_img_placeholder else: - placeholder_per_tile = [] + tiles_placeholder = "" for i in range(grid_h): for j in range(grid_w): - image_placeholder = tile_img_placeholder.format( + placeholder_per_tile = tile_img_placeholder.format( n_h=i + 1, n_w=j + 1) + # Add line break if it is the last tile in the row if j == grid_w - 1: - image_placeholder += "\n" - placeholder_per_tile.append(image_placeholder) + placeholder_per_tile += "\n" + tiles_placeholder += placeholder_per_tile - tiles_placeholder = fake_image_token + fake_image_token.join( - placeholder_per_tile) image_placeholder = (tiles_placeholder + "\n" + global_img_placeholder) return image_placeholder + fake_image_token From c72b0cbcdc9882981144c54aac3e93e3ab7200b3 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 15:46:48 +0800 Subject: [PATCH 10/17] fix mypy Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/inputs/registry.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 676a6b8fe605..6945548148eb 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -34,7 +34,9 @@ class HashableDict(dict): A dictionary that can be hashed by lru_cache. """ - def __hash__(self) -> int: + # NOTE: pythonic dict is not hashable, + # we override on it directly for simplicity + def __hash__(self) -> int: # type: ignore[override] return hash(frozenset(self.items())) From 163571770303c35dfd37061908f85a814c9f5d17 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 4 Feb 2025 16:04:35 +0800 Subject: [PATCH 11/17] Update vllm/model_executor/models/idefics3.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/idefics3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index e8cd70834615..cedd96bcd3fd 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -131,8 +131,8 @@ def _resize_output_size(self, width = int(height * aspect_ratio) # Ensure both width and height are even (if needed) - height += 1 if height % 2 != 0 else 0 - width += 1 if width % 2 != 0 else 0 + height += height % 2 + width += width % 2 # Ensure dimensions are not smaller than the minimum length height = max(height, min_len) From 370486c27b35aa6d6a0fbc8411c5310f20b46512 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 4 Feb 2025 16:04:41 +0800 Subject: [PATCH 12/17] Update vllm/model_executor/models/idefics3.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/idefics3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index cedd96bcd3fd..c41dd6ea9828 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -176,8 +176,8 @@ def _get_image_feature_grid_size( size = image_processor.size['longest_edge'] assert size % max_image_size == 0, ( "`longest_edge` in image_processor's `size` must be divisible by " - "`longest_edge` in `max_image_size`, this may cause by incorrect " - "mm_kwargs override.") + "`longest_edge` in `max_image_size`, this may be caused by " + "incorrect mm_kwargs override.") resized_height, resized_width = self._get_resize_output_image_size( image_width=image_width, From d1f9547a0eb7ac5058c3664335a065bdd5c0444a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 4 Feb 2025 16:04:50 +0800 Subject: [PATCH 13/17] Update vllm/model_executor/models/idefics3.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/idefics3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index c41dd6ea9828..51c76d7ef047 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -291,18 +291,18 @@ def get_replacement_idefics3(item_idx: int) -> str: if grid_w == 0 and grid_h == 0: image_placeholder = global_img_placeholder else: - tiles_placeholder = "" + tiles_placeholder = [] for i in range(grid_h): for j in range(grid_w): placeholder_per_tile = tile_img_placeholder.format( n_h=i + 1, n_w=j + 1) # Add line break if it is the last tile in the row if j == grid_w - 1: - placeholder_per_tile += "\n" - tiles_placeholder += placeholder_per_tile + placeholder_per_tile += ["\n"] + tiles_placeholder += [placeholder_per_tile] - image_placeholder = (tiles_placeholder + "\n" + - global_img_placeholder) + image_placeholder = "".join([tiles_placeholder, "\n", + global_img_placeholder]) return image_placeholder + fake_image_token return [ From fcff87f4f0a58f3da66b51c4a46269d1c0c6490b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 16:14:17 +0800 Subject: [PATCH 14/17] fix typo Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 51c76d7ef047..30978074eac6 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -296,13 +296,13 @@ def get_replacement_idefics3(item_idx: int) -> str: for j in range(grid_w): placeholder_per_tile = tile_img_placeholder.format( n_h=i + 1, n_w=j + 1) + tiles_placeholder.append(placeholder_per_tile) # Add line break if it is the last tile in the row if j == grid_w - 1: - placeholder_per_tile += ["\n"] - tiles_placeholder += [placeholder_per_tile] + tiles_placeholder.append("\n") - image_placeholder = "".join([tiles_placeholder, "\n", - global_img_placeholder]) + image_placeholder = "".join( + [*tiles_placeholder, "\n", global_img_placeholder]) return image_placeholder + fake_image_token return [ From 99527dea91b68e8c2b43f8ff0543c78a85ff1a6e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 16:19:10 +0800 Subject: [PATCH 15/17] add annotation Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 30978074eac6..385e1879dc9e 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -291,7 +291,7 @@ def get_replacement_idefics3(item_idx: int) -> str: if grid_w == 0 and grid_h == 0: image_placeholder = global_img_placeholder else: - tiles_placeholder = [] + tiles_placeholder = list[str]() for i in range(grid_h): for j in range(grid_w): placeholder_per_tile = tile_img_placeholder.format( From 06f5de9fc49c2d0c9328c62c70e1d8c587603b32 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 17:11:33 +0800 Subject: [PATCH 16/17] fix mm_max_token calculation Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 385e1879dc9e..ba4fc187fcec 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -99,12 +99,18 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: image_width=image_processor.size['longest_edge'], image_height=image_processor.size['longest_edge'], ) - # Non-image-token: - # cost 2 token per patch - # each row has one line break cost 1 token - # at the last num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len - non_image_token = (grid_w * grid_h + 1) * 2 + grid_h + 1 + # Calculate Non-image-token length + # NOTE: and are special token for SmolVLM + # but not for Idefic3, so we need to tokenize them to get actual length. + tokenizer = self.get_tokenizer() + tile_token_len = len(tokenizer.tokenize("")) + glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag)) + # linebreak and always cost 1 token + fake_token_len = lb_len = 1 + non_image_token = (grid_w * grid_h) * ( + tile_token_len + fake_token_len) + glob_token_len + ( + grid_h + 1) * lb_len + fake_token_len return {"image": num_image_token + non_image_token} def _resize_output_size(self, From b6bf1ad19425f145aee63dffc1fad2ef58c86409 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 4 Feb 2025 17:19:26 +0800 Subject: [PATCH 17/17] update signature Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/idefics3.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index a0358728365d..fdfabbaafce3 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -94,7 +94,11 @@ def get_hf_processor( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: hf_processor = self.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor grid_w, grid_h = self._get_image_feature_grid_size(