From 9d026e0b0af8db97cc81704b87d04028d5ddca45 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 20 Sep 2024 04:32:47 -0400 Subject: [PATCH 1/3] Use named tuple in multi-image example Signed-off-by: Alex-Brooks --- ...e_inference_vision_language_multi_image.py | 60 +++++++++++++------ 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 454872c62837..f8787f40dbbb 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -4,6 +4,7 @@ by the model. """ from argparse import Namespace +from collections import namedtuple from typing import List from transformers import AutoProcessor, AutoTokenizer @@ -18,6 +19,10 @@ "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", ] +ModelRequestData = namedtuple( + "ModelRequestData", + ["llm", "prompt", "stop_token_ids", "image_data", "chat_template"]) + def load_qwenvl_chat(question: str, image_urls: List[str]): model_name = "Qwen/Qwen-VL-Chat" @@ -48,7 +53,13 @@ def load_qwenvl_chat(question: str, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None, chat_template + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=chat_template, + ) def load_phi3v(question: str, image_urls: List[str]): @@ -62,7 +73,14 @@ def load_phi3v(question: str, image_urls: List[str]): for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" stop_token_ids = None - return llm, prompt, stop_token_ids, None, None + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) def load_internvl(question: str, image_urls: List[str]): @@ -93,7 +111,13 @@ def load_internvl(question: str, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None, None + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) def load_qwen2_vl(question, image_urls: List[str]): @@ -143,7 +167,13 @@ def load_qwen2_vl(question, image_urls: List[str]): else: image_data, _ = process_vision_info(messages) - return llm, prompt, stop_token_ids, image_data, None + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=image_data, + chat_template=None, + ) model_example_map = { @@ -155,20 +185,17 @@ def load_qwen2_vl(question, image_urls: List[str]): def run_generate(model, question: str, image_urls: List[str]): - llm, prompt, stop_token_ids, image_data, _ = model_example_map[model]( - question, image_urls) - if image_data is None: - image_data = [fetch_image(url) for url in image_urls] + req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, - stop_token_ids=stop_token_ids) + stop_token_ids=req_data.stop_token_ids) - outputs = llm.generate( + outputs = req_data.llm.generate( { - "prompt": prompt, + "prompt": req_data.prompt, "multi_modal_data": { - "image": image_data + "image": req_data.image_data }, }, sampling_params=sampling_params) @@ -179,13 +206,12 @@ def run_generate(model, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]): - llm, _, stop_token_ids, _, chat_template = model_example_map[model]( - question, image_urls) + req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, - stop_token_ids=stop_token_ids) - outputs = llm.chat( + stop_token_ids=req_data.stop_token_ids) + outputs = req_data.llm.chat( [{ "role": "user", @@ -203,7 +229,7 @@ def run_chat(model: str, question: str, image_urls: List[str]): ], }], sampling_params=sampling_params, - chat_template=chat_template, + chat_template=req_data.chat_template, ) for o in outputs: From 6db301fc9f96e2366f91604b73942302bbed3f58 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 02:45:22 -0400 Subject: [PATCH 2/3] Add load model return type annotations Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language_multi_image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index f8787f40dbbb..884cea4af09a 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -24,7 +24,7 @@ ["llm", "prompt", "stop_token_ids", "image_data", "chat_template"]) -def load_qwenvl_chat(question: str, image_urls: List[str]): +def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, @@ -62,7 +62,7 @@ def load_qwenvl_chat(question: str, image_urls: List[str]): ) -def load_phi3v(question: str, image_urls: List[str]): +def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, @@ -83,7 +83,7 @@ def load_phi3v(question: str, image_urls: List[str]): ) -def load_internvl(question: str, image_urls: List[str]): +def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" llm = LLM( @@ -120,7 +120,7 @@ def load_internvl(question: str, image_urls: List[str]): ) -def load_qwen2_vl(question, image_urls: List[str]): +def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: From 2bed8912edf27db05aa1a0447350fd6cec5f4c79 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 22 Sep 2024 07:01:28 -0400 Subject: [PATCH 3/3] Use typing NamedTuple instead of collections Signed-off-by: Alex-Brooks --- ...ffline_inference_vision_language_multi_image.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 884cea4af09a..92ab4f42baa8 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -4,9 +4,9 @@ by the model. """ from argparse import Namespace -from collections import namedtuple -from typing import List +from typing import List, NamedTuple, Optional +from PIL.Image import Image from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, SamplingParams @@ -19,9 +19,13 @@ "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", ] -ModelRequestData = namedtuple( - "ModelRequestData", - ["llm", "prompt", "stop_token_ids", "image_data", "chat_template"]) + +class ModelRequestData(NamedTuple): + llm: LLM + prompt: str + stop_token_ids: Optional[List[str]] + image_data: List[Image] + chat_template: Optional[str] def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: