Skip to content
14 changes: 13 additions & 1 deletion tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
),
limit_mm_per_prompt={"video": 4},
runner_mm_key="videos",
)],
),
"llava_next_video": VLMTestInfo(
Expand Down Expand Up @@ -623,6 +622,19 @@
limit_mm_per_prompt={"image": 4},
)],
),
"qwen2_5_omni-mixed-modalities": VLMTestInfo(
models=["Qwen/Qwen2.5-Omni-3B"],
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForTextToWaveform,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner,
custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.mixed_modality_qwen2_5_omni(),
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)],
),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
Expand Down
77 changes: 45 additions & 32 deletions tests/models/multimodal/generation/vlm_utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .....conftest import ImageTestAssets, VideoTestAssets
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
ImageSizeWrapper, SizeType, VLMTestInfo)
ImageSizeWrapper, PromptWithMultiModalInput, SizeType,
VLMTestInfo)


def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int],
Expand Down Expand Up @@ -68,10 +69,11 @@ def get_model_prompts(base_prompts: Iterable[str],


def build_single_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError(
"Prompt formatter must be set to build single image inputs")
Expand All @@ -97,28 +99,32 @@ def build_single_image_inputs_from_test_info(
return build_single_image_inputs(images, model_prompts, size_wrapper)


def build_single_image_inputs(images, model_prompts,
size_wrapper: ImageSizeWrapper):
def build_single_image_inputs(
images, model_prompts,
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
# For every image / prompt pair, get a pair containing two lists of
# length size_factors, where the first contains duplicates of the model
# prompt [str], and the second contains copies of the image after being
# scaled by one of the size factors.
#
# NOTE: rescaling preserves the image aspect ratio.
return [(
[prompt for _ in size_wrapper.data],
[
apply_image_size_scaling(image, size, size_wrapper.type)
for size in size_wrapper.data
],
) for image, prompt in zip(images, model_prompts)]
return [
PromptWithMultiModalInput.create(
prompts=[prompt for _ in size_wrapper.data],
image_data=[
apply_image_size_scaling(image, size, size_wrapper.type)
for size in size_wrapper.data
],
) for image, prompt in zip(images, model_prompts)
]


def build_multi_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
test_info: VLMTestInfo,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None,
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError(
"Prompt formatter must be set to build multi image inputs")
Expand Down Expand Up @@ -146,15 +152,18 @@ def build_multi_image_inputs_from_test_info(
)


def build_multi_image_inputs(image_lists, model_prompts,
size_wrapper: ImageSizeWrapper):
return [(
[prompt for _ in size_wrapper.data],
[[
apply_image_size_scaling(image, size, size_wrapper.type)
for image in images
] for size in size_wrapper.data],
) for images, prompt in zip(image_lists, model_prompts)]
def build_multi_image_inputs(
image_lists, model_prompts,
size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]:
return [
PromptWithMultiModalInput.create(
prompts=[prompt for _ in size_wrapper.data],
image_data=[[
apply_image_size_scaling(image, size, size_wrapper.type)
for image in images
] for size in size_wrapper.data],
) for images, prompt in zip(image_lists, model_prompts)
]


def build_embedding_inputs_from_test_info(
Expand Down Expand Up @@ -195,7 +204,7 @@ def build_video_inputs_from_test_info(
video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper,
num_frames: int,
):
) -> list[PromptWithMultiModalInput]:
if test_info.prompt_formatter is None:
raise ValueError("Prompt formatter must be set to build video inputs")
model_prompts = get_model_prompts(
Expand All @@ -213,10 +222,14 @@ def build_video_inputs_from_test_info(
video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE
else rescale_video_size)

return [(
[prompt for _ in size_wrapper.data],
[video_scaler(video, size) for size in size_wrapper.data],
) for video, prompt in zip(sampled_vids, model_prompts)]
return [
PromptWithMultiModalInput.create(
prompts=[prompt for _ in size_wrapper.data],
video_data=[
video_scaler(video, size) for size in size_wrapper.data
],
) for video, prompt in zip(sampled_vids, model_prompts)
]


def apply_image_size_scaling(image, size: Union[float, tuple[int, int]],
Expand Down
29 changes: 18 additions & 11 deletions tests/models/multimodal/generation/vlm_utils/core.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
"""Core test implementation to be shared across modalities."""
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

import torch
from PIL.Image import Image
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from vllm.config import TaskOption
from vllm.transformers_utils.tokenizer import AnyTokenizer

from .....conftest import HfRunner, VllmRunner
from ....registry import HF_EXAMPLE_MODELS
from .types import RunnerOutput
from .types import PromptWithMultiModalInput, RunnerOutput


def run_test(
*,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], list[Union[list[Image], Image]]]],
inputs: list[PromptWithMultiModalInput],
model: str,
dtype: str,
max_tokens: int,
Expand All @@ -38,7 +37,6 @@ def run_test(
hf_model_kwargs: Optional[dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
task: TaskOption = "auto",
runner_mm_key: str = "images",
distributed_executor_backend: Optional[str] = None,
tensor_parallel_size: int = 1,
vllm_embeddings: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -94,10 +92,16 @@ def run_test(
if stop_str:
vllm_kwargs["stop"] = stop_str

for prompts, media in vllm_inputs:
vllm_kwargs[runner_mm_key] = media
for prompts, image_data, video_data, audio_data in vllm_inputs:
mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
vllm_kwargs_with_mm_data = vllm_kwargs | mm_data
vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs)
prompts,
max_tokens,
num_logprobs=num_logprobs,
**vllm_kwargs_with_mm_data)
vllm_outputs_per_mm.append(vllm_output)

hf_model = hf_runner(model,
Expand All @@ -122,14 +126,17 @@ def run_test(
if stop_str:
hf_kwargs["stop_strings"] = stop_str

for prompts, media in inputs:
hf_kwargs[runner_mm_key] = media
for prompts, image_data, video_data, audio_data in inputs:
mm_data = dict(images=image_data,
videos=video_data,
audios=audio_data)
hf_kwargs_with_mm_data = hf_kwargs | mm_data
hf_output = hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
tokenizer=tokenizer,
**hf_kwargs)
**hf_kwargs_with_mm_data)
hf_outputs_per_mm.append(hf_output)

# Apply output processing / sanitation to the vLLM and HF runner results
Expand Down
105 changes: 71 additions & 34 deletions tests/models/multimodal/generation/vlm_utils/custom_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import requests
from PIL import Image

from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)

from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS
from .builders import build_multi_image_inputs, build_single_image_inputs
from .types import ImageSizeWrapper, SizeType
from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType


def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
Expand All @@ -32,24 +35,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]):
"<image>\nWhat is the season?",
]
formatted_prompts = [formatter(prompt) for prompt in img_prompts]

return [(
formatted_prompts,
aspect_ratio_images = [
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
]

return [
PromptWithMultiModalInput.create(
prompts=formatted_prompts,
image_data=aspect_ratio_images,
)
]


def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
Expand All @@ -68,24 +75,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str],
"<video>\nWhy is this video funny?",
]
formatted_prompts = [formatter(prompt) for prompt in video_prompts]

return [(
formatted_prompts,
aspect_ratio_videos = [
[video, video],
# Videos with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
[video, video],
# Videos with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
])]
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
]

return [
PromptWithMultiModalInput.create(
prompts=formatted_prompts,
video_data=aspect_ratio_videos,
)
]


def different_patch_input_cases_internvl():
Expand Down Expand Up @@ -120,3 +131,29 @@ def windows_attention_image_qwen2_5_vl():

wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5])
return build_single_image_inputs([image], [prompt], wrapped_sf)


def mixed_modality_qwen2_5_omni(size_factor: float = 0.25):
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
question = ("What is recited in the audio? "
"What is the content of this image? Why is this video funny?")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
video = VideoAsset(name="baby_reading", num_frames=4).np_ndarrays
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
return [
PromptWithMultiModalInput.create(
prompts=[prompt],
image_data=[rescale_image_size(image, size_factor=size_factor)],
video_data=[video],
audio_data=[audio],
)
]
Loading