Skip to content

Commit bea6fa5

Browse files
Isotr0pyDarkLight1337
authored andcommitted
[Model] Refactor Qwen2-VL to use merged multimodal processor (vllm-project#11258)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 4e14d4d commit bea6fa5

File tree

5 files changed

+277
-527
lines changed

5 files changed

+277
-527
lines changed

examples/offline_inference_vision_language.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str):
447447

448448
# Qwen2-VL
449449
def run_qwen2_vl(question: str, modality: str):
450-
assert modality == "image"
451450

452451
model_name = "Qwen/Qwen2-VL-7B-Instruct"
453452

@@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str):
463462
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
464463
)
465464

465+
if modality == "image":
466+
placeholder = "<|image_pad|>"
467+
elif modality == "video":
468+
placeholder = "<|video_pad|>"
469+
466470
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
467-
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
471+
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
468472
f"{question}<|im_end|>\n"
469473
"<|im_start|>assistant\n")
470474
stop_token_ids = None
Lines changed: 65 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from typing import Any, Dict, Tuple
22

33
import pytest
4-
import torch
5-
from PIL.Image import Image
64
from transformers import AutoTokenizer
75

8-
from vllm.inputs import InputContext, token_inputs
9-
from vllm.multimodal import MultiModalRegistry
6+
from vllm.inputs import InputContext, InputProcessingContext
107

118
from .....conftest import _ImageAssets
129
from ....utils import build_model_context
@@ -20,22 +17,9 @@
2017
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
2118
# input mappers.
2219
@pytest.fixture()
23-
def image_input_mapper_for_qwen2_vl():
24-
from vllm.model_executor.models.qwen2_vl import (
25-
image_input_mapper_for_qwen2_vl)
26-
return image_input_mapper_for_qwen2_vl
27-
28-
29-
@pytest.fixture()
30-
def input_processor_for_qwen2_vl():
31-
from vllm.model_executor.models.qwen2_vl import (
32-
input_processor_for_qwen2_vl)
33-
return input_processor_for_qwen2_vl
34-
35-
36-
@pytest.fixture()
37-
def qwen2_vl_context() -> InputContext:
38-
return build_model_context(model_name=MODEL)
20+
def processor_for_qwen2_vl():
21+
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
22+
return Qwen2VLMultiModalProcessor
3923

4024

4125
@pytest.fixture()
@@ -45,123 +29,77 @@ def get_max_qwen2_vl_image_tokens():
4529
return get_max_qwen2_vl_image_tokens
4630

4731

