diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md
index b0018ebccf5b..ee5db70c7d5c 100644
--- a/docs/source/features/compatibility_matrix.md
+++ b/docs/source/features/compatibility_matrix.md
@@ -297,7 +297,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
* ✅
* ✅
* ?
- * [✗](gh-issue:7968>)
+ * [✗](gh-issue:7968)
* ?
* ✅
*
diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py
index 1ad56241535b..c6d5244318a3 100644
--- a/tests/models/decoder_only/language/test_models.py
+++ b/tests/models/decoder_only/language/test_models.py
@@ -26,6 +26,9 @@
"google/gemma-1.1-2b-it", # gemma
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
+ pytest.param(
+ "THUDM/chatglm3-6b", # ChatGLM (text-only)
+ ),
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
@@ -43,6 +46,9 @@
"microsoft/phi-2", # phi
marks=[pytest.mark.core_model],
),
+ pytest.param(
+ "Qwen/Qwen-7B", # qwen (text-only)
+ ),
pytest.param(
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
@@ -68,6 +74,10 @@ def test_models(
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
+ if model.startswith("THUDM/chatglm3"):
+ hf_model.model.get_output_embeddings = lambda: \
+ hf_model.model.transformer.output_layer
+
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 8658e60bc5b2..a56a9e2beef2 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -89,7 +89,7 @@ def _test_processing_correctness(
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
- for _ in range(rng.randint(limit))]
+ for _ in range(rng.randint(limit + 1))]
for k, limit in limit_mm_per_prompt.items()
}
diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py
index 9a336b7e60ff..40fcfeeeac7d 100644
--- a/tests/multimodal/utils.py
+++ b/tests/multimodal/utils.py
@@ -17,10 +17,7 @@ def random_video(
min_wh: int,
max_wh: int,
):
- # Temporary workaround for https://github.com/huggingface/transformers/issues/35412
num_frames = rng.randint(min_frames, max_frames)
- num_frames = (num_frames // 2) * 2
-
w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 9ee9e9ca8009..153c85cfb214 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -4,8 +4,8 @@
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
-from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
- TypedDict, Union)
+from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
+ Union)
import torch
from torch import nn
@@ -19,7 +19,6 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -37,12 +36,10 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
-from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
- BoundPromptReplacement,
MultiModalFieldConfig,
- PlaceholderFeaturesInfo,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -53,39 +50,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
-logger = init_logger(__name__)
-
-IMAGE_TOKEN_ID = 151329
-
-
-def build_normalization_transform(image_size: int) -> transforms.Compose:
- """
- Build a normalization transform which can be applied to one or
- more input images from which we want to extract visual features.
-
- Args:
- image_size: size of the image to be processed for visual embeddings.
-
- Returns:
- Callable transform for normalizing and resizing one RGB image.
- """
-
- return transforms.Compose([
- transforms.Resize(
- (image_size, image_size),
- interpolation=InterpolationMode.BICUBIC,
- ),
- transforms.ToTensor(),
- transforms.Normalize(
- (0.48145466, 0.4578275, 0.40821073),
- (0.26862954, 0.26130258, 0.27577711),
- ),
- ])
-
-
-def calculate_image_placeholder(vision_config):
- return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
-
class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
@@ -109,9 +73,20 @@ def __init__(
self.config = config
self.tokenizer = tokenizer
- if hasattr(self.config, "vision_config"):
- self.image_transform = build_normalization_transform(
- config.vision_config["image_size"])
+ if vision_config := getattr(config, "vision_config", None):
+ image_size = vision_config["image_size"]
+
+ self.image_transform = transforms.Compose([
+ transforms.Resize(
+ (image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ])
else:
self.image_transform = None
@@ -150,9 +125,19 @@ def __call__(
class GLM4VProcessingInfo(BaseProcessingInfo):
- def __init__(self, ctx):
- super().__init__(ctx)
- self._pre_calculate()
+ def get_tokenizer(self):
+ tokenizer = self.ctx.tokenizer
+ assert isinstance(tokenizer, PreTrainedTokenizer)
+ return tokenizer
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(ChatGLMConfig)
+
+ def get_hf_processor(self) -> GLM4VProcessor:
+ return GLM4VProcessor(
+ self.get_hf_config(),
+ self.get_tokenizer(),
+ )
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
@@ -162,27 +147,21 @@ def get_mm_max_tokens_per_item(
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
+ return {"image": self.get_num_image_feature_tokens()}
- return {"image": self.image_token_num + 2}
-
- def _pre_calculate(self):
+ def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
- vision_config = hf_config.vision_config
- self.image_token_num = calculate_image_placeholder(vision_config)
- self.image_size = vision_config["image_size"]
+ if not (vision_config := getattr(hf_config, "vision_config", None)):
+ return 0
- def get_num_image_tokens(self) -> int:
- return self.image_token_num + 2
+ image_size = vision_config["image_size"]
+ patch_size = vision_config["patch_size"]
+ grid_length = image_size // patch_size // 2
+ return grid_length * grid_length
- def get_image_size(self) -> ImageSize:
-
- return ImageSize(height=self.image_size, width=self.image_size)
-
- def get_hf_processor(self) -> GLM4VProcessor:
- return GLM4VProcessor(
- self.get_hf_config(),
- self.get_tokenizer(),
- )
+ def get_num_image_feature_tokens(self) -> int:
+ # EVA2CLIPModel has embeddings for boi and eoi tokens as well
+ return self.get_num_image_tokens() + 2
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
@@ -192,8 +171,12 @@ def get_dummy_processor_inputs(
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
+ hf_config = self.info.get_hf_config()
+ if not (vision_config := getattr(hf_config, "vision_config", None)):
+ return ProcessorInputs(prompt_text="", mm_data={})
+
+ target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
- target_width, target_height = self.info.get_image_size()
mm_data = {
"image":
@@ -201,9 +184,11 @@ def get_dummy_processor_inputs(
height=target_height,
num_images=num_images)
}
- text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
+
+ base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
+
return ProcessorInputs(
- prompt_text=text,
+ prompt_text=base_text * num_images,
mm_data=mm_data,
)
@@ -223,47 +208,28 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
+ hf_config = self.info.get_hf_config()
+ if not hasattr(hf_config, "vision_config"):
+ return []
+
+ boi_token_id = hf_config.boi_token_id
+ image_token_id = hf_config.pad_token_id
+ eoi_token_id = hf_config.eoi_token_id
def get_replacement(item_idx: int):
- image_tokens = self.info.image_token_num
- return [IMAGE_TOKEN_ID] * image_tokens
+ num_image_tokens = self.info.get_num_image_tokens()
+ image_tokens = [image_token_id] * num_image_tokens
+
+ return [boi_token_id] + image_tokens + [eoi_token_id]
return [
PromptReplacement(
modality="image",
- target=[IMAGE_TOKEN_ID],
+ target=[boi_token_id, image_token_id, eoi_token_id],
replacement=get_replacement,
),
]
- def _apply_prompt_replacements(
- self,
- token_ids: list[int],
- mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
- mm_item_counts: Mapping[str, int],
- ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
- token_ids, text, placeholders = super()._apply_prompt_replacements(
- token_ids=token_ids,
- mm_prompt_repls=mm_prompt_repls,
- mm_item_counts=mm_item_counts,
- )
- hf_config = self.info.get_hf_config()
- boi_token_id = hf_config.boi_token_id
- eoi_token_id = hf_config.eoi_token_id
- placeholders = {
- modality: [
- PlaceholderFeaturesInfo(
- modality=p.modality,
- item_idx=p.item_idx,
- start_idx=p.start_idx - 1,
- tokens=[boi_token_id] + p.tokens + [eoi_token_id],
- ) for p in ps
- ]
- for modality, ps in placeholders.items()
- }
-
- return token_ids, text, placeholders
-
class GLMAttention(nn.Module):
@@ -618,7 +584,7 @@ def get_input_embeddings(
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
- IMAGE_TOKEN_ID,
+ self.config.pad_token_id,
self.config.eoi_token_id,
],
)
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 897066124314..4b8aeaddbdd3 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -63,18 +63,6 @@
logger = init_logger(__name__)
-# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
-# for the time being, these tags are not considered as special at encoding
-# time. This may change as VLLMs multimodal API changes in the future.
-IMG_START = "
"
-IMG_END = ""
-IMG_PAD = ""
-# Image context is fixed at 256 for all images
-MAX_QWEN_IMG_TOKENS = 256
-# Image normalization params
-CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
-CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
-
class QwenImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@@ -622,25 +610,6 @@ def forward(
return hidden_states
-def build_normalization_transform(image_size: int) -> transforms.Compose:
- """
- Build a normalization transform which can be applied to one or
- more input images from which we want to extract visual features.
-
- Args:
- image_size: size of the image to be processed for visual embeddings.
-
- Returns:
- Callable transform for normalizing and resizing one RGB image.
- """
- return transforms.Compose([
- transforms.Resize((image_size, image_size),
- interpolation=InterpolationMode.BICUBIC),
- transforms.ToTensor(),
- transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
- ])
-
-
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
@@ -716,16 +685,34 @@ def __init__(
self.config = config
self.tokenizer = tokenizer
- if hasattr(self.config, "visual"):
- self.image_transform = build_normalization_transform(
- config.visual["image_size"])
+ if vision_config := getattr(self.config, "visual", None):
+ image_size = vision_config["image_size"]
+
+ self.image_transform = transforms.Compose([
+ transforms.Resize(
+ (image_size, image_size),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+ ),
+ ])
else:
self.image_transform = None
- special_tokens: dict[str,
- int] = tokenizer.special_tokens # type: ignore
- self.img_start_id = special_tokens[IMG_START]
- self.img_end_id = special_tokens[IMG_END]
+ @property
+ def image_start_tag(self) -> str:
+ return self.tokenizer.image_start_tag # type: ignore
+
+ @property
+ def image_end_tag(self) -> str:
+ return self.tokenizer.image_end_tag # type: ignore
+
+ @property
+ def image_pad_tag(self) -> str:
+ return self.tokenizer.image_pad_tag # type: ignore
def __call__(
self,
@@ -787,7 +774,14 @@ def get_mm_max_tokens_per_item(
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
- return MAX_QWEN_IMG_TOKENS
+ hf_config = self.get_hf_config()
+ if not (vision_config := getattr(hf_config, "visual", None)):
+ return 0
+
+ image_size = vision_config["image_size"]
+ patch_size = vision_config["patch_size"]
+ grid_length = image_size // patch_size // 2
+ return grid_length * grid_length
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
@@ -798,10 +792,12 @@ def get_dummy_processor_inputs(
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
- if not hasattr(hf_config, "visual"):
+ if not (vision_config := getattr(hf_config, "visual", None)):
return ProcessorInputs(prompt_text="", mm_data={})
- vision_config = hf_config.visual
+ processor = self.info.get_hf_processor()
+ img_start = processor.image_start_tag
+ img_end = processor.image_end_tag
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
@@ -814,7 +810,7 @@ def get_dummy_processor_inputs(
}
return ProcessorInputs(
- prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
+ prompt_text="".join(f"Picture {i}: {img_start}{img_end}\n"
for i in range(1, num_images + 1)),
mm_data=mm_data,
)
@@ -869,13 +865,18 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
+ hf_config = self.info.get_hf_config()
+ if not hasattr(hf_config, "visual"):
+ return []
+
tokenizer = self.info.get_tokenizer()
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
- img_start_id = special_tokens[IMG_START]
- img_end_id = special_tokens[IMG_END]
- img_pad_id = special_tokens[IMG_PAD]
+ processor = self.info.get_hf_processor()
+ img_start_id = special_tokens[processor.image_start_tag]
+ img_end_id = special_tokens[processor.image_end_tag]
+ img_pad_id = special_tokens[processor.image_pad_tag]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [img_pad_id] * num_image_tokens
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 34ae7b8c9469..f2071eaff481 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -885,14 +885,10 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int:
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
- num_frames = min(max(max_total_frames // max(max_videos, 1), 1),
- _MAX_FRAMES_PER_VIDEO)
+ max_frames_per_video = min(max_total_frames // max(max_videos, 1),
+ _MAX_FRAMES_PER_VIDEO)
- # Temporary workaround for https://github.com/huggingface/transformers/issues/35412
- if num_frames > 1 and num_frames % 2 == 1:
- num_frames += 1
-
- return num_frames
+ return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()