Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@
"model_impl": "transformers",
},
# FIXME: Investigate mrope issue
marks=[large_gpu_mark(min_gb=32),
pytest.mark.skip(reason="Mrope issue")],
marks=[large_gpu_mark(min_gb=32)],
),
#### Extended model tests
"aria": VLMTestInfo(
Expand Down
86 changes: 39 additions & 47 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
flatten_bn, make_empty_intermediate_tensors_factory,
maybe_prefix)
make_empty_intermediate_tensors_factory, maybe_prefix)

logger = init_logger(__name__)

Expand Down Expand Up @@ -291,19 +290,23 @@ def _get_prompt_updates(

def _get_mm_fields_config(
self,
hf_inputs,
hf_processor_mm_kwargs,
num_image_patches: torch.Tensor = None,
):
hf_inputs: "BatchFeature",
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
num_image_patches = hf_inputs.get("num_image_patches")
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image",
num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches)

# Keep `num_patches` as batched, as it always has `bs` as first dim
mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image")
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields

Expand All @@ -318,7 +321,9 @@ def _apply_hf_processor_text_mm(
Apply the HF processor on the prompt text and multi-modal data
together.

In addition, return whether prompt replacements have been applied.
In contrast to the base class, this method always returns
`mm_token_type_ids` from HF processor. Additionally, return whether
prompt replacements have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
Expand All @@ -330,14 +335,16 @@ def _apply_hf_processor_text_mm(
tok_kwargs=tokenization_kwargs,
)
processed_data.update(passthrough_data)

prompt_ids, = processed_data.pop("input_ids").tolist()
mm_token_type_ids = processed_data.pop(
"mm_token_type_ids"
) if "mm_token_type_ids" in processed_data else processed_data.pop(
"token_type_ids") # for gemma3 only

return prompt_ids, processed_data, mm_token_type_ids
is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)

return prompt_ids, processed_data, is_update_applied

def apply(
self,
Expand All @@ -364,17 +371,21 @@ def apply(
# into string
prompt = hf_processor.decode(prompt)

(prompt_ids, processed_data,
mm_token_type_ids) = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)

# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
# Bypass cached processor and always apply to the full set of mm inputs
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)

# For gemma3 we check `token_type_ids` as the key
token_type_key = "mm_token_type_ids" if "mm_token_type_ids" \
in processed_data else "token_type_ids"
mm_token_type_ids = processed_data.pop(token_type_key)

# We can infer vLLM style placeholder from token type ids, if we split
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
Expand Down Expand Up @@ -403,14 +414,11 @@ def apply(
]
mm_placeholders = {"image": ranges}

num_image_patches = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
) if "num_image_patches" in mm_tokens_per_modality else None
processed_data['num_image_patches'] = num_image_patches
processed_data['num_image_patches'] = torch.tensor(
mm_tokens_per_modality["num_image_patches"])
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches),
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)

# Use overrides if provided; fallback to data-dependent hashing.
Expand Down Expand Up @@ -468,16 +476,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.ignore_unexpected_suffixes.append(".bias")

# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method once its checks are fixed in Transformers.
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)

self.model.set_attn_implementation({"text_config": "vllm"})
self.pipeline_parallel()
self.tensor_parallel()

Expand Down Expand Up @@ -755,17 +761,6 @@ def compute_logits(
return logits


def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
"""Flatten until a list of tensors can be concatenated then do concat"""

def _can_concat(x: list[torch.Tensor]):
return len(set(map(lambda _x: _x.shape[1:], x))) == 1

if _can_concat(x):
return torch.concat(x)
return flatten_and_concat(flatten_bn(x))


@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
Expand Down Expand Up @@ -843,9 +838,6 @@ def get_multimodal_embeddings(self, **kwargs):
pixel_values, **kwargs)

if isinstance(vision_embeddings, torch.Tensor):
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)

if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)

Expand Down