48-
@pytest.fixture()
49-
def dummy_data_for_qwen2_vl():
50-
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
51-
return dummy_data_for_qwen2_vl
52-
53-
5432
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
5533
({}, 1225),
5634
({
5735
MIN_PIXELS: 64**2,
5836
MAX_PIXELS: 512**2
5937
}, 324),
6038
])
61-
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens,
62-
qwen2_vl_context: InputContext,
63-
mm_processor_kwargs: Dict[str, Any],
64-
expected_max_tokens: int):
39+
@pytest.mark.parametrize("model", [MODEL])
40+
def test_qwen2_vl_max_image_tokens(
41+
get_max_qwen2_vl_image_tokens,
42+
model: str,
43+
mm_processor_kwargs: Dict[str, Any],
44+
expected_max_tokens: int,
45+
):
6546
"""Ensure that the max token calc handles min/max pixels properly."""
66-
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context,
67-
**mm_processor_kwargs)
68-
assert actual_max_tokens == expected_max_tokens
69-
70-
71-
@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [
72-
[{}, 1225, (980, 980)],
73-
[{
74-
MIN_PIXELS: 64**2,
75-
MAX_PIXELS: 512**2
76-
}, 324, (504, 504)],
77-
])
78-
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
79-
qwen2_vl_context: InputContext,
80-
mm_processor_kwargs: Dict[str, Any],
81-
token_count: int, img_size: Tuple[int, int]):
82-
"""Ensure that the dummy data handles min/max pixels properly."""
83-
seq_len = 3000
84-
hf_config = qwen2_vl_context.get_hf_config()
85-
image_token_id = hf_config.image_token_id
86-
87-
# NOTE: video value is required, but isn't actually used
88-
# when making the dummy data except for error handling currently
89-
dummy_data = dummy_data_for_qwen2_vl(
90-
ctx=qwen2_vl_context,
91-
seq_len=seq_len,
92-
mm_counts={
93-
"image": 1,
94-
"video": 0
95-
},
96-
**mm_processor_kwargs,
47+
ctx = build_model_context(
48+
model_name=model,
49+
tokenizer_name=model,
50+
mm_processor_kwargs=None,
9751
)
98-
seq_data = dummy_data.seq_data
99-
mm_data = dummy_data.multi_modal_data
100-
101-
# Ensure we have the right number of placeholders for min/max pixel values
102-
assert seq_data.get_token_ids().count(image_token_id) == token_count
10352

104-
# Ensure the images were resized correctly
105-
image = mm_data["image"]
106-
assert isinstance(image, Image)
107-
assert image.size == img_size
53+
actual_max_tokens = get_max_qwen2_vl_image_tokens(
54+
InputContext(ctx.model_config), **mm_processor_kwargs)
55+
assert actual_max_tokens == expected_max_tokens
10856

10957

110-
@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [
111-
({}, 1426),
112-
({
113-
MIN_PIXELS: 64**2,
114-
MAX_PIXELS: 512**2
115-
}, 330),
116-
])
117-
def test_input_processor(input_processor_for_qwen2_vl,
118-
qwen2_vl_context: InputContext,
119-
image_assets: _ImageAssets, num_placeholders: int,
120-
mm_processor_kwargs: Dict[str, Any]):
121-
"""Ensure that the image processor handles min/max pixels properly."""
122-
tokenizer = AutoTokenizer.from_pretrained(MODEL)
123-
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
124-
125-
image = image_assets[0].pil_image
126-
hf_config = qwen2_vl_context.get_hf_config()
127-
image_token_id = hf_config.image_token_id
128-
129-
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
130-
prompt=prompt,
131-
multi_modal_data={"image": [image]})
132-
133-
processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
134-
**mm_processor_kwargs)
135-
assert processed_inputs["prompt_token_ids"].count(
136-
image_token_id) == num_placeholders
137-
assert len(processed_inputs["multi_modal_data"]["image"]) == 1
138-
139-
140-
@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
141-
({}, [5704, 1176]),
142-
({
143-
MIN_PIXELS: 64**2,
144-
MAX_PIXELS: 512**2
145-
}, [1320, 1176]),
146-
])
147-
def test_image_mapper_override(qwen2_vl_context: InputContext,
148-
image_assets: _ImageAssets,
149-
mm_processor_kwargs: Dict[str, Any],
150-
pixels_shape: Tuple[int, int]):
151-
"""Ensure that the image mapper handles min/max pixels properly."""
152-
mm_registry = MultiModalRegistry()
153-
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)
154-
155-
image = image_assets[0].pil_image
156-
157-
mapped_output = mm_registry.map_input(
158-
qwen2_vl_context.model_config,
159-
{"image": image},
160-
mm_processor_kwargs=mm_processor_kwargs,
58+
@pytest.mark.parametrize(
59+
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
60+
({}, 1426, (5704, 1176)),
61+
({
62+
MIN_PIXELS: 64**2,
63+
MAX_PIXELS: 512**2
64+
}, 330, (1320, 1176)),
65+
])
66+
@pytest.mark.parametrize("model", [MODEL])
67+
@pytest.mark.parametrize("num_imgs", [1, 2])
68+
def test_processor_override(
69+
processor_for_qwen2_vl,
70+
image_assets: _ImageAssets,
71+
model: str,
72+
mm_processor_kwargs: Dict[str, Any],
73+
expected_toks_per_img: int,
74+
expected_pixels_shape: Tuple[int, int],
75+
num_imgs: int,
76+
):
77+
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
78+
# Same as the previous test - don't initialize mm_processor_kwargs
79+
# in this test and assume that the kwargs will be correctly expanded by
80+
# the partial when calling the custom input processor.
81+
ctx = build_model_context(
82+
model_name=model,
83+
tokenizer_name=model,
84+
mm_processor_kwargs=None,
16185
)
162-
163-
# Dimension 0 of pixel values should match the product of image_grid_thw
164-
actual_pixels_shape = mapped_output["pixel_values"].shape
165-
assert list(actual_pixels_shape) == pixels_shape
166-
assert actual_pixels_shape[0] == torch.prod(
167-
mapped_output["image_grid_thw"])
86+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
87+
ctx = InputProcessingContext(ctx.model_config, tokenizer)
88+
# Build the image str / prompt based on the number of images we pass
89+
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
90+
images = [image_assets[0].pil_image] * num_imgs
91+
92+
mm_data = {"image": images}
93+
94+
processor = processor_for_qwen2_vl(ctx)
95+
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
96+
97+
# Ensure we have the right number of placeholders per num_crops size
98+
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
99+
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
100+
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
101+
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
102+
103+
assert img_tok_count == expected_toks_per_img * num_imgs
104+
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
105+
assert pixel_shape[1] == expected_pixels_shape[1]

vllm/model_executor/models/qwen2_audio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def _get_dummy_mm_inputs(
164164
self,
165165
mm_counts: Mapping[str, int],
166166
) -> ProcessorInputs:
167-
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
167+
feature_extractor = self._get_feature_extractor()
168+
sampling_rate = feature_extractor.sampling_rate
169+
audio_len = feature_extractor.chunk_length * sampling_rate
168170

169171
audio_count = mm_counts["audio"]
170172
audio = np.zeros(audio_len)

0 commit comments

Comments
 (0)