From 515037b0bbeb1e72da6b33fde21f8812cd7b8309 Mon Sep 17 00:00:00 2001 From: Hannah Zhang Date: Mon, 30 Sep 2024 15:21:52 -0700 Subject: [PATCH 1/5] fix: bug fixes for phi3v and ultravox image embedding support --- vllm/model_executor/models/phi3v.py | 7 +--- vllm/model_executor/models/ultravox.py | 54 +++++++++++++++----------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 245381518a7f..859257b07f10 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -439,7 +439,7 @@ def input_processor_for_phi3v(ctx: InputContext, elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] + image_feature_size = [item.shape[0] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -577,9 +577,6 @@ def _parse_and_validate_image_input( image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) - if pixel_values is None: - return None - if pixel_values is None and image_embeds is None: return None @@ -616,7 +613,7 @@ def _process_image_input( ) -> torch.Tensor: if image_input["type"] == "image_embeds": - return image_input["data"] + return list(torch.unbind(image_input["data"], dim=0)) assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens(image_input["data"], diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 71808eb4c271..c5f95fa11638 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -119,10 +119,13 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): data = [data] audio_features = [] + audio_embeds = [] + is_audio_embeds = False for audio_input in data: if not isinstance(audio_input, tuple): - raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}") + is_audio_embeds = True + audio_embeds.append(audio_input) + continue (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) feature_extractor = whisper_feature_extractor(ctx) @@ -150,7 +153,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Remove the batch dimension because we're wrapping it in a list. audio_features.append(single_audio_features.squeeze(0)) - return MultiModalInputs({"audio_features": audio_features}) + return (MultiModalInputs({"audio_embeds": audio_embeds}) + if is_audio_embeds + else MultiModalInputs({"audio_features": audio_features})) def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): @@ -164,25 +169,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): audios = [audios] audio_token_counts = [] - for audio_data, sample_rate in audios: - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - (feature_extractor.hop_length - 1)) / - feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) - audio_token_counts.append(audio_num_tokens) + for audio in audios: + if isinstance(audio, torch.Tensor): + audio_num_tokens = audio.shape[1] + audio_token_counts.append(audio_num_tokens) + else: + audio_data, sample_rate = audio + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) From c5c5c023ad9b0bc0fdb4fe074cc86aef3c6649cc Mon Sep 17 00:00:00 2001 From: Hannah Zhang Date: Mon, 30 Sep 2024 15:27:30 -0700 Subject: [PATCH 2/5] lint: fixing lint for ultravox --- vllm/model_executor/models/ultravox.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c5f95fa11638..4a77103f0d7e 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -153,8 +153,7 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Remove the batch dimension because we're wrapping it in a list. audio_features.append(single_audio_features.squeeze(0)) - return (MultiModalInputs({"audio_embeds": audio_embeds}) - if is_audio_embeds + return (MultiModalInputs({"audio_embeds": audio_embeds}) if is_audio_embeds else MultiModalInputs({"audio_features": audio_features})) @@ -190,7 +189,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): max( 1, math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), + (uv_config.stack_factor * 2))), get_ultravox_max_audio_tokens(ctx)) audio_token_counts.append(audio_num_tokens) From 9a87aa66d959f65f2fbe7f1a2d6a457bbccb6f84 Mon Sep 17 00:00:00 2001 From: Hannah Zhang Date: Mon, 30 Sep 2024 15:56:13 -0700 Subject: [PATCH 3/5] feat: simplify ultravox input mapper --- vllm/model_executor/models/ultravox.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 4a77103f0d7e..b97d4d91772b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -39,6 +39,7 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils import is_list_of _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -118,14 +119,15 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): if not isinstance(data, list): data = [data] + # If the audio inputs are embeddings, no need for preprocessing + if is_list_of(data, torch.Tensor, check="all"): + return MultiModalInputs({"audio_embeds": data}) + audio_features = [] - audio_embeds = [] - is_audio_embeds = False for audio_input in data: if not isinstance(audio_input, tuple): - is_audio_embeds = True - audio_embeds.append(audio_input) - continue + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) feature_extractor = whisper_feature_extractor(ctx) @@ -153,8 +155,7 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Remove the batch dimension because we're wrapping it in a list. audio_features.append(single_audio_features.squeeze(0)) - return (MultiModalInputs({"audio_embeds": audio_embeds}) if is_audio_embeds - else MultiModalInputs({"audio_features": audio_features})) + return MultiModalInputs({"audio_features": audio_features}) def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): From 1918192e38148472f5f9b18274b34c6cecbdbd1d Mon Sep 17 00:00:00 2001 From: Hannah Zhang Date: Tue, 1 Oct 2024 10:07:59 -0700 Subject: [PATCH 4/5] fix: fixes for single image enbeds input --- vllm/model_executor/models/phi3v.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 859257b07f10..b1bb0ba879e4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -437,7 +437,8 @@ def input_processor_for_phi3v(ctx: InputContext, input_height=h, num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape + image_feature_size = [image_data.shape[0]] + image_data = [image_data] elif is_list_of(image_data, torch.Tensor): image_feature_size = [item.shape[0] for item in image_data] else: @@ -613,7 +614,15 @@ def _process_image_input( ) -> torch.Tensor: if image_input["type"] == "image_embeds": - return list(torch.unbind(image_input["data"], dim=0)) + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 2: + # 2D tensor + return image_data + # 3D tensor + return list(torch.unbind(image_data, dim=0)) assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens(image_input["data"], From f0a04ba4a67f928883222ec35cefd371e086b274 Mon Sep 17 00:00:00 2001 From: Hannah Zhang Date: Thu, 3 Oct 2024 10:57:48 -0700 Subject: [PATCH 5/5] feat: raise error if phi3v image embeds are not batched 2D tensors --- vllm/model_executor/models/phi3v.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b1bb0ba879e4..f0f065e221ad 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -618,11 +618,13 @@ def _process_image_input( if is_list_of(image_data, torch.Tensor): # it's already a list of tensors return image_data - if len(image_data.shape) == 2: - # 2D tensor - return image_data - # 3D tensor - return list(torch.unbind(image_data, dim=0)) + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors;" + "this can be either a list of 2D tensors or a single 3D tensor." + ) assert self.vision_embed_tokens is not None image_embeds = self.vision_embed_tokens(image_input["data"],