-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
Your current environment
I encountered a few issues while running phi-3-vision with the vllm built from current main branch.
-
Dependency:
torchvision
is a dependency under image_processing_phi3_v.py
Currently it is only included in requirements-test.txt, not requirements-common.txt. But importing the image processor also needstorchvision
be available during imports. -
Different Image Size Support
I have built a vllm docker based on the latest main branch.
I have the following script to resize the stop_sign.jpg
before sending to vllm API server.
python send_phi3v_request.py
import base64
import requests
import time
import random
import numpy as np # To calculate the mean
import json
import os
from PIL import Image
# Parameters
num_iterations = 20 # Number of times to repeat the request
# To store latencies for each iteration
latencies = []
ttfts = []
output_processing_times = []
output_throughputs = []
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_image_size(image_path):
with Image.open(image_path) as img:
# Get the dimensions
width, height = img.size
return width, height
def resize_image(image_path, sizes, output_dir):
# Open the image
with Image.open(image_path) as img:
# Iterate over the desired sizes
for size in sizes:
# Resize the image
width, height = img.size
print(f"Image original size: {width, height}")
resized_img = img.resize(size)
# Create a file name for the resized image
base_name = os.path.basename(image_path)
name, ext = os.path.splitext(base_name)
resized_img_name = f"{name}_{size[0]}x{size[1]}{ext}"
resized_img_path = os.path.join(output_dir, resized_img_name)
# Save the resized image
resized_img.save(resized_img_path)
print(f"Saved resized image: {resized_img_path}")
# Path to your image
image_path = "/home/changsu/images/stop_sign.jpg"
image_files = [image_path]
sizes = [(256, 256), (384, 384), (512, 512), (1024, 1024)]
output_dir = 'resized_images'
os.makedirs(output_dir, exist_ok=True)
resize_image(image_path, sizes, output_dir)
image_files = []
directory_path = "/home/changsu/resized_images"
for root, dirs, filenames in os.walk(directory_path):
for filename in filenames:
image_files.append(os.path.join(root, filename))
# Loop for the specified number of iterations
for _ in range(num_iterations):
# Generate 96 lines with varying numbers of tokens
# inputs = ["1 " * 256 for _ in range(num_lines)]
for image_idx, image in enumerate(image_files):
# Getting the base64 string
base64_image = encode_image(image)
payload = {
"model": "/models/Phi-3-vision-128k-instruct",
# "model": "/models/llava-v1.6-mistral-7b-hf",
# "model": "/models/llava-v1.6-34b-hf",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's the content of the image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": True,
}
# metrics
ttft = 0
total_request_time = 0
tokens_received = 0
time_to_next_token = []
generated_text = ""
output_throughput = 0
total_request_time = 0
output_start_time = 0
output_processing_time = 0
# Start the timer
start_time = time.monotonic()
# Send the POST request
response = requests.post(
"http://localhost:8000/v1/chat/completions",
# "http://localhost:9922/v1/chat/completions",
headers={"Content-Type": "application/json"},
json=payload,
stream=True,
)
Server Launch Command
docker run -tid --gpus \"device=0\" --shm-size 5g \
-p 8086:8000 -v /mnt/data/models:/models \
--ulimit nofile=65535:65535 \
-v $(pwd)/entrypoint.sh:/entrypoint.sh \
--entrypoint /entrypoint.sh \
--name vllm-main-phi3-v-1gpu-p8086 \
vllm:main-phi3-v \
--tensor-parallel-size=1 \
--model=/models/Phi-3-vision-128k-instruct \
--image-input-type="pixel_values" \
--image-feature-size=1921 \
--image-token-id=32044 \
--image-input-shape="1, 3, 1008, 1344" \
--trust-remote-code # Need for Phi-3-vision only as its config needs to be loaded from HF model files
🐛 Describe the bug
To run the file, there would raise an error:
Traceback (most recent call last):
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 43, in _log_task_completion
return_value = task.result()
^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 550, in run_engine_loop
has_requests_in_progress = await self.engine_step()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 523, in engine_step
request_outputs = await self.engine.step_async()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 236, in step_async
output = await self.model_executor.execute_model_async(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/executor/gpu_executor.py", line 117, in execute_model_async
output = await make_async(self.driver_worker.execute_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/worker/worker.py", line 281, in execute_model
output = self.model_runner.execute_model(seq_group_metadata_list,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 749, in execute_model
hidden_states = model_executable(
^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/model_executor/models/phi3v.py", line 315, in forward
inputs_embeds = self.vision_embed_tokens(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/vllm/lib/python3.11/site-packages/vllm/model_executor/models/phi3v.py", line 253, in forward
hidden_states[positions[idx, 0],
RuntimeError: The expanded size of the tensor (1937) must match the existing size (2509) at non-singleton dimension 0. Target sizes: [1937, 3072]. Tensor sizes: [2509, 3072]
Debugging Progress
Upon checking, I didn't see any resize done for images input to Phi-3-vision. See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi3v.py#L271. Is this intended?
I saw there is a resize in Llava-Next. And by doing something similar to https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py#L94-L104, I was able to run my test script successfully.