Skip to content

[Bug]: Different Image Size support with Phi-3-Vision and torchvision dependency #5767

@CatherineSue

Description

@CatherineSue

Your current environment

I encountered a few issues while running phi-3-vision with the vllm built from current main branch.

  1. Dependency:
    torchvision is a dependency under image_processing_phi3_v.py
    Currently it is only included in requirements-test.txt, not requirements-common.txt. But importing the image processor also needs torchvision be available during imports.

  2. Different Image Size Support

I have built a vllm docker based on the latest main branch.
I have the following script to resize the stop_sign.jpg before sending to vllm API server.

python send_phi3v_request.py
import base64
import requests
import time
import random
import numpy as np  # To calculate the mean
import json
import os
from PIL import Image

# Parameters
num_iterations = 20  # Number of times to repeat the request

# To store latencies for each iteration
latencies = []
ttfts = []
output_processing_times = []
output_throughputs = []

# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')


def get_image_size(image_path):
    with Image.open(image_path) as img:
        # Get the dimensions
        width, height = img.size
    return width, height


def resize_image(image_path, sizes, output_dir):
    # Open the image
    with Image.open(image_path) as img:
        # Iterate over the desired sizes
        for size in sizes:
            # Resize the image
            width, height = img.size
            print(f"Image original size: {width, height}")
            resized_img = img.resize(size)
            # Create a file name for the resized image
            base_name = os.path.basename(image_path)
            name, ext = os.path.splitext(base_name)
            resized_img_name = f"{name}_{size[0]}x{size[1]}{ext}"
            resized_img_path = os.path.join(output_dir, resized_img_name)
            # Save the resized image
            resized_img.save(resized_img_path)
            print(f"Saved resized image: {resized_img_path}")


# Path to your image
image_path = "/home/changsu/images/stop_sign.jpg"
image_files = [image_path]

sizes = [(256, 256), (384, 384), (512, 512), (1024, 1024)]
output_dir = 'resized_images'
os.makedirs(output_dir, exist_ok=True)
resize_image(image_path, sizes, output_dir)

image_files = []
directory_path = "/home/changsu/resized_images"
for root, dirs, filenames in os.walk(directory_path):
    for filename in filenames:
        image_files.append(os.path.join(root, filename))
        
        
# Loop for the specified number of iterations
for _ in range(num_iterations):
    # Generate 96 lines with varying numbers of tokens
    # inputs = ["1 " * 256 for _ in range(num_lines)]
    for image_idx, image in enumerate(image_files):
        # Getting the base64 string
        base64_image = encode_image(image)

        payload = {
        "model": "/models/Phi-3-vision-128k-instruct",
        # "model": "/models/llava-v1.6-mistral-7b-hf",
        # "model": "/models/llava-v1.6-34b-hf",
        "messages": [
            {
                "role": "user",
                "content": [
                {
                    "type": "text",
                    "text": "What's the content of the image?"
                },
                {
                    "type": "image_url",
                    "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                    }
                }
                ]
            }
        ],
        "max_tokens": 300,
        "temperature": 0.0,
        "stream": True,
        }

        # metrics
        ttft = 0
        total_request_time = 0
        tokens_received = 0
        time_to_next_token = []
        generated_text = ""
        output_throughput = 0
        total_request_time = 0
        output_start_time = 0
        output_processing_time = 0

        # Start the timer
        start_time = time.monotonic()

        # Send the POST request
        response = requests.post(
            "http://localhost:8000/v1/chat/completions",
            # "http://localhost:9922/v1/chat/completions",
            headers={"Content-Type": "application/json"},
            json=payload,
            stream=True,
        )

Server Launch Command

docker run -tid --gpus \"device=0\" --shm-size 5g \
        -p 8086:8000 -v /mnt/data/models:/models \
        --ulimit nofile=65535:65535 \
        -v $(pwd)/entrypoint.sh:/entrypoint.sh \
        --entrypoint /entrypoint.sh \
        --name vllm-main-phi3-v-1gpu-p8086 \
        vllm:main-phi3-v \
        --tensor-parallel-size=1 \
        --model=/models/Phi-3-vision-128k-instruct \
        --image-input-type="pixel_values" \
        --image-feature-size=1921 \
        --image-token-id=32044 \
        --image-input-shape="1, 3, 1008, 1344" \
        --trust-remote-code # Need for Phi-3-vision only as its config needs to be loaded from HF model files

🐛 Describe the bug

To run the file, there would raise an error:

Traceback (most recent call last):
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 43, in _log_task_completion
    return_value = task.result()
                   ^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 550, in run_engine_loop
    has_requests_in_progress = await self.engine_step()
                               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 523, in engine_step
    request_outputs = await self.engine.step_async()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 236, in step_async
    output = await self.model_executor.execute_model_async(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/executor/gpu_executor.py", line 117, in execute_model_async
    output = await make_async(self.driver_worker.execute_model
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/worker/worker.py", line 281, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 749, in execute_model
    hidden_states = model_executable(
                    ^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/model_executor/models/phi3v.py", line 315, in forward
    inputs_embeds = self.vision_embed_tokens(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/model_executor/models/phi3v.py", line 253, in forward
    hidden_states[positions[idx, 0],
RuntimeError: The expanded size of the tensor (1937) must match the existing size (2509) at non-singleton dimension 0.  Target sizes: [1937, 3072].  Tensor sizes: [2509, 3072]

Debugging Progress

Upon checking, I didn't see any resize done for images input to Phi-3-vision. See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi3v.py#L271. Is this intended?
I saw there is a resize in Llava-Next. And by doing something similar to https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py#L94-L104, I was able to run my test script successfully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions