diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 98f18319f1e4..1b80c801d5be 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ * `microsoft/Phi-4-multimodal-instruct`, etc. * ✅︎ * - * + * ✅︎ - * `PixtralForConditionalGeneration` * Pixtral * T + I+ diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 077b5c762a25..e3c75d5cb6a9 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=12800, max_num_seqs=2, enable_lora=True, max_lora_rank=320, diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 510a043ce421..bd7035b7615a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=4096, + max_model_len=5120, max_num_seqs=2, + max_num_batched_tokens=12800, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 16}, limit_mm_per_prompt={"image": 1}, ) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index e2e14d16228a..f165ea9efa10 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: engine_args = EngineArgs( model=model_path, trust_remote_code=True, - max_model_len=10000, + max_model_len=4096, max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, enable_lora=True, max_lora_rank=320, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={"dynamic_hd": 4}, ) placeholders = "".join(f"<|image_{i}|>" diff --git a/requirements/docs.txt b/requirements/docs.txt index 416ca503b36c..99fb87def6dd 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -18,6 +18,7 @@ transformers mistral_common >= 1.5.4 aiohttp starlette +scipy openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index bd1dcba6a995..e9dcba8ec089 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -1,14 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Optional +from typing import Any, Optional import numpy as np import pytest import pytest_asyncio from transformers import AutoModel, AutoTokenizer -from vllm.multimodal.audio import resample_audio +from vllm.multimodal.audio import resample_audio_librosa from vllm.sequence import SampleLogprobs from ....conftest import HfRunner, VllmRunner @@ -43,6 +43,18 @@ def audio(request): return AudioAsset(request.param) +def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: + """Convert kwargs to CLI args.""" + args = [] + for key, value in params_kwargs.items(): + if isinstance(value, bool): + if value: + args.append(f"--{key.replace('_','-')}") + else: + args.append(f"--{key.replace('_','-')}={value}") + return args + + @pytest.fixture(params=[ pytest.param({}, marks=pytest.mark.cpu_model), pytest.param(CHUNKED_PREFILL_KWARGS), @@ -52,10 +64,7 @@ def server(request, audio_assets): "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", "--limit-mm-per-prompt", json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" - ] + [ - f"--{key.replace('_','-')}={value}" - for key, value in request.param.items() - ] + ] + params_kwargs_to_cli_args(request.param) with RemoteOpenAIServer(MODEL_NAME, args, @@ -136,9 +145,9 @@ def run_test( [hf_prompt], max_tokens, num_logprobs=num_logprobs, - audios=[(resample_audio(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) + audios=[(resample_audio_librosa(audio[0], + orig_sr=audio[1], + target_sr=16000), 16000)]) for _, hf_prompt, audio in prompts_and_audios ] diff --git a/tests/models/decoder_only/vision_language/test_phi4mm.py b/tests/models/decoder_only/vision_language/test_phi4mm.py index 3cd830015076..11460a1a8d2b 100644 --- a/tests/models/decoder_only/vision_language/test_phi4mm.py +++ b/tests/models/decoder_only/vision_language/test_phi4mm.py @@ -181,7 +181,7 @@ def patch_hf_processor(*args, ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [4096]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ], ) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("model", models) @pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [10000]) +@pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a8f21ff919b7..d56638f051f2 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -274,6 +274,7 @@ def _test_processing_correctness_mistral( "nvidia/NVLM-D-72B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", "Qwen/Qwen-VL-Chat", diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py new file mode 100644 index 000000000000..797986adba4a --- /dev/null +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for phi4mm's multimodal preprocessing kwargs.""" +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ....conftest import _ImageAssets +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) +# yapf: disable +@pytest.mark.parametrize( + ("mm_processor_kwargs", "expected_toks_per_img"), + [ + ({"dynamic_hd": 4}, 1329), + ({"dynamic_hd": 16}, 4433), + # the default num_crops of phi-4-multimodal is 36 + ({}, 9585), + ]) +# yapf: enable +@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) +def test_processor_override( + image_assets: _ImageAssets, + model_id: str, + mm_processor_kwargs: dict[str, int], + expected_toks_per_img: int, + num_imgs: int, + kwargs_on_init: bool, +): + """Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly.""" + # Avoid initializing CUDA early + from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID + + ctx = build_model_context( + model_id, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": num_imgs}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs + + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + + image_size = ctx.get_hf_config( + ).embd_layer["image_embd_layer"]["crop_size"] + dummy_image_size = (image_size * 7, image_size * 7) + dummy_image = image_assets[0].pil_image.resize(dummy_image_size) + mm_data = {"image": [dummy_image] * num_imgs} + + processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = processed_inputs["prompt_token_ids"].count( + _IMAGE_PLACEHOLDER_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index e505a4592e2d..0b662f1a7ec3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -482,11 +482,8 @@ def _placeholder_str(self, modality: ModalityStr, if modality in ("image", "image_embeds"): if model_type == "chatglm": return "<|begin_of_image|><|endoftext|><|end_of_image|>" - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer + if model_type in ("phi3_v", "phi4mm"): return f"<|image_{current_count}|>" - if model_type == "phi4mm": - return "<|endoftext10|>" # 200010 (see vocab.json in hf model) if model_type in ("minicpmo", "minicpmv"): return "(./)" if model_type in ("blip-2", "florence2", "fuyu", "paligemma", @@ -522,7 +519,7 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "ultravox": return "<|audio|>" if model_type == "phi4mm": - return "<|endoftext11|>" # 200011 (see vocab.json in hf model) + return f"<|audio_{current_count}|>" if model_type in ("qwen2_audio", "qwen2_5_omni"): return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7f41ad2359df..5b43871b7591 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -327,7 +327,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[ProcessorMixin], + processor: Optional[ProcessorMixin] = None, ) -> int: if processor is None: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index ec19797f8875..cdd762f5fec3 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1,41 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 import math -import re -from functools import lru_cache -from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union import numpy as np -import scipy.signal import torch import torch.nn as nn -import torchvision.transforms as T -from PIL import Image -from transformers import PretrainedConfig, SiglipVisionConfig -from transformers.utils import logging +from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, + SequenceFeatureExtractor, SiglipVisionConfig) from vllm.config import VllmConfig from vllm.distributed import get_pp_group -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) -from vllm.inputs.data import TokenInputs, token_inputs from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, + ImageProcessorItems, ImageSize, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -43,115 +44,19 @@ _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 -DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz - -DYNAMIC_HD = 16 -AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>" -IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>" SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { 'siglip-so400m-patch14-448': { - 'dynamic_hd': 16, 'vit_image_size': 448, 'vit_patch_size': 14, 'token_compression_factor': 2, }, } -logger = logging.get_logger(__name__) -# This is a workaround to prevent text (user input) + audio + image -# from being used in the same prompt. -# It includes token ids for "/n" and tokens in added_tokens_decoder -# from the tokenizer_confg.json file. -NON_USER_INPUT_TOKENS = { - 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022, - 200023, 200024, 200025, 200026, 200027, 200028 -} -def get_max_dummy_image(ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - - max_side = vit_image_size * dynamic_hd_size - dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side) - return dummy_image - - -# image token length -def get_max_phi4mm_image_tokens(ctx: InputContext): - dummy_image = get_max_dummy_image(ctx) - - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size, - vit_image_size, - vit_patch_size, - token_compression_factor) - return image_num_tokens - - -def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, - image_size): - best_ratio_diff = float('inf') - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - -def _find_target_aspect_ratio(image, image_size, max_num, min_num): - orig_width, orig_height = image.size - - w_crop_num = math.ceil(orig_width / float(image_size)) - h_crop_num = math.ceil(orig_height / float(image_size)) - if w_crop_num * h_crop_num > max_num: - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - target_aspect_ratio = find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - logger.debug("target_aspect_ratio: %s", target_aspect_ratio) - else: - target_width = image_size * w_crop_num - target_height = image_size * h_crop_num - target_aspect_ratio = (w_crop_num, h_crop_num) - return target_aspect_ratio, target_height, target_width - - -def _get_padding_size(image, target_height, target_width): - orig_width, orig_height = image.size +def _get_padding_size(orig_width: int, orig_height: int, target_height: int, + target_width: int): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -164,181 +69,6 @@ def _get_padding_size(image, target_height, target_width): return padding_height, padding_width -def dynamic_preprocess(image, - min_num=1, - max_num=12, - image_size=384, - mask_size=27): - target_aspect_ratio, target_height, target_width =\ - _find_target_aspect_ratio( - image, image_size, max_num, min_num) - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - - # Calculate the ratio - orig_width, orig_height = image.size - ratio_width = target_width / orig_width - ratio_height = target_height / orig_height - if ratio_width < ratio_height: - new_size = (target_width, int(orig_height * ratio_width)) - else: - new_size = (int(orig_width * ratio_height), target_height) - - attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), - int(mask_size * target_aspect_ratio[0]))) - if padding_width >= 14: - attention_mask[:, -math.floor(padding_width / 14):] = 0 - if padding_height >= 14: - attention_mask[-math.floor(padding_height / 14):, :] = 0 - assert attention_mask.sum( - ) > 0, f'attention mask is empty {attention_mask}' - - if min(new_size[1], target_height) < 10 or min(new_size[0], - target_width) < 10: - raise ValueError(f'the aspect ratio is very extreme {new_size}') - - image = T.functional.resize( - image, - [new_size[1], new_size[0]], - ) - - resized_img = T.functional.pad(image, - [0, 0, padding_width, padding_height], - fill=[255, 255, 255]) - - return resized_img, attention_mask - - -def pad_to_max_num_crops(images, max_crops=5): - """ - images: B x 3 x H x W, B<=max_crops - """ - B, _, H, W = images.shape - if max_crops > B: - pad = torch.zeros(max_crops - B, - 3, - H, - W, - dtype=images.dtype, - device=images.device) - images = torch.cat([images, pad], dim=0) - return images - - -def pad_mask_to_max_num_crops(masks, max_crops=5): - B, H, W = masks.shape - if max_crops > B: - pad = torch.ones(max_crops - B, - H, - W, - dtype=masks.dtype, - device=masks.device) - masks = torch.cat([masks, pad], dim=0) - return masks - - -def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size): - - # Basic settings. - img_processor = T.Compose([ - T.ToTensor(), - T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ]) - # Dynamic HD - base_resolution = vit_resolution - images = [image.convert('RGB') for image in images] - # cover 384 and 448 resolution - mask_resolution = base_resolution // vit_patch_size - elems, image_attention_masks = [], [] - for im in images: - elem, attention_mask = dynamic_preprocess(im, - max_num=dynamic_hd_size, - image_size=base_resolution, - mask_size=mask_resolution) - elems.append(elem) - image_attention_masks.append(attention_mask) - hd_images = [img_processor(im) for im in elems] - global_image = [ - torch.nn.functional.interpolate( - im.unsqueeze(0).float(), - size=(base_resolution, base_resolution), - mode='bicubic', - ).to(im.dtype) for im in hd_images - ] - shapes = [[im.size(1), im.size(2)] for im in hd_images] - mask_shapes = [[mask.size(0), mask.size(1)] - for mask in image_attention_masks] - global_attention_mask = [ - torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images - ] - hd_images_reshape = [ - im.reshape(1, 3, h // base_resolution, base_resolution, - w // base_resolution, base_resolution).permute( - 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution, - base_resolution).contiguous() - for im, (h, w) in zip(hd_images, shapes) - ] - attention_masks_reshape = [ - mask.reshape(1, h // mask_resolution, mask_resolution, - w // mask_resolution, mask_resolution).permute( - 0, 1, 3, 2, 4).reshape(-1, mask_resolution, - mask_resolution).contiguous() - for mask, (h, w) in zip(image_attention_masks, mask_shapes) - ] - # NOTE token compression is hard coded here, and odd numbers seems to fail - downsample_attention_masks = [ - mask[:, 0::2, - 0::2].reshape(1, h // mask_resolution, w // mask_resolution, - mask_resolution // 2 + mask_resolution % 2, - mask_resolution // 2 + mask_resolution % 2).permute( - 0, 1, 3, 2, 4) - for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) - ] - downsample_attention_masks = [ - mask.reshape(mask.size(1) * mask.size(2), - mask.size(3) * mask.size(4)) - for mask in downsample_attention_masks - ] - # NOTE hard coded number of tokens - num_img_tokens = [ - 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 - for mask in downsample_attention_masks - ] - - hd_images_reshape = [ - torch.cat([_global_image] + [_im], dim=0) - for _global_image, _im in zip(global_image, hd_images_reshape) - ] - hd_masks_reshape = [ - torch.cat([_global_mask] + [_mask], - dim=0) for _global_mask, _mask in zip( - global_attention_mask, attention_masks_reshape) - ] - max_crops = max([img.size(0) for img in hd_images_reshape]) - image_transformed = [ - pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape - ] - image_transformed = torch.stack(image_transformed, dim=0) - mask_transformed = [ - pad_mask_to_max_num_crops(mask, max_crops) \ - for mask in hd_masks_reshape - ] - mask_transformed = torch.stack(mask_transformed, dim=0) - - returned_input_image_embeds = image_transformed - returned_image_sizes = torch.tensor(shapes, dtype=torch.long) - returned_image_attention_mask = mask_transformed - returned_num_img_tokens = num_img_tokens - - data = { - "pixel_values": returned_input_image_embeds, - "image_sizes": returned_image_sizes, - "image_attention_mask": returned_image_attention_mask, - "num_img_tokens": returned_num_img_tokens, - } - return data - - def get_navit_vision_model(layer_idx: int = -1, **kwargs): vision_config = { "hidden_size": 1152, @@ -492,7 +222,7 @@ def get_img_features(self, def forward(self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - image_attention_mask: torch.Tensor) -> torch.FloatTensor: + image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]: """ process image and return vision embeddings. @@ -656,785 +386,528 @@ def forward(self, pixel_values: torch.FloatTensor, for _output_img in output_imgs: img_feature_proj = self.img_projection( _output_img.to(target_device).to(target_dtype)) - img_set_tensor.append(img_feature_proj) + img_set_tensor.append(img_feature_proj.squeeze(0)) return img_set_tensor -class Phi4MMAudioFeatureInputs(TypedDict): - type: Literal["audio_features"] - data: Tuple[NestedTensors] - """Shape: `((batch_size, num_audios, 80, M), )""" - - -class Phi4MMAudioEmbeddingInputs(TypedDict): - type: Literal["audio_embeds"] - data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" - - -Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] - - -def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): - """Create a Mel filter-bank the same as SpeechLib FbankFC. - - Args: - sample_rate (int): Sample rate in Hz. number > 0 [scalar] - n_fft (int): FFT size. int > 0 [scalar] - n_mel (int): Mel filter size. int > 0 [scalar] - fmin (float): lowest frequency (in Hz). If None use 0.0. - float >= 0 [scalar] - fmax: highest frequency (in Hz). If None use sample_rate / 2. - float >= 0 [scalar] - - Returns - out (numpy.ndarray): Mel transform matrix - [shape=(n_mels, 1 + n_fft/2)] +class Phi4MMImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bank_width = int(n_fft // 2 + 1) - if fmax is None: - fmax = sample_rate / 2 - if fmin is None: - fmin = 0 - assert fmin >= 0, "fmin cannot be negative" - assert (fmin < fmax <= - sample_rate / 2), "fmax must be between (fmin, samplerate / 2]" - - def mel(f): - return 1127.0 * np.log(1.0 + f / 700.0) - - def bin2mel(fft_bin): - return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) - - def f2bin(f): - return int((f * n_fft / sample_rate) + 0.5) - - # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] - klo = f2bin(fmin) + 1 - khi = f2bin(fmax) - - khi = max(khi, klo) - - # Spec 2: SpeechLib uses triangles in Mel space - mlo = mel(fmin) - mhi = mel(fmax) - m_centers = np.linspace(mlo, mhi, n_mels + 2) - ms = (mhi - mlo) / (n_mels + 1) - - matrix = np.zeros((n_mels, bank_width), dtype=np.float32) - for m in range(0, n_mels): - left = m_centers[m] - center = m_centers[m + 1] - right = m_centers[m + 2] - for fft_bin in range(klo, khi): - mbin = bin2mel(fft_bin) - if left < mbin < right: - matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms - - return matrix - - -class LogFbankProcessor: - - def __init__(self): - - self._eightk_method = "fillzero" - self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T - - self._hamming400 = np.hamming(400) # for 16k audio - self._hamming200 = np.hamming(200) # for 8k audio + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ - def extract_spectrogram(self, wav, fs): - """Extract spectrogram features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - if wav.ndim > 1: - wav = np.squeeze(wav) + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` - # by default, we extract the mean if stereo - if len(wav.shape) == 2: - wav = wav.mean(1) + This should be in `(height, width)` format. + """ - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 16000) - fs = 16000 - elif 8000 < fs < 16000: - wav = scipy.signal.resample_poly(wav, 1, fs // 8000) - fs = 8000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") - - if fs == 8000: - if self._eightk_method == "resample": - # Input audio is 8 kHz. Convert to 16 kHz before feature - # extraction - wav = scipy.signal.resample_poly(wav, 2, 1) - fs = 16000 - # Do nothing here for fillzero method - elif fs != 16000: - # Input audio is not a supported sample rate. - raise RuntimeError( - f"Input data using an unsupported sample rate: {fs}") - - preemphasis = 0.97 - - if fs == 8000: - n_fft = 256 - win_length = 200 - hop_length = 80 - fft_window = self._hamming200 - elif fs == 16000: - n_fft = 512 - win_length = 400 - hop_length = 160 - fft_window = self._hamming400 - - # Spec 1: SpeechLib cut remaining sample insufficient for a hop - n_batch = (wav.shape[0] - win_length) // hop_length + 1 - # Here we don't use stride_tricks since the input array may not satisfy - # memory layout requirement and we need writeable output - # Here we only use list of views before copy to destination - # so it is more efficient than broadcasting - y_frames = np.array( - [ - wav[_stride:_stride + win_length] - for _stride in range(0, hop_length * n_batch, hop_length) - ], - dtype=np.float32, - ) + num_img_tokens: list[int] + """Shape: `(batch_size * num_images)`""" - # Spec 2: SpeechLib applies preemphasis within each batch - y_frames_prev = np.roll(y_frames, 1, axis=1) - y_frames_prev[:, 0] = y_frames_prev[:, 1] - y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + image_attention_mask: torch.Tensor + """Shape: `(batch_size * num_images, H_mask, W_mask)`""" - S = np.fft.rfft(fft_window * y_frames, n=n_fft, - axis=1).astype(np.complex64) - if fs == 8000: - # Need to pad the output to look like 16 kHz data but with zeros in - # the 4 to 8 kHz bins. - frames, bins = S.shape - padarray = np.zeros((frames, bins)) - S = np.concatenate((S[:, 0:-1], padarray), - axis=1) # Nyquist bin gets set to zero +class Phi4MMImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - spec = np.abs(S).astype(np.float32) - return spec + `hidden_size` must match the hidden size of language model backbone. + """ - def extract_features(self, wav, fs): - """Extract log filterbank features from waveform. - Args: - wav (1D array): waveform of the input - fs (int): sampling rate of the waveform, 16000 or 8000. - If fs=8000, the waveform will be resampled to 16000Hz. - Output: - log_fbank (2D array): a TxD matrix of log Mel filterbank features. - D=80, and T is the number of frames. - """ - spec = self.extract_spectrogram(wav, fs) - spec_power = spec**2 - fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) - log_fbank = np.log(fbank_power).astype(np.float32) +class Phi4MMAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_audios, 80, M)""" - return log_fbank +class Phi4MMAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" -@lru_cache -def audio_feature_extractor() -> LogFbankProcessor: - # Creates an instance of the audio processor, needed to extract the - # the audio features from the sound file - # LRU cache ensures that we only make one copy - return LogFbankProcessor() +Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] +Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] -def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor): - """ - compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing - only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be - 32x32 feature map - NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 - """ - assert vit_image_size % vit_patch_size == 0, \ - "vit_image_size must be divisible by vit_patch_size" - assert vit_image_size // vit_patch_size % token_compression_factor == 0, \ - "vit_image_size // vit_patch_size must be divisible by "\ - "token_compression_factor" - - target_aspect_ratio, target_height, target_width = ( - _find_target_aspect_ratio(image, - vit_image_size, - dynamic_hd_size, - min_num=1)) - assert target_aspect_ratio[ - 0] * vit_image_size == target_width, \ - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - assert target_aspect_ratio[ - 1] * vit_image_size == target_height, \ - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) - - padding_height, padding_width = _get_padding_size(image, target_height, - target_width) - assert padding_width == 0 or padding_height == 0, \ - "padding_width or padding_height must be 0" - - target_feat_width = target_width // vit_patch_size - target_feat_height = target_height // vit_patch_size - if padding_width >= vit_patch_size: - assert padding_height == 0, "padding_height not 0" - non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) - non_pad_feat_height = target_feat_height - elif padding_height >= vit_patch_size: - assert padding_width == 0, "padding_width not 0" - non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) - non_pad_feat_width = target_feat_width - else: - # small padding shorter than a vit patch - non_pad_feat_width = target_feat_width - non_pad_feat_height = target_feat_height - - feat_width = non_pad_feat_width // token_compression_factor - feat_height = non_pad_feat_height // token_compression_factor - # NOTE it's possible that the non-padding feature is not divisible - if non_pad_feat_width % token_compression_factor != 0: - feat_width += 1 - if non_pad_feat_height % token_compression_factor != 0: - feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height - num_hd_newline_tokens = feat_height - vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor)**2 - num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor - - return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens + - num_hd_newline_tokens + num_global_image_newline_tokens) - - -def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]: +def cat_with_pad(tensors, dim, padding_value=0): """ - Compute the output size of the `extract_features` method. - - Args: - wav_length (int): Length of the input waveform in samples. - fs (int): Sampling rate of the waveform, either 16000 or 8000. - - Returns: - tuple (int, int): Output size as (T, D), where: - T: Number of time frames. - D: Number of Mel filterbank bins (80). + cat along dim, while pad to max for all other dims """ + ndim = tensors[0].dim() + assert all( + t.dim() == ndim for t in + tensors[1:]), "All tensors must have the same number of dimensions" - # Resample to 16000 or 8000 if needed - if fs > 16000: - wav_length //= fs // 16000 - fs = 16000 - elif 8000 <= fs < 16000: - # We'll resample to 16K from 8K - wav_length *= 2 - fs = 16000 - elif fs < 8000: - raise RuntimeError(f"Unsupported sample rate {fs}") + out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] + out_size[dim] = sum(t.shape[dim] for t in tensors) + output = tensors[0].new_full(out_size, padding_value) - # Spectrogram parameters for 16 kHz - win_length = 400 # Frame length in samples - hop_length = 160 # Frame shift in samples - mel_bins = 80 # Number of mel filterbank bins + index = 0 + for t in tensors: + # Create a slice list where every dimension except dim is full slice + slices = [slice(0, t.shape[d]) for d in range(ndim)] + # Update only the concat dimension slice + slices[dim] = slice(index, index + t.shape[dim]) - # Calculate number of frames (T) - T = (wav_length - win_length) // hop_length + 1 - if T < 1: - raise ValueError("Waveform too short for given parameters.") + output[slices] = t + index += t.shape[dim] - # Return time frames (T) and mel bins (D) - return T, mel_bins + return output -def _get_audio_embed_sizes(audios, ctx: InputContext): - """ - Get the audio embedding sizes for each audio file. +class Phi4MMProcessingInfo(BaseProcessingInfo): - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. + def get_hf_processor( + self, + *, + dynamic_hd: Optional[int] = None, + **kwargs: object, + ) -> ProcessorMixin: + if dynamic_hd is not None: + kwargs["dynamic_hd"] = dynamic_hd - Returns: - List[int]: List of audio embedding sizes. - """ - audio_embed_sizes = [] - for audio in audios: - audio_data, sf = audio - audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf) - audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - audio_embed_sizes.append(audio_embed_size) - return audio_embed_sizes + return self.ctx.get_hf_processor(**kwargs) + @property + def image_tokens(self) -> list[str]: + return [f"<|image_{i+1}|>" for i in range(100)] -def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""): - """ - The following will search for `<|audio_{idx}|>` tokens and - return a mapping of audio placeholder tokens to audio placeholder token ids - based on the size of the audio embeddings. + @property + def audio_tokens(self) -> list[str]: + return [f"<|audio_{i+1}|>" for i in range(100)] - Args: - audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of - waveform and sample rate. - ctx (InputContext): Input context. - prompt_str (str): The prompt string. + def get_dynamic_hd( + self, + processor: Optional[ProcessorMixin] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + image_processor = processor.image_processor + return image_processor.dynamic_hd - Returns: - Dict[str, List[int]]: Mapping of audio placeholder tokens to audio - placeholder token ids. + def get_feature_extractor(self) -> SequenceFeatureExtractor: + return self.get_hf_processor().audio_processor - """ - if len(audios) == 0: - return {} - - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str) - audio_ids = [int(audio_id) for audio_id in audio_ids] - assert len(audio_ids) == len( - audio_embed_sizes - ), "Number of audio tokens and audio features do not match" - assert tuple(audio_ids) == tuple(range(1, - len(audio_ids) + - 1)), "Audio ids are not in order!" - audio_id_to_input_ids = { - f"<|audio_{audio_id}|>": - [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes) - } + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None} - return audio_id_to_input_ids - - -def _count_image_tokens(images, ctx: InputContext): - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] - - image_token_counts = [ - _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size, - vit_patch_size, token_compression_factor) - for image in images - ] - return image_token_counts - - -def _get_image_id_to_input_ids(images, prompt, ctx: InputContext): - if len(images) == 0: - return {} - - image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt) - image_ids = [int(image_id) for image_id in image_ids] - assert len(image_ids) == len( - set(image_ids)), "Duplicate image tokens in prompt" - assert len(images) == len( - image_ids), "Number of images and image tokens in prompt do not match" - - # NOTE the following assertion is not strictly necessary - assert tuple(image_ids) == tuple(range(1, - len(image_ids) + - 1)), "Image ids are not in order" - - image_token_counts = _count_image_tokens(images, ctx) - image_id_to_input_ids = { - f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens - for image_id, num_tokens in zip(image_ids, image_token_counts) - } - return image_id_to_input_ids + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return { + "image": self.get_max_image_tokens(), + "audio": self.get_max_audio_tokens(), + } + def get_max_audio_tokens(self) -> int: + sr = self.get_feature_extractor().sampling_rate + num_frames = self.get_audio_num_frames(_AUDIO_MAX_SOUNDFILE_SIZE, sr) + return self._compute_audio_embed_size(num_frames) -def input_processor_for_phi4mm(ctx: InputContext, - inputs: DecoderOnlyInputs) -> TokenInputs: - """ - Implements the input processor, which transforms the input prompt ids - to include the audio placeholder token. This will become the `input_ids` - in `forward` for the model. + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) - Args: - ctx (InputContext): Input context. - inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids) - to process. + def _find_target_aspect_ratio( + self, + orig_width: int, + orig_height: int, + image_size: int, + max_num: int, + min_num: int, + ): + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + image_processor = self.get_hf_processor().image_processor + target_aspect_ratio = image_processor.find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + orig_width, + orig_height, + image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + return target_aspect_ratio, target_height, target_width - Returns: - TokenInputs: Processed inputs - """ - multi_modal_data = inputs.get("multi_modal_data") - if (multi_modal_data is None or - ("audio" not in multi_modal_data and "image" not in multi_modal_data)): - # pure text input, so no need to do pre-processing - return inputs - - prompt_str = inputs.get("prompt") - prompt_token_ids = inputs.get("prompt_token_ids") - # for offline_inference, we will get str input and we parse MM special - # tokens from it - # (ignore prompt_token_ids) - # for OAI server, we will get prompt_token_ids, where MM special tokens - # are already parsed - - if 'audio' in multi_modal_data: - audios = multi_modal_data["audio"] - - if not isinstance(audios, list): - audios = [audios] - if prompt_str is not None: - audio_id_to_input_ids = _get_audio_id_to_input_ids( - audios, ctx, prompt_str=prompt_str) - audio_embed_sizes = [] - elif prompt_token_ids is not None: - audio_id_to_input_ids = {} - audio_embed_sizes = _get_audio_embed_sizes(audios, ctx) - else: - audio_id_to_input_ids = {} - audio_embed_sizes = [] - - if 'image' in multi_modal_data: - # PIL Image or list of PIL Images - images = multi_modal_data["image"] - if not isinstance(images, list): - images = [images] - if prompt_str is not None: - image_id_to_input_ids = _get_image_id_to_input_ids( - images, prompt_str, ctx) - image_token_counts = [] - elif prompt_token_ids is not None: - image_id_to_input_ids = {} - image_token_counts = _count_image_tokens(images, ctx) - else: - image_id_to_input_ids = {} - image_token_counts = [] - - # Handle the case where the prompt is a string and we need to manually - # tokenize it. - # In this case, the `audio_id_to_input_ids` dict will be mapping from - # an audio placeholder - # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the - # given audio length. - if prompt_str: - pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)" - prompt_chunk_strings = re.split(pattern, prompt_str) - prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""] - - # Create the new input_ids with the placeholder image and audio - # tokens inserted - tokenizer = cached_tokenizer_from_config(ctx.model_config) - input_ids = [] - has_imag, has_audio, has_user_text_input = False, False, False - for prompt_chunk_string in prompt_chunk_strings: - if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(image_id_to_input_ids[prompt_chunk_string]) - has_imag = True - elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string): - input_ids.extend(audio_id_to_input_ids[prompt_chunk_string]) - has_audio = True - else: - curr_token_ids = tokenizer(prompt_chunk_string).input_ids - if not has_user_text_input: - for token_id in curr_token_ids: - if token_id not in NON_USER_INPUT_TOKENS: - has_user_text_input = True - break - input_ids.extend(curr_token_ids) - if has_audio and has_imag and has_user_text_input: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # Handle the case where the prompt is already tokenized - else: - assert prompt_token_ids is not None, \ - "If string prompt isn't provided, prompt_token_ids must be" - - i = 0 - input_ids = prompt_token_ids - # only needed for later assertion - img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0 - image_token_count_iter = iter(image_token_counts) - audio_embed_size_iter = iter(audio_embed_sizes) - while i < len(input_ids): - token_id = input_ids[i] - if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID: - token_count = next(audio_embed_size_iter) - audio_cnt += 1 - elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID: - token_count = next(image_token_count_iter) - img_cnt += 1 - else: - user_text_input_cnt += 1 if token_id not in \ - NON_USER_INPUT_TOKENS else 0 - i += 1 - continue - tokens = [token_id] * token_count - input_ids = input_ids[:i] + tokens + input_ids[i + 1:] - i += token_count - - if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0: - raise ValueError( - "Phi4MMForCausalLM does not support text + audio + image" + - " inputs in the same prompt") - # If the below assertion fails, it might be that input pure-text - # messages contain image/audio special tokens literally - # (<|endoftext10|>, <|endoftext11|>). - assert (img_cnt == len(image_token_counts)), ( - f"Number of image tokens in prompt_token_ids ({img_cnt}) " - f"does not match number of images ({len(image_token_counts)})") - assert (audio_cnt == len(audio_embed_sizes)), ( - f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " - f"does not match number of audios ({len(audio_embed_sizes)})") - - # NOTE: Create a defensive copy of the original inputs - return token_inputs( - prompt_token_ids=input_ids, - prompt=prompt_str, - multi_modal_data=multi_modal_data, - ) + def _compute_num_image_tokens( + self, + orig_width: int, + orig_height: int, + dynamic_hd_size: int, + vit_image_size: int, + vit_patch_size: int, + token_compression_factor: int = 2, + ): + """ + compute the number of tokens an image is expected to take up considering + the image encoder architecture and exclude output features containing + only padding pixels + for siglip, vit_image_size=448, vit_patch_size=14, so output will be + 32x32 feature map + NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 + """ + assert vit_image_size % vit_patch_size == 0, ( + "vit_image_size must be divisible by vit_patch_size") + assert (vit_image_size // vit_patch_size % + token_compression_factor == 0), ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor") + + target_aspect_ratio, target_height, target_width = ( + self._find_target_aspect_ratio(orig_width, + orig_height, + vit_image_size, + dynamic_hd_size, + min_num=1)) + assert target_aspect_ratio[0] * vit_image_size == target_width, ( + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + assert target_aspect_ratio[1] * vit_image_size == target_height, ( + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") + assert (target_height % vit_image_size == 0 + and target_width % vit_image_size == 0) + + padding_height, padding_width = _get_padding_size( + orig_width, orig_height, target_height, target_width) + assert padding_width == 0 or padding_height == 0, \ + "padding_width or padding_height must be 0" + + target_feat_width = target_width // vit_patch_size + target_feat_height = target_height // vit_patch_size + if padding_width >= vit_patch_size: + assert padding_height == 0, "padding_height not 0" + non_pad_feat_width = target_feat_width - math.floor( + padding_width / vit_patch_size) + non_pad_feat_height = target_feat_height + elif padding_height >= vit_patch_size: + assert padding_width == 0, "padding_width not 0" + non_pad_feat_height = target_feat_height - math.floor( + padding_height / vit_patch_size) + non_pad_feat_width = target_feat_width + else: + # small padding shorter than a vit patch + non_pad_feat_width = target_feat_width + non_pad_feat_height = target_feat_height + + feat_width = non_pad_feat_width // token_compression_factor + feat_height = non_pad_feat_height // token_compression_factor + # NOTE it's possible that the non-padding feature is not divisible + if non_pad_feat_width % token_compression_factor != 0: + feat_width += 1 + if non_pad_feat_height % token_compression_factor != 0: + feat_height += 1 + num_hd_patch_tokens = feat_width * feat_height + num_hd_newline_tokens = feat_height + vit_feature_size = vit_image_size // vit_patch_size + num_global_image_tokens = (vit_feature_size // + token_compression_factor)**2 + num_sep_tokens = 1 + num_global_image_newline_tokens = \ + vit_feature_size // token_compression_factor + + return (num_global_image_tokens + num_sep_tokens + + num_hd_patch_tokens + num_hd_newline_tokens + + num_global_image_newline_tokens) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[ProcessorMixin] = None, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + vit_patch_size = prepro_config['vit_patch_size'] + token_compression_factor = prepro_config['token_compression_factor'] + + dynamic_hd_size = self.get_dynamic_hd(processor=processor) + + image_num_tokens = self._compute_num_image_tokens( + image_width, + image_height, + dynamic_hd_size=dynamic_hd_size, + vit_image_size=vit_image_size, + vit_patch_size=vit_patch_size, + token_compression_factor=token_compression_factor, + ) -def _compute_audio_embed_size(hf_config, audio_frames): - """ - Compute the audio embedding size based on the audio frames and - compression rate. - """ - compression_rate = hf_config.embd_layer['audio_embd_layer'][ - 'compression_rate'] - # NOTE: this is a hard-coded value but might be configurable in the future - qformer_compression_rate = 1 - integer = audio_frames // compression_rate - remainder = audio_frames % compression_rate + return image_num_tokens - result = integer if remainder == 0 else integer + 1 + def get_image_size_with_most_features( + self, + processor: Optional[ProcessorMixin] = None, + ) -> ImageSize: + hf_config = self.get_hf_config() + vision_encoder_name = hf_config.img_processor + if vision_encoder_name is None: + vision_encoder_name = SIGLIP_NAME + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ + vision_encoder_name] + vit_image_size = prepro_config['vit_image_size'] + + max_side = vit_image_size * self.get_dynamic_hd(processor=processor) + return ImageSize(height=max_side, width=vit_image_size) + + def get_audio_num_frames(self, audio_len: int, sr: float) -> int: + """ + Compute the output size of the `extract_features` method. - integer = result // qformer_compression_rate - remainder = result % qformer_compression_rate - result = integer if remainder == 0 else integer + 1 # qformer compression + Args: + audio_len (int): Length of the input waveform in samples. + sr (float): Sampling rate of the waveform, either 16000 or 8000. - return result + Returns: + tuple (int, int): Output size as (T, D), where: + T: Number of time frames. + D: Number of Mel filterbank bins (80). + """ + # Resample to 16000 or 8000 if needed + if sr > 16000: + audio_len //= sr // 16000 + elif 8000 <= sr < 16000: + # We'll resample to 16K from 8K + audio_len *= 2 + elif sr < 8000: + raise RuntimeError(f"Unsupported sample rate {sr}") + + # Spectrogram parameters for 16 kHz + win_length = 400 # Frame length in samples + hop_length = 160 # Frame shift in samples + + # Calculate number of frames (T) + num_frames = (audio_len - win_length) // hop_length + 1 + if num_frames < 1: + raise ValueError("Waveform too short for given parameters.") + + # Return time frames (T) + return num_frames + + def _compute_audio_embed_size(self, audio_frames: int) -> int: + """ + Compute the audio embedding size based on the audio frames and + compression rate. + """ + hf_config = self.get_hf_config() + compression_rate = hf_config.embd_layer['audio_embd_layer'][ + 'compression_rate'] + # NOTE: this is a hard-coded value but might be configurable + # in the future + qformer_compression_rate = 1 + integer = audio_frames // compression_rate + remainder = audio_frames % compression_rate -def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int: - return 10000 + result = integer if remainder == 0 else integer + 1 + integer = result // qformer_compression_rate + remainder = result % qformer_compression_rate + # qformer compression + result = integer if remainder == 0 else integer + 1 -def dummy_audio_for_phi4mm(audio_count: int) -> dict: - """ - Create dummy audio data for the Phi4MM model, which is used for profiling. + return result - Args: - audio_count (int): Number of audio samples. - Returns: - dict: Dummy audio data. - """ - dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0) - return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count +class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) -def dummy_image_for_phi4mm(width: int, height: int): - image = Image.new('RGB', (width, height), color='black') - return image + image_tokens: list[str] = self.info.image_tokens[:num_images] + audio_tokens: list[str] = self.info.audio_tokens[:num_audios] + return "".join(image_tokens + audio_tokens) -def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - """ - Create dummy sequence (input_ids) and audio data for the Phi4MM model, - which is used for profiling. + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + num_images = mm_counts.get("image", 0) - In this case, the sequence data is a bunch of 0s with a number of audio - tokens that correspond to the audio embed size of the - _AUDIO_MAX_SOUNDFILE_SIZE. + target_width, target_height = \ + self.info.get_image_size_with_most_features() - Args: - ctx (InputContext): Input context. - seq_len (int): Length of the sequence. - mm_counts (Mapping[str, int]): Multi-modal counts. + target_width, target_height = \ + self.info.get_image_size_with_most_features() - Returns: - Tuple: Dummy sequence data and dummy audio data. - """ - audio_count = mm_counts["audio"] - audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE, - DUMMY_SAMPLING_FREQUENCY) - audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(), - audio_frames) - - image_count = mm_counts["image"] - dummy_image = get_max_dummy_image(ctx) - max_image_tokens = get_max_phi4mm_image_tokens(ctx) - total_image_tokens = image_count * max_image_tokens - - if seq_len - audio_feature_size * audio_count - total_image_tokens < 0: - raise RuntimeError( - f"Phi4MM cannot process {audio_count} audios and {image_count}" - f"images in a prompt, please increase max_model_len to be at" - f" larger than " - f"{audio_feature_size * audio_count + total_image_tokens}" - " or reduce audio/image limit by --limit-mm-per-prompt.") - - if audio_feature_size * audio_count > total_image_tokens: - seq_data = SequenceData.from_prompt_token_counts( - (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count), - (0, seq_len - audio_feature_size * audio_count), - ) mm_data = { - "audio": dummy_audio_for_phi4mm(audio_count), + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios), } - else: - seq_data = SequenceData.from_prompt_token_counts( - (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens), - (0, seq_len - total_image_tokens), - ) - mm_data = { - "image": [dummy_image] * image_count, - } - return DummyData(seq_data, mm_data) + return mm_data -def input_mapper_for_phi4mm_audio(ctx: InputContext, - data: object) -> MultiModalKwargs: - """ - This function is used to create the MultiModalKwargs for the Phi4MM - (audio) model. - Specifically, for audio, we extract the audio features from the sound - file and create pairs of audio features and audio embed lengths (the - latter of which is used to repeat the audio placeholder token in the - input prompt IDs). - These pairs are used, downstream, in `_audio_features_to_embeddings` - (via `_process_audio_input`). - - Note that the incoming audio data (each entry in `data`) is a tuple of - the audio data and the sampling frequency (e.g. from soundfile.read). - - Args: - ctx (InputContext): Input context. - data (object): Audio data. - - Returns: - MultiModalKwargs: Multi-modal inputs. - """ - if not isinstance(data, list): - data = [data] - - if len(data) == 0: - return MultiModalKwargs() - - audio_features = [] - for audio_input in data: - if not isinstance(audio_input, tuple): - raise NotImplementedError( - f"Unsupported data type: {type(audio_input)}") - - audio, sf = audio_input - feature_extractor = audio_feature_extractor() - single_audio_features = feature_extractor.extract_features(audio, sf) - feat_stride = (1 if not hasattr(feature_extractor, "stride") else - feature_extractor.stride) - audio_frames = len(single_audio_features) * feat_stride - single_audio_embed_size = _compute_audio_embed_size( - ctx.get_hf_config(), audio_frames) - single_audio_feature_audio_len_pair = ( - single_audio_features, - [single_audio_embed_size], - ) - audio_features.append(single_audio_feature_audio_len_pair) - return MultiModalKwargs({"audio_features": audio_features}) - - -def input_mapper_for_phi4mm_image(ctx: InputContext, data: object): - if not isinstance(data, list): - data = [data] - # data: list of PIL images - if len(data) == 0: - return MultiModalKwargs() - hf_config = ctx.get_hf_config() - vision_encoder_name = hf_config.img_processor - if vision_encoder_name is None: - vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] - dynamic_hd_size = prepro_config['dynamic_hd'] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - - image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size, - vit_patch_size) - return MultiModalKwargs({ - "pixel_values": - image_input_dict["pixel_values"], - "image_sizes": - image_input_dict["image_sizes"], - "image_attention_mask": - image_input_dict["image_attention_mask"], - "num_img_tokens": - image_input_dict["num_img_tokens"], - }) +class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): -def cat_with_pad(tensors, dim, padding_value=0): - """ - cat along dim, while pad to max for all other dims - """ - ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate, + audio_resample_method="scipy") - out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] - out_size[dim] = sum(t.shape[dim] for t in tensors) - output = tensors[0].new_full(out_size, padding_value) + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + sr = self.info.get_feature_extractor().sampling_rate + if (audio_data := mm_data.get("audios", [])): + mm_data['audios'] = [(data, sr) for data in audio_data] + + processed_outputs = super()._call_hf_processor(prompt, mm_data, + mm_kwargs) + + num_img_tokens = [ + self.info.get_num_image_tokens(image_width=img_size[0], + image_height=img_size[1]) + for img_size in processed_outputs["image_sizes"] + ] + processed_outputs["num_img_tokens"] = num_img_tokens - index = 0 - for t in tensors: - # Create a slice list where every dimension except dim is full slice - slices = [slice(0, t.shape[d]) for d in range(ndim)] - # Update only the concat dimension slice - slices[dim] = slice(index, index + t.shape[dim]) + audio_features = processed_outputs['input_audio_embeds'] + feature_sizes = [ + self.info.get_audio_num_frames(len(audio), sr) + for audio in audio_data + ] + processed_outputs['input_audio_embeds'] = [ + audio_features[idx, :size] + for idx, size in enumerate(feature_sizes) + ] - output[slices] = t - index += t.shape[dim] + return processed_outputs - return output + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_image_embeds=MultiModalFieldConfig.batched("image"), + image_attention_mask=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + num_img_tokens=MultiModalFieldConfig.batched("image"), + input_audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_tokens: list[str] = self.info.image_tokens # type: ignore + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + feature_extractor = self.info.get_feature_extractor() + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + def get_image_replacement_phi4mm(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens + + return image_tokens + + def get_audio_replacement_phi4mm(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + # TODO(Isotr0py): support embedding inputs + audio_len = audios.get_audio_length(item_idx) + audio_frames = self.info.get_audio_num_frames( + audio_len, feature_extractor.sampling_rate) + audio_embed_size = self.info._compute_audio_embed_size( + audio_frames) + + audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size + + return audio_tokens + + num_images = mm_items.get_count("image", strict=False) + num_audios = mm_items.get_count("audio", strict=False) + + image_repl = [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_image_replacement_phi4mm, + ) for image_token in image_tokens[:num_images] + ] + audio_repl = [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_audio_replacement_phi4mm, + ) for audio_token in audio_tokens[:num_audios] + ] + return image_repl + audio_repl -@MULTIMODAL_REGISTRY.register_input_mapper("audio", - input_mapper_for_phi4mm_audio) -@MULTIMODAL_REGISTRY.register_input_mapper("image", - input_mapper_for_phi4mm_image) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_phi4mm_audio_tokens) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_phi4mm_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm) -@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm) -class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal, - SupportsV0Only): +@MULTIMODAL_REGISTRY.register_processor( + Phi4MMMultiModalProcessor, + info=Phi4MMProcessingInfo, + dummy_inputs=Phi4MMDummyInputsBuilder, +) +class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ @@ -1518,48 +991,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() - - def _audio_features_to_embeddings( - self, - input_ids: torch.Tensor, - input_features: List[torch.Tensor], - audio_input_sizes: torch.Tensor, - audio_projection_mode: str, - ) -> torch.Tensor: - """ - Convert audio features to embeddings, which are used as input to the - model (via `inputs_embeds`). - - Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case). - input_features (list[torch.Tensor]): Input features (the audio - embeddings). - audio_input_sizes (list[torch.Tensor]): Audio input sizes (the - audio embed lengths to use for padding the audio placeholder token - in the input prompt IDs). - """ - # The audio projection can either be a single linear or Sequential, - # so handle both cases - if isinstance(self.embed_tokens_extend.audio_projection, - nn.Sequential): - target_dtype = self.embed_tokens_extend.audio_projection[ - 0].bias.dtype - else: - target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype - - audio_input = [ - input.unsqueeze(0).to(target_dtype) for input in input_features - ] - kwargs = { - "wte": self.model.embed_tokens, - 'audio_projection_mode': audio_projection_mode - } - audio_embeddings = self.embed_tokens_extend(input_ids, audio_input, - audio_input_sizes, - **kwargs) - audio_embeddings = audio_embeddings.to(target_dtype) - return audio_embeddings + self.sampler = get_sampler() def _parse_and_validate_audio_input( self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: @@ -1574,7 +1006,7 @@ def _parse_and_validate_audio_input( Returns: Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. """ - audio_features = kwargs.pop("audio_features", None) + audio_features = kwargs.pop("input_audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: @@ -1586,7 +1018,7 @@ def _parse_and_validate_audio_input( f"Got type: {type(audio_features)}") return Phi4MMAudioFeatureInputs(type="audio_features", - data=audio_features) + data=flatten_bn(audio_features)) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): @@ -1598,8 +1030,7 @@ def _parse_and_validate_audio_input( raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, input_ids: torch.Tensor, - audio_input: Phi4MMAudioInputs, + def _process_audio_input(self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input @@ -1607,8 +1038,6 @@ def _process_audio_input(self, input_ids: torch.Tensor, created by `input_mapper_for_phi4mm_audio`. Args: - input_ids (torch.Tensor): Input IDs (the prompt in this case, - before the audio token replication). audio_input (Phi4MMAudioInputs): Audio input. Returns: @@ -1620,21 +1049,20 @@ def _process_audio_input(self, input_ids: torch.Tensor, audio_features = audio_input["data"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) - audio_feature = [i[0] for j in audio_features for i in j] - audio_feature_len = [i[1].item() for j in audio_features for i in j] - # Add the batch dim via `squeeze` - return self._audio_features_to_embeddings( - input_ids.unsqueeze(0), - audio_feature, - audio_feature_len, - audio_projection_mode, - ).squeeze(0) + dtype = next(self.embed_tokens_extend.parameters()).dtype + audio_embeds = [ + self.embed_tokens_extend( + features.to(dtype), + audio_projection_mode=audio_projection_mode, + ) for features in audio_features + ] + return audio_embeds def _parse_and_validate_image_input(self, **kwargs: object) -> Optional[Dict]: - pixel_values: Optional[Dict] = kwargs.get("pixel_values") - if pixel_values is None: + input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") + if input_image_embeds is None: return None image_sizes = kwargs.get("image_sizes") @@ -1643,23 +1071,24 @@ def _parse_and_validate_image_input(self, assert image_sizes is not None and image_attention_mask is not None\ and num_img_tokens is not None, "Missing image inputs" - if isinstance(pixel_values, list): - assert pixel_values[0].dim() == 5, "Incorrect image inputs" + if is_list_of(input_image_embeds, torch.Tensor): + assert all(p.dim() == 5 + for p in input_image_embeds), "Incorrect image inputs" # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. # need to pad along num_hd_patches. # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - pixel_values = cat_with_pad(pixel_values, dim=0) - elif isinstance(pixel_values, torch.Tensor): + input_image_embeds = cat_with_pad(input_image_embeds, dim=0) + elif isinstance(input_image_embeds, torch.Tensor): # dimension: batch_size, num_img_per_example, num_hd_patches, # channels, height, width. # we flatten first 2 dims to make it a single large batch for # SigLIP Encoder. - assert pixel_values.dim() == 6, "Incorrect image inputs" - pixel_values = pixel_values.flatten(0, 1) + assert input_image_embeds.dim() == 6, "Incorrect image inputs" + input_image_embeds = input_image_embeds.flatten(0, 1) else: - raise ValueError("Incorrect pixel_values inputs") + raise ValueError("Incorrect input_image_embeds inputs") if isinstance(image_attention_mask, list): image_attention_mask = cat_with_pad(image_attention_mask, dim=0) @@ -1685,80 +1114,140 @@ def _parse_and_validate_image_input(self, else: raise ValueError("Incorrect image_attention_mask inputs") - return { - 'pixel_values': pixel_values, - 'image_sizes': image_sizes, - 'image_attention_mask': image_attention_mask, - 'num_img_tokens': num_img_tokens, - } + return Phi4MMImagePixelInputs( + type="pixel_values", + data=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + num_img_tokens=num_img_tokens, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("input_image_embeds", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("input_audio_embeds", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_image_input( + self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + dtype = next(self.vision_encoder.parameters()).dtype + pixel_values = image_input['data'].to(dtype) + image_sizes = image_input['image_sizes'] + image_attention_mask = image_input['image_attention_mask'] + image_embeds = self.vision_encoder(pixel_values, image_sizes, + image_attention_mask) + return image_embeds + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None - def merge_image_features_to_inputs_embeds( + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + audio_projection_mode = 'speech' + for modality in modalities: + # make sure process images first + if modality == "images": + audio_projection_mode = "vision" + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(vision_embeddings) + if modality == "audios": + audio_input = modalities["audios"] + audio_embeddings = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + multimodal_embeddings += tuple(audio_embeddings) + + return multimodal_embeddings + + def get_input_embeddings( self, input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_set_tensors: List[torch.Tensor], - ): - position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=True) - - assert all([t.shape[0] == 1 for t in image_set_tensors - ]), 'img_set_tensor should have shape (1, N_tokens, C)' - # Shape: (merged_N_tokens, C) - image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0) - image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to( - inputs_embeds.device) - merged_embeds = inputs_embeds.index_put( - indices=position_tuple, - values=image_set_tensor, - accumulate=False, - ) - return merged_embeds + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.embed_tokens(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Phi4MMImagePixelInputs] = None, + audio_input: Optional[Phi4MMAudioFeatureInputs] = None, + ) -> torch.Tensor: + audio_projection_mode = 'speech' + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID, + ) + audio_projection_mode = 'vision' + + if audio_input is not None: + audio_embeds = self._process_audio_input( + audio_input, audio_projection_mode=audio_projection_mode) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + audio_embeds, + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID, + ) + return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - # Each entry in this is a pair of audio_features and audio_embed - # lengths + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs) - image_inputs = self._parse_and_validate_image_input(**kwargs) - - has_audio = audio_input is not None - has_image = image_inputs is not None - - if has_audio: - audio_projection_mode = 'vision' if has_image else 'speech' - inputs_embeds = self._process_audio_input( - input_ids, audio_input, audio_projection_mode) - - if has_image: - dtype = self.vision_encoder.img_processor.embeddings.\ - patch_embedding.weight.dtype - pixel_values = image_inputs['pixel_values'].to(dtype) - image_sizes = image_inputs['image_sizes'] - image_attention_mask = image_inputs['image_attention_mask'] - image_set_tensors = self.vision_encoder( - pixel_values, image_sizes, image_attention_mask) - if not has_audio: - inputs_embeds = self.model.embed_tokens(input_ids) - - inputs_embeds = self.merge_image_features_to_inputs_embeds( - input_ids, inputs_embeds, image_set_tensors) - - if has_image or has_audio: - # multi-modal input, we have set inputs_embeds properly in - # previous steps - input_ids = None - else: - # text-only, we keep using original input_ids + + if image_input is None and audio_input is None: inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + audio_input=audio_input) + input_ids = None hidden_states = self.model( input_ids, diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index db90848f9809..34a7a73d057a 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -1159,8 +1159,11 @@ def get_audio_features( input_embeds: torch.FloatTensor, audio_attention_mask: torch.Tensor = None, audio_projection_mode: str = "speech", - ): - + ) -> torch.FloatTensor: + """ + arguments: + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ if self.freeze_audio_processor: with torch.no_grad(): audio_features, masks = self.encoder(input_embeds, @@ -1210,62 +1213,20 @@ def get_audio_features( def forward( self, - input_ids: torch.LongTensor, - input_embeds: torch.FloatTensor, - audio_embed_sizes, - **kwargs, + audio_features: torch.FloatTensor, + audio_attention_mask: torch.Tensor = None, + audio_projection_mode: str = "speech", ) -> torch.FloatTensor: """ arguments: - input_ids: input text ids (B, U) - input_embeds: audio features (B, T, D) B: num audios in a sequence + audio_features: audio features (T, D) + + returns: + audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ - assert input_embeds is not None and len(input_embeds) == len( - audio_embed_sizes) - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - with torch.no_grad(): - positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero( - as_tuple=False) - - if not isinstance(input_embeds, list): - input_embeds = [input_embeds] - - audio_projection_mode = kwargs.get("audio_projection_mode", "speech") - audio_set_tensor = [ - self.get_audio_features( - input_embed, audio_projection_mode=audio_projection_mode) - for input_embed in input_embeds - ] - - with torch.no_grad(): - input_ids.clamp_min_(0).clamp_max_(self.vocab_size) - - if "wte" in kwargs: - # we use the token embedding layer from the huggingface model, this - # is REQUIRED to make sure we are using the loaded weights. - hidden_states = kwargs["wte"](input_ids) - else: - # otherwise, we use token embedding in pretrained mixformer from - # phi team - hidden_states = self.wte(input_ids) - - if len(positions.tolist()) > 0: - assert sum(audio_embed_sizes) == len( - positions - ), "please ensure the encoder outputs have the same length as"\ - " defined in input_ids!" - idx = 0 - for i in range(len(audio_embed_sizes)): - cnt = audio_embed_sizes[i] - assert audio_set_tensor[i].shape[0] == 1 - hidden_states[ - positions[idx, 0], - positions[idx, 1]:positions[idx, 1] + cnt, - ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to( - hidden_states.dtype).to(hidden_states.device)) - idx += cnt - - return hidden_states + audio_embeds = self.get_audio_features( + audio_features.unsqueeze(0), + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + return audio_embeds.squeeze(0) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f379ec1682a3..70a912c9c9ef 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 - import base64 from io import BytesIO from pathlib import Path +from typing import Literal, Optional import numpy as np import numpy.typing as npt @@ -43,7 +43,7 @@ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: "There is no default maximum multimodal tokens") -def resample_audio( +def resample_audio_librosa( audio: npt.NDArray[np.floating], *, orig_sr: float, @@ -52,6 +52,55 @@ def resample_audio( return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) +def resample_audio_scipy( + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + target_sr: float, +): + # lazy import scipy.signal, otherwise it will crash doc build. + import scipy.signal + + if orig_sr > target_sr: + return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr) + elif orig_sr < target_sr: + return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1) + return audio + + +class AudioResampler: + """Resample audio data to a target sample rate.""" + + def __init__( + self, + target_sr: Optional[float] = None, + method: Literal["librosa", "scipy"] = "librosa", + ): + self.target_sr = target_sr + self.method = method + + def resample( + self, + audio: npt.NDArray[np.floating], + *, + orig_sr: float, + ) -> npt.NDArray[np.floating]: + if self.target_sr is None: + raise RuntimeError("Audio resampling is not supported when " + "`target_sr` is not provided") + if self.method == "librosa": + return resample_audio_librosa(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + elif self.method == "scipy": + return resample_audio_scipy(audio, + orig_sr=orig_sr, + target_sr=self.target_sr) + else: + raise ValueError(f"Invalid resampling method: {self.method}. " + "Supported methods are 'librosa' and 'scipy'.") + + class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index fc5a294564e3..9707b9cfcf8b 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -3,8 +3,8 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, + TypeVar, Union) import numpy as np import torch @@ -14,7 +14,7 @@ from vllm.utils import is_list_of -from .audio import resample_audio +from .audio import AudioResampler from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, VideoItem) @@ -308,10 +308,18 @@ class MultiModalDataParser: items to the model's expected sampling rate. """ - def __init__(self, *, target_sr: Optional[float] = None) -> None: + def __init__( + self, + *, + target_sr: Optional[float] = None, + audio_resample_method: Literal["librosa", "scipy"] = "librosa", + ) -> None: super().__init__() - self.target_sr = target_sr + self.audio_resampler = AudioResampler( + target_sr=target_sr, + method=audio_resample_method, + ) def _is_embeddings( self, data: object @@ -374,15 +382,8 @@ def _parse_audio_data( if orig_sr is None: new_audio = audio else: - target_sr = self.target_sr - if target_sr is None: - raise RuntimeError( - "Audio resampling is not supported when " - "`target_sr` is not provided") - - new_audio = resample_audio(audio, - orig_sr=orig_sr, - target_sr=target_sr) + new_audio = self.audio_resampler.resample(audio, + orig_sr=orig_sr) new_audios.append(new_audio)