diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index fbdca189af62..d8e284292951 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -733,7 +733,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `HuggingFaceM4/Idefics3-8B-Llama3` etc.
* ✅︎
*
- *
+ * ✅︎
- * `InternVLChatModel`
* InternVL 2.5, Mono-InternVL, InternVL 2.0
* T + IE+
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 7a14ba2f3b60..5fe46bd75b9f 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -254,14 +254,14 @@
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
- models=["HuggingFaceM4/Idefics3-8B-Llama3"],
+ models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501
img_idx_to_prompt=lambda idx: "",
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
- marks=[large_gpu_mark(min_gb=48)],
+ hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
),
"intern_vl": VLMTestInfo(
models=[
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index d2401b222558..ced891e1e2c2 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -192,6 +192,14 @@ def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs
+def idefics3_trunc_hf_output(hf_output: RunnerOutput,
+ model: str) -> RunnerOutput:
+ output_ids, output_str, out_logprobs = hf_output
+ if output_str.endswith(""):
+ output_str = output_str.split("")[0]
+ return output_ids, output_str, out_logprobs
+
+
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 07906a71d06e..5cd749cbd779 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -149,6 +149,7 @@ def _test_processing_correctness(
"adept/fuyu-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
+ "HuggingFaceM4/Idefics3-8B-Llama3",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py
index 00c1dae51158..07ab1bbd4b5e 100644
--- a/tests/models/multimodal/processing/test_idefics3.py
+++ b/tests/models/multimodal/processing/test_idefics3.py
@@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for Idefics3's multimodal preprocessing kwargs."""
-from typing import Optional
-
import pytest
-import torch
-from transformers import AutoImageProcessor, AutoTokenizer
+from transformers import Idefics3Config
-from vllm.inputs import InputContext, token_inputs
-from vllm.multimodal import MultiModalRegistry
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
@@ -15,163 +12,53 @@
models = ["HuggingFaceM4/Idefics3-8B-Llama3"]
-# Wrap lazy imports to avoid initializing CUDA during test collection
-@pytest.fixture()
-def input_processor_for_idefics3():
- from vllm.model_executor.models.idefics3 import (
- input_processor_for_idefics3)
- return input_processor_for_idefics3
-
-
-@pytest.fixture()
-def dummy_data_for_idefics3():
- from vllm.model_executor.models.idefics3 import dummy_data_for_idefics3
- return dummy_data_for_idefics3
-
-
-@pytest.fixture()
-def get_max_idefics3_image_tokens():
- from vllm.model_executor.models.idefics3 import (
- get_max_idefics3_image_tokens)
- return get_max_idefics3_image_tokens
-
-
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("longest_edge", [None, 168, 336, 400, 2 * 336])
-def test_input_mapper_override(model: str, image_assets: _ImageAssets,
- longest_edge: Optional[int]):
- """Ensure that the [default] input mapper handles size properly."""
-
- mm_processor_kwargs = {
- "size": {
- "longest_edge": longest_edge
- }
- } if longest_edge is not None else {}
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- )
-
- hf_processor = AutoImageProcessor.from_pretrained(model,
- trust_remote_code=True,
- **mm_processor_kwargs)
-
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
-
- image = image_assets[0].pil_image
- hf_result = hf_processor.preprocess(
- image,
- return_tensors="pt",
- )
-
- vllm_result = mm_registry.map_input(
- ctx.model_config,
- {"image": image},
- )
-
- assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
-
-
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("longest_edge, expected_max_tokens", [
- (None, 2873),
- (168, 169),
- (336, 169),
- (400, 338),
- (672, 338),
-])
-def test_max_tokens_override(get_max_idefics3_image_tokens, model: str,
- longest_edge: Optional[int],
- expected_max_tokens: int):
- """Ensure get_max_idefics3_image_tokens handles mm_processor_kwargs."""
- size = {"longest_edge": longest_edge} if longest_edge is not None else None
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=None,
- )
-
- actual_max_tokens = get_max_idefics3_image_tokens(
- ctx=InputContext(ctx.model_config),
- size=size,
- )
-
- assert expected_max_tokens == actual_max_tokens
-
-
-@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("longest_edge, toks_per_img, num_imgs", [
- (168, 169, 1),
- (168, 169, 2),
- (400, 338, 1),
- (400, 338, 2),
-])
-def test_dummy_data_override(dummy_data_for_idefics3, model: str,
- longest_edge: int, toks_per_img: int,
- num_imgs: int):
- """Ensure dummy_data_for_idefics3 handles num_crops properly."""
- # Same as the previous test - don't initialize mm_processor_kwargs
- # in this test and assume that the kwargs will be correctly expanded by
- # the partial when calling the dummy data func.
- size = {"longest_edge": longest_edge} if longest_edge is not None else None
- ctx = build_model_context(
- model_name=model,
- tokenizer_name=model,
- trust_remote_code=True,
- mm_processor_kwargs=None,
- )
-
- dummy_data = dummy_data_for_idefics3(
- ctx=ctx,
- seq_len=8192, # Should be bigger than num_imgs * toks_per_img
- mm_counts={"image": num_imgs},
- size=size)
- sequence_data = dummy_data.seq_data
- # Ensure we have the right number of placeholders per size
- image_token_id = ctx.get_hf_config().image_token_id
- img_tok_count = sequence_data.get_token_ids().count(image_token_id)
- assert img_tok_count == toks_per_img * num_imgs
-
-
@pytest.mark.parametrize("model", models)
-@pytest.mark.parametrize("longest_edge,expected_toks_per_img,num_imgs", [
- (336, 169 * (1**2 + 1), 1),
- (336, 169 * (1**2 + 1), 2),
- (400, 169 * (2**2 + 1), 1),
- (400, 169 * (2**2 + 1), 2),
-])
-def test_input_processor_override(input_processor_for_idefics3,
- image_assets: _ImageAssets, model: str,
- longest_edge: int,
- expected_toks_per_img: int, num_imgs: int):
+# yapf: disable
+@pytest.mark.parametrize(
+ ("mm_processor_kwargs", "expected_toks_per_img"),
+ [
+ ({"size": {"longest_edge": 364}}, 169),
+ ({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)),
+ ])
+# yapf: enable
+@pytest.mark.parametrize("num_imgs", [1, 2])
+def test_processor_override(image_assets: _ImageAssets, model: str,
+ mm_processor_kwargs: dict[str, object],
+ expected_toks_per_img: int, num_imgs: int):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
- size = {"longest_edge": longest_edge} if longest_edge is not None else None
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
+ limit_mm_per_prompt={"image": num_imgs},
)
+ tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
+ processor = MULTIMODAL_REGISTRY.create_processor(
+ ctx.model_config,
+ tokenizer=tokenizer,
+ )
+ hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
# Build the image str / prompt based on the number of images we pass
- tokenizer = AutoTokenizer.from_pretrained(model)
placeholders = "" if num_imgs == 1 else "\n".join(
f"Image-{i}: \n" for i in range(1, num_imgs + 1))
prompt = f"<|begin_of_text|>User:{placeholders}\n\nAssistant:" # noqa: E501
- images = [image_assets[0].pil_image.resize((336 * 4, 336 * 4))] * num_imgs
-
- inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
- prompt=prompt,
- multi_modal_data={"image": images})
- processed_inputs = input_processor_for_idefics3(ctx, inputs, size=size)
+ # Build mm_data
+ image_size = ctx.get_hf_config(Idefics3Config).vision_config.image_size
+ dummy_image_size = (image_size * 4, image_size * 4)
+ dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
+ mm_data = {"image": [dummy_image] * num_imgs}
+
+ processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
+ # Ensure the placeholders format are correct
+ hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
+ assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
+ "input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
image_token_id = ctx.get_hf_config().image_token_id
diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py
index 0ec726b8b05f..cd421443981f 100644
--- a/vllm/inputs/registry.py
+++ b/vllm/inputs/registry.py
@@ -31,6 +31,17 @@
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
+class HashableDict(dict):
+ """
+ A dictionary that can be hashed by lru_cache.
+ """
+
+ # NOTE: pythonic dict is not hashable,
+ # we override on it directly for simplicity
+ def __hash__(self) -> int: # type: ignore[override]
+ return hash(frozenset(self.items()))
+
+
@dataclass(frozen=True)
class InputContext:
"""
@@ -104,6 +115,13 @@ def get_hf_processor(
if isinstance(typ, type):
merged_kwargs["processor_cls"] = typ
+ # NOTE: Pythonic dict is not hashable and will raise unhashable type
+ # error when calling `cached_get_processor`, therefore we need to
+ # wrap it to a hashable dict.
+ for key, value in merged_kwargs.items():
+ if isinstance(value, dict):
+ merged_kwargs[key] = HashableDict(value)
+
hf_processor = cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py
index 9e2e677a652e..fdfabbaafce3 100644
--- a/vllm/model_executor/models/idefics3.py
+++ b/vllm/model_executor/models/idefics3.py
@@ -16,35 +16,35 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
-from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
- Optional, Set, Tuple, TypedDict, Union)
+from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
+ Tuple, TypedDict, Union)
import torch
import torch.utils.checkpoint
-from PIL import Image
from torch import nn
-# Temporary solution for transformers below 4.46.0.
-from transformers import PretrainedConfig as Idefics3Config
-from transformers import ProcessorMixin as Idefics3ImageProcessor
+from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
+ Idefics3Processor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
-from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
-from vllm.sequence import IntermediateTensors, SequenceData
-from vllm.transformers_utils.processor import cached_get_processor
-from vllm.utils import is_list_of
+from vllm.multimodal.parse import ImageProcessorItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ MultiModalDataItems,
+ MultiModalFieldConfig,
+ PromptReplacement)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+from vllm.sequence import IntermediateTensors
# yapf: disable
from .idefics2_vision_model import (
@@ -77,307 +77,253 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
"""
-class Idefics3ProcessorSize(NamedTuple):
- """Hashable wrapper for unhashable `size` dict of Idefics3Processor."""
- # NOTE: cached_get_processor/cached_get_image_processor uses lru_cache,
- # we need to use NamedTuple instead of TypedDict to avoid hashing issues.
- longest_edge: int
-
- def __contains__(self, key: str) -> bool:
- return key in self._asdict() and getattr(self, key) is not None
-
- def __getitem__(self, key: str) -> int:
- return getattr(self, key)
-
-
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
-def get_mm_processor_kwargs(size: Optional[Dict[str, int]] = None) -> Dict:
- mm_processor_kwargs = {}
- if size:
- mm_processor_kwargs["size"] = Idefics3ProcessorSize(**size)
- return mm_processor_kwargs
-
-
-def input_mapper_for_idefics3(
- ctx: InputContext,
- data: object,
- *,
- size: Optional[Dict[str, int]] = None,
-):
- model_config = ctx.model_config
- mm_processor_kwargs = get_mm_processor_kwargs(size)
- image_processor = cached_get_image_processor(
- model_config.model,
- trust_remote_code=model_config.trust_remote_code,
- **mm_processor_kwargs)
- if image_processor is None:
- raise RuntimeError("No HuggingFace processor is available "
- "to process the image object")
-
- if isinstance(data, Image.Image):
- images = [[data]]
- elif is_list_of(data, Image.Image):
- images = [data]
- else:
- raise TypeError(f"Invalid image type: {type(data)}")
-
- try:
- batch_data = image_processor(images,
- return_tensors="pt",
- return_row_col_info=True).data
- except Exception:
- logger.error("Failed to process image (%s)", data)
- raise
-
- return MultiModalKwargs(batch_data)
-
-
-def _resize_output_size(height: int,
- width: int,
- max_len: Optional[int] = None,
- min_len: Optional[int] = 1,
- max_size: Optional[int] = None) -> Tuple[int, int]:
- # Set default value for max_len if not provided
- max_len = max(height, width) if max_len is None else max_len
- aspect_ratio = width / height
-
- # Handle the maximum size constraint
- if max_size is not None:
- max_len = min(max_len, max_size)
-
- # Adjust dimensions according to the aspect ratio
- if width >= height:
- width = max_len
- height = int(width / aspect_ratio)
- else:
- height = max_len
- width = int(height * aspect_ratio)
-
- # Ensure both width and height are even (if needed)
- height += 1 if height % 2 != 0 else 0
- width += 1 if width % 2 != 0 else 0
-
- # Ensure dimensions are not smaller than the minimum length
- height = max(height, min_len)
- width = max(width, min_len)
-
- return height, width
-
-
-def _get_resize_output_image_size(
- image_size: Tuple[int, int],
- resolution_max_side: int,
- max_image_size: int = 1820,
-) -> Tuple[int, int]:
- if resolution_max_side > max_image_size:
- raise ValueError(
- "`resolution_max_side` cannot be larger than `max_image_size`")
-
- height, width = image_size
-
- # Find the output size, when rescaling the longest edge to max_len and
- # preserving the aspect ratio
- height, width = _resize_output_size(height,
- width,
- max_len=resolution_max_side)
-
- return height, width
-
-
-def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int,
- fake_token_around_image: str, image_token: str,
- global_img_token: str) -> str:
- """
- Prompt with expanded image tokens for when the image is split
- into patches.
- """
- text_split_images = ""
- for n_h in range(image_rows):
- for n_w in range(image_cols):
- text_split_images += (fake_token_around_image +
- f"" +
- image_token * image_seq_len)
- text_split_images += "\n"
-
- text_split_images += "\n" + _prompt_single_image(
- image_seq_len=image_seq_len,
- fake_token_around_image=fake_token_around_image,
- image_token=image_token,
- global_img_token=global_img_token)
- return text_split_images
-
-
-def _prompt_single_image(image_seq_len: int, fake_token_around_image: str,
- image_token: str, global_img_token: str):
- """Prompt with expanded image tokens for a single image."""
- return (fake_token_around_image + global_img_token +
- image_token * image_seq_len + fake_token_around_image)
-
-
-def _get_image_prompt_string(image_rows: int, image_cols: int,
- image_seq_len: int, fake_token_around_image: str,
- image_token: str, global_img_token: str):
- if image_rows == 0 and image_cols == 0:
- return _prompt_single_image(
- image_seq_len=image_seq_len,
- fake_token_around_image=fake_token_around_image,
- image_token=image_token,
- global_img_token=global_img_token,
- )
- return _prompt_split_image(image_seq_len, image_rows, image_cols,
- fake_token_around_image, image_token,
- global_img_token)
-
-
-def input_processor_for_idefics3(ctx: InputContext,
- inputs: DecoderOnlyInputs,
- *,
- size: Optional[Dict[str, int]] = None):
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- model_config = ctx.model_config
- mm_processor_kwargs = get_mm_processor_kwargs(size)
- processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
- image_processor = processor.image_processor
- tokenizer = processor.tokenizer
- size = image_processor.size['longest_edge']
- max_image_size = image_processor.max_image_size['longest_edge']
-
- image_data = multi_modal_data["image"]
- if isinstance(image_data, Image.Image):
- image_list = [image_data]
- elif is_list_of(image_data, Image.Image):
- image_list = image_data
- else:
- raise TypeError(f"Invalid image type: {type(image_data)}")
-
- image_rows = []
- image_cols = []
- for image in image_list:
- height, width = _get_resize_output_image_size(image.size, size)
-
- rows = math.ceil(height / max_image_size)
- cols = math.ceil(width / max_image_size)
- image_rows.append(rows)
- image_cols.append(cols)
- image_rows = [image_rows]
- image_cols = [image_cols]
-
- n_images_in_text = []
-
- text = inputs.get("prompt")
- if text is None:
- prompt_token_ids = inputs.get("prompt_token_ids", [])
- assert prompt_token_ids
- text = tokenizer.decode(prompt_token_ids)
-
- if isinstance(text, str):
- text = [text]
- elif not isinstance(text, list) and not isinstance(text[0], str):
- raise ValueError("Invalid input text. Please provide a string, "
- "or a list of strings")
-
- fake_image_token = processor.fake_image_token.content
- image_token = processor.image_token.content
- global_img_token = processor.global_image_tag
-
- prompt_strings = []
- for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
- n_images_in_text.append(sample.count(image_token))
-
- # Replace the image token with fake tokens around the expanded
- # image token sequence of length `image_seq_len`
- image_prompt_strings = []
- for n_rows, n_cols in zip(sample_rows, sample_cols):
- image_prompt_string = _get_image_prompt_string(
- n_rows,
- n_cols,
- processor.image_seq_len,
- image_token=image_token,
- fake_token_around_image=fake_image_token,
- global_img_token=global_img_token,
- )
- image_prompt_strings.append(image_prompt_string)
-
- split_sample = sample.split(image_token)
- if len(split_sample) == 0:
- raise ValueError("The image token should be present in the text.")
-
- # Place in the image prompt strings where the image tokens are
- sample = split_sample[0]
- for i, image_prompt_string in enumerate(image_prompt_strings):
- sample += image_prompt_string + split_sample[i + 1]
- prompt_strings.append(sample)
+class Idefics3ProcessingInfo(BaseProcessingInfo):
- prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids
+ def get_hf_processor(
+ self,
+ *,
+ size: Optional[Dict[str, int]] = None) -> Idefics3Processor:
+ if size is not None:
+ return self.ctx.get_hf_processor(Idefics3Processor, size=size)
- return token_inputs(
- prompt_token_ids=prompt_token_ids,
- prompt=prompt_strings[0],
- multi_modal_data=multi_modal_data,
- )
+ return self.ctx.get_hf_processor(Idefics3Processor)
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None}
-def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int:
- size = image_processor.size['longest_edge']
- max_image_size = image_processor.max_image_size['longest_edge']
- resized_height, resized_width = size, size
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ hf_processor = self.get_hf_processor()
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
+ grid_w, grid_h = self._get_image_feature_grid_size(
+ image_width=image_processor.size['longest_edge'],
+ image_height=image_processor.size['longest_edge'],
+ )
+ num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
+ # Calculate Non-image-token length
+ # NOTE: and are special token for SmolVLM
+ # but not for Idefic3, so we need to tokenize them to get actual length.
+ tokenizer = self.get_tokenizer()
+ tile_token_len = len(tokenizer.tokenize(""))
+ glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
+ # linebreak and always cost 1 token
+ fake_token_len = lb_len = 1
+ non_image_token = (grid_w * grid_h) * (
+ tile_token_len + fake_token_len) + glob_token_len + (
+ grid_h + 1) * lb_len + fake_token_len
+ return {"image": num_image_token + non_image_token}
+
+ def _resize_output_size(self,
+ *,
+ height: int,
+ width: int,
+ max_len: Optional[int] = None,
+ min_len: Optional[int] = 1,
+ max_size: Optional[int] = None) -> tuple[int, int]:
+ # Set default value for max_len if not provided
+ max_len = max(height, width) if max_len is None else max_len
+ aspect_ratio = width / height
+
+ # Handle the maximum size constraint
+ if max_size is not None:
+ max_len = min(max_len, max_size)
+
+ # Adjust dimensions according to the aspect ratio
+ if width >= height:
+ width = max_len
+ height = int(width / aspect_ratio)
+ else:
+ height = max_len
+ width = int(height * aspect_ratio)
- grid_h = resized_height // max_image_size
- grid_w = resized_width // max_image_size
- return (grid_h * grid_w + 1)
+ # Ensure both width and height are even (if needed)
+ height += height % 2
+ width += width % 2
+ # Ensure dimensions are not smaller than the minimum length
+ height = max(height, min_len)
+ width = max(width, min_len)
-def get_max_idefics3_image_tokens(ctx: InputContext,
- *,
- size: Optional[Dict[str,
- int]] = None) -> int:
- model_config = ctx.model_config
- mm_processor_kwargs = get_mm_processor_kwargs(size)
- processor = cached_get_processor(model_config.model, **mm_processor_kwargs)
- image_seq_len = processor.image_seq_len
- image_processor = processor.image_processor
+ return height, width
- max_num_image_patches = _get_max_num_image_patch(image_processor)
+ def _get_resize_output_image_size(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ resolution_max_side: int,
+ ) -> tuple[int, int]:
+ hf_processor = self.get_hf_processor()
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
+ max_image_size = image_processor.size['longest_edge']
+ if resolution_max_side > max_image_size:
+ raise ValueError(
+ "`resolution_max_side` cannot be larger than `max_image_size`")
+
+ height, width = image_height, image_width
+
+ # Find the output size, when rescaling the longest edge to max_len and
+ # preserving the aspect ratio
+ height, width = self._resize_output_size(height=height,
+ width=width,
+ max_len=resolution_max_side)
+ return height, width
+
+ def _get_image_feature_grid_size(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ size: Optional[dict[str, object]] = None,
+ ) -> tuple[int, int]:
+ hf_processor = self.get_hf_processor(size=size)
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
+ max_image_size = image_processor.max_image_size['longest_edge']
+ size = image_processor.size['longest_edge']
+ assert size % max_image_size == 0, (
+ "`longest_edge` in image_processor's `size` must be divisible by "
+ "`longest_edge` in `max_image_size`, this may be caused by "
+ "incorrect mm_kwargs override.")
+
+ resized_height, resized_width = self._get_resize_output_image_size(
+ image_width=image_width,
+ image_height=image_height,
+ resolution_max_side=size,
+ )
+ if resized_height > max_image_size or resized_width > max_image_size:
+ grid_h = math.ceil(resized_height / max_image_size)
+ grid_w = math.ceil(resized_width / max_image_size)
+ else:
+ grid_h = grid_w = 0
+ return grid_w, grid_h
- return max_num_image_patches * image_seq_len
+class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
+ ):
-def dummy_data_for_idefics3(
- ctx: InputContext,
+ def get_dummy_processor_inputs(
+ self,
seq_len: int,
mm_counts: Mapping[str, int],
- *,
- size: Optional[Dict[str, int]] = None) -> DummyData:
- hf_config = ctx.get_hf_config()
- num_images = mm_counts["image"]
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+ hf_processor = self.info.get_hf_processor()
+ image_processor: Idefics3ImageProcessor = hf_processor.image_processor
+ longest_edge = image_processor.max_image_size['longest_edge']
+ image_token: str = hf_processor.image_token.content
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=longest_edge,
+ height=longest_edge,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text=image_token * num_images,
+ mm_data=mm_data,
+ )
- mm_processor_kwargs = get_mm_processor_kwargs(size)
- processor = cached_get_processor(ctx.model_config.model,
- **mm_processor_kwargs)
- max_num_image_patches = _get_max_num_image_patch(processor.image_processor)
- image_seq_len = processor.image_seq_len
- max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images
- if seq_len - max_llm_image_tokens < 0:
- raise RuntimeError(
- f"Idefics3 cannot process {num_images} images in a prompt, "
- "please increase max_model_len or reduce image limit by "
- "--limit-mm-per-prompt.")
+class Idefics3MultimodalProcessor(
+ BaseMultiModalProcessor[Idefics3ProcessingInfo]):
- seq_data = SequenceData.from_prompt_token_counts(
- (hf_config.image_token_id, max_llm_image_tokens),
- (0, seq_len - max_llm_image_tokens))
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if mm_data:
+ processed_outputs = super()._call_hf_processor(
+ prompt, mm_data, mm_kwargs)
+ image_grids = [
+ self.info._get_image_feature_grid_size(
+ image_width=img.width,
+ image_height=img.height,
+ **mm_kwargs,
+ ) for img in mm_data["images"]
+ ]
+ image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
+ for key in ("pixel_values", "pixel_attention_mask"):
+ data = processed_outputs.pop(key)
+ data = data.flatten(0, 1).split(image_patches)
+ processed_outputs[key] = data
+ else:
+ tokenizer = self.info.get_tokenizer()
+ processed_outputs = tokenizer(prompt,
+ add_special_tokens=True,
+ return_tensors="pt")
+ return processed_outputs
- width = height = hf_config.vision_config.image_size
- image = Image.new("RGB", (width, height), color=0)
- mm_data = {"image": [image] if num_images == 1 else [image] * num_images}
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ pixel_values=MultiModalFieldConfig.batched("image"),
+ pixel_attention_mask=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
+ )
- return DummyData(seq_data, mm_data)
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ image_token = hf_processor.image_token.content
+ fake_image_token = hf_processor.fake_image_token.content
+ global_img_token = hf_processor.global_image_tag
+ image_seq_len = hf_processor.image_seq_len
+ grid_placeholder = ""
+
+ p_img = image_token * image_seq_len
+ global_img_placeholder = fake_image_token + global_img_token + p_img
+ tile_img_placeholder = fake_image_token + grid_placeholder + p_img
+
+ def get_replacement_idefics3(item_idx: int) -> str:
+ images = mm_items.get_items("image", ImageProcessorItems)
+
+ image_size = images.get_image_size(item_idx)
+ grid_w, grid_h = self.info._get_image_feature_grid_size(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ **hf_processor_mm_kwargs,
+ )
+ if grid_w == 0 and grid_h == 0:
+ image_placeholder = global_img_placeholder
+ else:
+ tiles_placeholder = list[str]()
+ for i in range(grid_h):
+ for j in range(grid_w):
+ placeholder_per_tile = tile_img_placeholder.format(
+ n_h=i + 1, n_w=j + 1)
+ tiles_placeholder.append(placeholder_per_tile)
+ # Add line break if it is the last tile in the row
+ if j == grid_w - 1:
+ tiles_placeholder.append("\n")
+
+ image_placeholder = "".join(
+ [*tiles_placeholder, "\n", global_img_placeholder])
+ return image_placeholder + fake_image_token
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_replacement_idefics3,
+ )
+ ]
class Idefics3SimpleMLP(nn.Module):
@@ -453,7 +399,7 @@ class Idefics3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
- config = vllm_config.model_config.hf_config
+ config: Idefics3Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
@@ -541,15 +487,13 @@ def _image_pixels_to_features(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
- ) -> torch.Tensor:
+ ) -> NestedTensors:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
- batch_size, num_images, num_channels, height, width = pixel_values.shape
+ num_patches = [x.size(0) for x in pixel_values]
pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
- pixel_values = pixel_values.view(batch_size * num_images,
- *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
@@ -567,8 +511,6 @@ def _image_pixels_to_features(
)
else:
# Remove padding images from the mask
- pixel_attention_mask = pixel_attention_mask.view(
- batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
@@ -587,10 +529,10 @@ def _image_pixels_to_features(
patch_attention_mask=patch_attention_mask,
)
- return image_hidden_states
+ return image_hidden_states.split(num_patches)
def _process_image_pixels(
- self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
+ self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
assert self.vision_model is not None
pixel_values = inputs["data"]
@@ -605,7 +547,9 @@ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
- return self.connector(image_features)
+ num_patches = [x.size(0) for x in image_features]
+ image_features = torch.cat(image_features)
+ return self.connector(image_features).split(num_patches)
def get_input_embeddings(
self,
@@ -634,10 +578,10 @@ def forward(
return hidden_states
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
+@MULTIMODAL_REGISTRY.register_processor(
+ Idefics3MultimodalProcessor,
+ info=Idefics3ProcessingInfo,
+ dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
@@ -689,7 +633,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if self.config.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
- self.sampler = Sampler()
+ self.sampler = get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self.model._parse_and_validate_image_input(**kwargs)