Skip to content
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ The following table lists those that are tested in vLLM.

| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ |
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ |
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ |
| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* |
Expand Down
34 changes: 31 additions & 3 deletions examples/offline_inference/vision_language_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple):
documents: Optional[ScoreMultiModalParam] = None


def run_clip(query: Query) -> ModelRequestData:
if query["modality"] == "text":
prompt = query["text"]
image = None
elif query["modality"] == "image":
prompt = "" # For image input, make sure that the prompt text is empty
image = query["image"]
else:
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")

engine_args = EngineArgs(
model="openai/clip-vit-base-patch32",
runner="pooling",
limit_mm_per_prompt={"image": 1},
)

return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image=image,
)


def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501

Expand Down Expand Up @@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:

processor = AutoProcessor.from_pretrained(
model_id,
# `min_pixels` and `max_pixels` are deprecated
# `min_pixels` and `max_pixels` are deprecated for
# transformers `preprocessor_config.json`
size={"shortest_edge": 3136, "longest_edge": 12845056},
)
processor.chat_template = load_chat_template(
Expand All @@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
model=merged_path,
runner="pooling",
max_model_len=4096,
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
mm_processor_kwargs={
"min_pixels": 3136,
"max_pixels": 12845056,
},
limit_mm_per_prompt={"image": 1},
)

Expand Down Expand Up @@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]):


model_example_map = {
"clip": run_clip,
"e5_v": run_e5_v,
"vlm2vec_phi3v": run_vlm2vec_phi3v,
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Example Python client for multimodal embedding API using vLLM API server
NOTE:
start a supported multimodal embeddings model server with `vllm serve`, e.g.
vllm serve TIGER-Lab/VLM2Vec-Full \
--runner pooling \
--trust-remote-code \
--max-model-len 4096 \
--chat-template examples/template_vlm2vec_phi3v.jinja
"""Example Python client for multimodal embedding API using vLLM API server.

Refer to each `run_*` function for the command to run the server for that model.
"""

import argparse
Expand Down Expand Up @@ -47,7 +42,58 @@ def create_chat_embeddings(
)


def run_clip(client: OpenAI, model: str):
"""
Start the server using:

vllm serve openai/clip-vit-base-patch32 \
--runner pooling
"""

response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
model=model,
encoding_format="float",
)

print("Image embedding output:", response.data[0].embedding)

response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "a photo of a cat"},
],
}
],
model=model,
encoding_format="float",
)

print("Text embedding output:", response.data[0].embedding)


def run_vlm2vec(client: OpenAI, model: str):
"""
Start the server using:

vllm serve TIGER-Lab/VLM2Vec-Full \
--runner pooling \
--trust-remote-code \
--max-model-len 4096 \
--chat-template examples/template_vlm2vec_phi3v.jinja
"""

response = create_chat_embeddings(
client,
messages=[
Expand Down Expand Up @@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str):


def run_dse_qwen2_vl(client: OpenAI, model: str):
"""
Start the server using:

vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
--runner pooling \
--trust-remote-code \
--max-model-len 8192 \
--chat-template examples/template_dse_qwen2_vl.jinja
"""
response = create_chat_embeddings(
client,
messages=[
Expand Down Expand Up @@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str):


model_example_map = {
"clip": run_clip,
"vlm2vec": run_vlm2vec,
"dse_qwen2_vl": run_dse_qwen2_vl,
}
Expand Down
138 changes: 138 additions & 0 deletions tests/models/multimodal/pooling/test_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from transformers import CLIPModel

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_embeddings_close

HF_TEXT_PROMPTS = [
"a photo of a stop sign",
"a photo of a cherry blossom",
]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "",
"cherry_blossom": "",
})

MODELS = ["openai/clip-vit-base-patch32"]


def _run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
input_texts: list[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=77) as vllm_model:
vllm_outputs = vllm_model.embed(input_texts, images=input_images)

with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model:
all_inputs = hf_model.get_inputs(input_texts, images=input_images)

all_outputs = []
for inputs in all_inputs:
if "pixel_values" in inputs:
inputs.pop("input_ids")
pooled_output = hf_model.model.get_image_features(
**hf_model.wrap_device(inputs)).squeeze(0)
else:
pooled_output = hf_model.model.get_text_features(
**hf_model.wrap_device(inputs)).squeeze(0)

all_outputs.append(pooled_output.tolist())

hf_outputs = all_outputs

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_models_text_image_no_crash(
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
texts = [HF_TEXT_PROMPTS[0]]
images = [image_assets[0].pil_image]

with vllm_runner(model,
runner="pooling",
dtype=dtype,
enforce_eager=True,
max_model_len=77) as vllm_model:
with pytest.raises(ValueError, match="not both"):
vllm_model.embed(texts, images=images)

# Should still be able to run subsequent requests
vllm_model.embed(texts)
vllm_model.embed([""], images=images)
10 changes: 8 additions & 2 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def check_available_online(
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
# [Multimodal]
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
Expand Down Expand Up @@ -687,7 +688,11 @@ def get_supported_archs(self) -> Set[str]:
return self.hf_models.keys()

def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
return self.hf_models[model_arch]
try:
return self.hf_models[model_arch]
except KeyError:
raise ValueError(f"No example model defined for {model_arch}; "
f"please update this file.") from None

def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
for info in self.hf_models.values():
Expand All @@ -699,7 +704,8 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
if any(extra == model_id for extra in info.extras.values()):
return info

raise ValueError(f"No example model defined for {model_id}")
raise ValueError(f"No example model defined for {model_id}; "
f"please update this file.")


HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
Expand Down
6 changes: 5 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,16 @@ def __init__(
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix: str = "",
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix

assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def __init__(
prefix=f"{prefix}.encoder")

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embeddings(input_ids)
return self.embeddings.word_embeddings(input_ids)

def forward(
self,
Expand Down
Loading