diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index d8f48644cc00..5513a5f78f1c 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -35,6 +35,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer logger = logging.getLogger(__name__) @@ -257,7 +258,7 @@ def process_image(image: Any) -> Mapping[str, Any]: if isinstance(image, dict) and "bytes" in image: image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): - image = image.convert("RGB") + image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 52b6e977eaa2..deb6f580a447 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -11,6 +11,7 @@ from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode from vllm.utils import FlexibleArgumentParser @@ -45,7 +46,8 @@ def get_mixed_modalities_query() -> QueryResult: "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, "image": - ImageAsset("cherry_blossom").pil_image.convert("RGB"), + convert_image_mode( + ImageAsset("cherry_blossom").pil_image, "RGB"), "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, }, diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index c54f328c7a38..941fcd381dea 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -19,6 +19,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.lora.request import LoRARequest +from vllm.multimodal.image import convert_image_mode from vllm.utils import FlexibleArgumentParser @@ -1096,8 +1097,8 @@ def get_multi_modal_input(args): """ if args.modality == "image": # Input image and question - image = ImageAsset("cherry_blossom") \ - .pil_image.convert("RGB") + image = convert_image_mode( + ImageAsset("cherry_blossom").pil_image, "RGB") img_questions = [ "What is the content of this image?", "Describe the content of this image in detail.", diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index eec84751e450..972db40e8bd6 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -4,6 +4,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset +from vllm.multimodal.image import convert_image_mode models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"] @@ -26,8 +27,9 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: give the same result. """ - image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB") - image_stop = ImageAsset("stop_sign").pil_image.convert("RGB") + image_cherry = convert_image_mode( + ImageAsset("cherry_blossom").pil_image, "RGB") + image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB") images = [image_cherry, image_stop] video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 11460a1a8d2b..5a12b5910949 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -12,7 +12,7 @@ from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest -from vllm.multimodal.image import rescale_image_size +from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs @@ -267,7 +267,7 @@ def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=None) - image = ImageAsset("cherry_blossom").pil_image.convert("RGB") + image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") inputs_vision_speech = [ ( diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index b45a87d94b86..b62720caa9cb 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -4,6 +4,7 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset +from vllm.multimodal.image import convert_image_mode from ..utils import create_new_process_for_each_test @@ -58,7 +59,7 @@ def test_oot_registration_embedding( assert all(v == 0 for v in output.outputs.embedding) -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") +image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") @create_new_process_for_each_test() diff --git a/tests/multimodal/assets/rgba.png b/tests/multimodal/assets/rgba.png new file mode 100644 index 000000000000..11eb81857a65 Binary files /dev/null and b/tests/multimodal/assets/rgba.png differ diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py new file mode 100644 index 000000000000..56b5475c9ca0 --- /dev/null +++ b/tests/multimodal/test_image.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +import numpy as np +from PIL import Image, ImageChops + +from vllm.multimodal.image import convert_image_mode + +ASSETS_DIR = Path(__file__).parent / "assets" +assert ASSETS_DIR.exists() + + +def test_rgb_to_rgb(): + # Start with an RGB image. + original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB") + converted_image = convert_image_mode(original_image, "RGB") + + # RGB to RGB should be a no-op. + diff = ImageChops.difference(original_image, converted_image) + assert diff.getbbox() is None + + +def test_rgba_to_rgb(): + original_image = Image.open(ASSETS_DIR / "rgba.png") + original_image_numpy = np.array(original_image) + + converted_image = convert_image_mode(original_image, "RGB") + converted_image_numpy = np.array(converted_image) + + for i in range(original_image_numpy.shape[0]): + for j in range(original_image_numpy.shape[1]): + # Verify that all transparent pixels are converted to white. + if original_image_numpy[i][j][3] == 0: + assert converted_image_numpy[i][j][0] == 255 + assert converted_image_numpy[i][j][1] == 255 + assert converted_image_numpy[i][j][2] == 255 diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 478184c34b91..f1e45da30eda 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -10,6 +10,7 @@ import pytest from PIL import Image, ImageChops +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, merge_and_sort_multimodal_metadata) @@ -53,7 +54,7 @@ def get_supported_suffixes() -> tuple[str, ...]: def _image_equals(a: Image.Image, b: Image.Image) -> bool: - return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() + return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all() @pytest.mark.asyncio diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index fab44fb6062d..13c37c979dac 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -13,7 +13,6 @@ TODO: Implement CustomDataset to parse a JSON file and convert its contents into SampleRequest instances, similar to the approach used in ShareGPT. """ - import base64 import io import json @@ -33,6 +32,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer logger = logging.getLogger(__name__) @@ -259,7 +259,7 @@ def process_image(image: Any) -> Mapping[str, Any]: if isinstance(image, dict) and 'bytes' in image: image = Image.open(BytesIO(image['bytes'])) if isinstance(image, Image.Image): - image = image.convert("RGB") + image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") image_base64 = base64.b64encode( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 66e78fcc4e80..f68513553846 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -23,6 +23,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, @@ -77,7 +78,7 @@ class InternVLImageEmbeddingInputs(TypedDict): def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD return T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 91f6c7753c68..eefadda918f6 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -24,6 +24,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, @@ -78,7 +79,7 @@ class SkyworkR1VImageEmbeddingInputs(TypedDict): def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD return T.Compose([ - T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index f6ab72f4e9b8..a5a4dcd0b6e1 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -10,6 +10,7 @@ from PIL import Image from vllm.logger import init_logger +from vllm.multimodal.image import convert_image_mode if TYPE_CHECKING: from vllm.inputs import TokensPrompt @@ -35,7 +36,8 @@ def serialize_item(cls, obj: object) -> bytes: return np.array(obj).tobytes() if isinstance(obj, Image.Image): - return cls.item_to_bytes("image", np.array(obj.convert("RGBA"))) + return cls.item_to_bytes("image", + np.array(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): return cls.item_to_bytes("tensor", obj.numpy()) if isinstance(obj, np.ndarray): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 939928bbf108..a63ec0bd8ada 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image, return image +# TODO: Support customizable background color to fill in. +def rgba_to_rgb( + image: Image.Image, background_color=(255, 255, 255)) -> Image.Image: + """Convert an RGBA image to RGB with filled background color.""" + assert image.mode == "RGBA" + converted = Image.new("RGB", image.size, background_color) + converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel + return converted + + +def convert_image_mode(image: Image.Image, to_mode: str): + if image.mode == to_mode: + return image + elif image.mode == "RGBA" and to_mode == "RGB": + return rgba_to_rgb(image) + else: + return image.convert(to_mode) + + class ImageMediaIO(MediaIO[Image.Image]): def __init__(self, *, image_mode: str = "RGB") -> None: @@ -32,7 +51,7 @@ def __init__(self, *, image_mode: str = "RGB") -> None: def load_bytes(self, data: bytes) -> Image.Image: image = Image.open(BytesIO(data)) image.load() - return image.convert(self.image_mode) + return convert_image_mode(image, self.image_mode) def load_base64(self, media_type: str, data: str) -> Image.Image: return self.load_bytes(base64.b64decode(data)) @@ -40,7 +59,7 @@ def load_base64(self, media_type: str, data: str) -> Image.Image: def load_file(self, filepath: Path) -> Image.Image: image = Image.open(filepath) image.load() - return image.convert(self.image_mode) + return convert_image_mode(image, self.image_mode) def encode_base64( self, @@ -51,7 +70,7 @@ def encode_base64( image = media with BytesIO() as buffer: - image = image.convert(self.image_mode) + image = convert_image_mode(image, self.image_mode) image.save(buffer, image_format) data = buffer.getvalue() diff --git a/vllm/transformers_utils/processors/ovis.py b/vllm/transformers_utils/processors/ovis.py index a35d32999991..f1c6407e1f3a 100644 --- a/vllm/transformers_utils/processors/ovis.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -33,6 +33,8 @@ Unpack) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from vllm.multimodal.image import convert_image_mode + __all__ = ['OvisProcessor'] IGNORE_ID = -100 @@ -361,8 +363,8 @@ def _get_best_grid(img, side): # pick the partition with maximum covering_ratio and break the tie using #sub_images return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] - if convert_to_rgb and image.mode != 'RGB': - image = image.convert('RGB') + if convert_to_rgb: + image = convert_image_mode(image, 'RGB') sides = self.get_image_size()