diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 98f18319f1e4..1b80c801d5be 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -1004,7 +1004,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `microsoft/Phi-4-multimodal-instruct`, etc.
* ✅︎
*
- *
+ * ✅︎
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I+
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index 077b5c762a25..e3c75d5cb6a9 100644
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -89,7 +89,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
- max_model_len=4096,
+ max_model_len=12800,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=320,
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 510a043ce421..bd7035b7615a 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -814,10 +814,13 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
- max_model_len=4096,
+ max_model_len=5120,
max_num_seqs=2,
+ max_num_batched_tokens=12800,
enable_lora=True,
max_lora_rank=320,
+ # Note - mm_processor_kwargs can also be passed to generate/chat calls
+ mm_processor_kwargs={"dynamic_hd": 16},
limit_mm_per_prompt={"image": 1},
)
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index e2e14d16228a..f165ea9efa10 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -503,11 +503,13 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
- max_model_len=10000,
+ max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enable_lora=True,
max_lora_rank=320,
+ # Note - mm_processor_kwargs can also be passed to generate/chat calls
+ mm_processor_kwargs={"dynamic_hd": 4},
)
placeholders = "".join(f"<|image_{i}|>"
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 416ca503b36c..99fb87def6dd 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -18,6 +18,7 @@ transformers
mistral_common >= 1.5.4
aiohttp
starlette
+scipy
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index bd1dcba6a995..e9dcba8ec089 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -1,14 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
import json
-from typing import Optional
+from typing import Any, Optional
import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer
-from vllm.multimodal.audio import resample_audio
+from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
@@ -43,6 +43,18 @@ def audio(request):
return AudioAsset(request.param)
+def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
+ """Convert kwargs to CLI args."""
+ args = []
+ for key, value in params_kwargs.items():
+ if isinstance(value, bool):
+ if value:
+ args.append(f"--{key.replace('_','-')}")
+ else:
+ args.append(f"--{key.replace('_','-')}={value}")
+ return args
+
+
@pytest.fixture(params=[
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
@@ -52,10 +64,7 @@ def server(request, audio_assets):
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
- ] + [
- f"--{key.replace('_','-')}={value}"
- for key, value in request.param.items()
- ]
+ ] + params_kwargs_to_cli_args(request.param)
with RemoteOpenAIServer(MODEL_NAME,
args,
@@ -136,9 +145,9 @@ def run_test(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
- audios=[(resample_audio(audio[0],
- orig_sr=audio[1],
- target_sr=16000), 16000)])
+ audios=[(resample_audio_librosa(audio[0],
+ orig_sr=audio[1],
+ target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]
diff --git a/tests/models/decoder_only/vision_language/test_phi4mm.py b/tests/models/decoder_only/vision_language/test_phi4mm.py
index 3cd830015076..11460a1a8d2b 100644
--- a/tests/models/decoder_only/vision_language/test_phi4mm.py
+++ b/tests/models/decoder_only/vision_language/test_phi4mm.py
@@ -181,7 +181,7 @@ def patch_hf_processor(*args,
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
-@pytest.mark.parametrize("max_model_len", [4096])
+@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
@@ -225,7 +225,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
-@pytest.mark.parametrize("max_model_len", [10000])
+@pytest.mark.parametrize("max_model_len", [25600])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
@@ -258,7 +258,7 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
-@pytest.mark.parametrize("max_model_len", [10000])
+@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index a8f21ff919b7..d56638f051f2 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
"nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
+ "microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py
new file mode 100644
index 000000000000..797986adba4a
--- /dev/null
+++ b/tests/models/multimodal/processing/test_phi4mm.py
@@ -0,0 +1,59 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for phi4mm's multimodal preprocessing kwargs."""
+import pytest
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+
+from ....conftest import _ImageAssets
+from ...utils import build_model_context
+
+
+@pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"])
+# yapf: disable
+@pytest.mark.parametrize(
+ ("mm_processor_kwargs", "expected_toks_per_img"),
+ [
+ ({"dynamic_hd": 4}, 1329),
+ ({"dynamic_hd": 16}, 4433),
+ # the default num_crops of phi-4-multimodal is 36
+ ({}, 9585),
+ ])
+# yapf: enable
+@pytest.mark.parametrize("num_imgs", [1, 2])
+@pytest.mark.parametrize("kwargs_on_init", [True, False])
+def test_processor_override(
+ image_assets: _ImageAssets,
+ model_id: str,
+ mm_processor_kwargs: dict[str, int],
+ expected_toks_per_img: int,
+ num_imgs: int,
+ kwargs_on_init: bool,
+):
+ """Ensure Phi4MMMultiModalProcessor handles dynamic_hd properly."""
+ # Avoid initializing CUDA early
+ from vllm.model_executor.models.phi4mm import _IMAGE_PLACEHOLDER_TOKEN_ID
+
+ ctx = build_model_context(
+ model_id,
+ mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
+ limit_mm_per_prompt={"image": num_imgs},
+ )
+ processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
+ hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
+
+ # Build the image str / prompt based on the number of images we pass
+ img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
+ prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
+
+ image_size = ctx.get_hf_config(
+ ).embd_layer["image_embd_layer"]["crop_size"]
+ dummy_image_size = (image_size * 7, image_size * 7)
+ dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
+ mm_data = {"image": [dummy_image] * num_imgs}
+
+ processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
+
+ # Ensure we have the right number of placeholders per num_crops size
+ img_tok_count = processed_inputs["prompt_token_ids"].count(
+ _IMAGE_PLACEHOLDER_TOKEN_ID)
+ assert img_tok_count == expected_toks_per_img * num_imgs
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index e505a4592e2d..0b662f1a7ec3 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -482,11 +482,8 @@ def _placeholder_str(self, modality: ModalityStr,
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
- if model_type == "phi3_v":
- # Workaround since this token is not defined in the tokenizer
+ if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
- if model_type == "phi4mm":
- return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(./)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
@@ -522,7 +519,7 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type == "ultravox":
return "<|audio|>"
if model_type == "phi4mm":
- return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
+ return f"<|audio_{current_count}|>"
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index 7f41ad2359df..5b43871b7591 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -327,7 +327,7 @@ def get_num_image_tokens(
*,
image_width: int,
image_height: int,
- processor: Optional[ProcessorMixin],
+ processor: Optional[ProcessorMixin] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py
index ec19797f8875..cdd762f5fec3 100644
--- a/vllm/model_executor/models/phi4mm.py
+++ b/vllm/model_executor/models/phi4mm.py
@@ -1,41 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
import math
-import re
-from functools import lru_cache
-from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple,
- TypedDict, Union)
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
-import scipy.signal
import torch
import torch.nn as nn
-import torchvision.transforms as T
-from PIL import Image
-from transformers import PretrainedConfig, SiglipVisionConfig
-from transformers.utils import logging
+from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
+ SequenceFeatureExtractor, SiglipVisionConfig)
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext)
-from vllm.inputs.data import TokenInputs, token_inputs
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
-from vllm.sequence import IntermediateTensors, SequenceData
-from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs, NestedTensors)
+from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
+ ImageProcessorItems, ImageSize,
+ MultiModalDataItems, MultiModalDataParser)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils import is_list_of
from .idefics2_vision_model import Idefics2VisionTransformer
-from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only
+from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
-from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
+from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
+ merge_multimodal_embeddings)
# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
@@ -43,115 +44,19 @@
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011
_AUDIO_MAX_SOUNDFILE_SIZE = 241_000
-DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz
-
-DYNAMIC_HD = 16
-AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>"
-IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>"
SIGLIP_NAME = "siglip-so400m-patch14-448"
VISION_ENCODER_TO_PROCESSING_CONFIG = {
'siglip-so400m-patch14-448': {
- 'dynamic_hd': 16,
'vit_image_size': 448,
'vit_patch_size': 14,
'token_compression_factor': 2,
},
}
-logger = logging.get_logger(__name__)
-# This is a workaround to prevent text (user input) + audio + image
-# from being used in the same prompt.
-# It includes token ids for "/n" and tokens in added_tokens_decoder
-# from the tokenizer_confg.json file.
-NON_USER_INPUT_TOKENS = {
- 198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022,
- 200023, 200024, 200025, 200026, 200027, 200028
-}
-def get_max_dummy_image(ctx: InputContext):
- hf_config = ctx.get_hf_config()
- vision_encoder_name = hf_config.img_processor
- if vision_encoder_name is None:
- vision_encoder_name = SIGLIP_NAME
- prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
- dynamic_hd_size = prepro_config['dynamic_hd']
- vit_image_size = prepro_config['vit_image_size']
-
- max_side = vit_image_size * dynamic_hd_size
- dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side)
- return dummy_image
-
-
-# image token length
-def get_max_phi4mm_image_tokens(ctx: InputContext):
- dummy_image = get_max_dummy_image(ctx)
-
- hf_config = ctx.get_hf_config()
- vision_encoder_name = hf_config.img_processor
- if vision_encoder_name is None:
- vision_encoder_name = SIGLIP_NAME
- prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
- dynamic_hd_size = prepro_config['dynamic_hd']
- vit_image_size = prepro_config['vit_image_size']
- vit_patch_size = prepro_config['vit_patch_size']
- token_compression_factor = prepro_config['token_compression_factor']
-
- image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size,
- vit_image_size,
- vit_patch_size,
- token_compression_factor)
- return image_num_tokens
-
-
-def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
- image_size):
- best_ratio_diff = float('inf')
- best_ratio = (1, 1)
- area = width * height
- for ratio in target_ratios:
- target_aspect_ratio = ratio[0] / ratio[1]
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
- if ratio_diff < best_ratio_diff:
- best_ratio_diff = ratio_diff
- best_ratio = ratio
- elif ratio_diff == best_ratio_diff:
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
- best_ratio = ratio
- return best_ratio
-
-
-def _find_target_aspect_ratio(image, image_size, max_num, min_num):
- orig_width, orig_height = image.size
-
- w_crop_num = math.ceil(orig_width / float(image_size))
- h_crop_num = math.ceil(orig_height / float(image_size))
- if w_crop_num * h_crop_num > max_num:
- aspect_ratio = orig_width / orig_height
-
- # calculate the existing image aspect ratio
- target_ratios = set((i, j) for i in range(1, max_num + 1)
- for j in range(1, max_num + 1)
- if i * j <= max_num and i * j >= min_num)
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- # find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio(
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
-
- # calculate the target width and height
- target_width = image_size * target_aspect_ratio[0]
- target_height = image_size * target_aspect_ratio[1]
- logger.debug("target_aspect_ratio: %s", target_aspect_ratio)
- else:
- target_width = image_size * w_crop_num
- target_height = image_size * h_crop_num
- target_aspect_ratio = (w_crop_num, h_crop_num)
- return target_aspect_ratio, target_height, target_width
-
-
-def _get_padding_size(image, target_height, target_width):
- orig_width, orig_height = image.size
+def _get_padding_size(orig_width: int, orig_height: int, target_height: int,
+ target_width: int):
ratio_width = target_width / orig_width
ratio_height = target_height / orig_height
@@ -164,181 +69,6 @@ def _get_padding_size(image, target_height, target_width):
return padding_height, padding_width
-def dynamic_preprocess(image,
- min_num=1,
- max_num=12,
- image_size=384,
- mask_size=27):
- target_aspect_ratio, target_height, target_width =\
- _find_target_aspect_ratio(
- image, image_size, max_num, min_num)
- padding_height, padding_width = _get_padding_size(image, target_height,
- target_width)
-
- # Calculate the ratio
- orig_width, orig_height = image.size
- ratio_width = target_width / orig_width
- ratio_height = target_height / orig_height
- if ratio_width < ratio_height:
- new_size = (target_width, int(orig_height * ratio_width))
- else:
- new_size = (int(orig_width * ratio_height), target_height)
-
- attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]),
- int(mask_size * target_aspect_ratio[0])))
- if padding_width >= 14:
- attention_mask[:, -math.floor(padding_width / 14):] = 0
- if padding_height >= 14:
- attention_mask[-math.floor(padding_height / 14):, :] = 0
- assert attention_mask.sum(
- ) > 0, f'attention mask is empty {attention_mask}'
-
- if min(new_size[1], target_height) < 10 or min(new_size[0],
- target_width) < 10:
- raise ValueError(f'the aspect ratio is very extreme {new_size}')
-
- image = T.functional.resize(
- image,
- [new_size[1], new_size[0]],
- )
-
- resized_img = T.functional.pad(image,
- [0, 0, padding_width, padding_height],
- fill=[255, 255, 255])
-
- return resized_img, attention_mask
-
-
-def pad_to_max_num_crops(images, max_crops=5):
- """
- images: B x 3 x H x W, B<=max_crops
- """
- B, _, H, W = images.shape
- if max_crops > B:
- pad = torch.zeros(max_crops - B,
- 3,
- H,
- W,
- dtype=images.dtype,
- device=images.device)
- images = torch.cat([images, pad], dim=0)
- return images
-
-
-def pad_mask_to_max_num_crops(masks, max_crops=5):
- B, H, W = masks.shape
- if max_crops > B:
- pad = torch.ones(max_crops - B,
- H,
- W,
- dtype=masks.dtype,
- device=masks.device)
- masks = torch.cat([masks, pad], dim=0)
- return masks
-
-
-def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
-
- # Basic settings.
- img_processor = T.Compose([
- T.ToTensor(),
- T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ])
- # Dynamic HD
- base_resolution = vit_resolution
- images = [image.convert('RGB') for image in images]
- # cover 384 and 448 resolution
- mask_resolution = base_resolution // vit_patch_size
- elems, image_attention_masks = [], []
- for im in images:
- elem, attention_mask = dynamic_preprocess(im,
- max_num=dynamic_hd_size,
- image_size=base_resolution,
- mask_size=mask_resolution)
- elems.append(elem)
- image_attention_masks.append(attention_mask)
- hd_images = [img_processor(im) for im in elems]
- global_image = [
- torch.nn.functional.interpolate(
- im.unsqueeze(0).float(),
- size=(base_resolution, base_resolution),
- mode='bicubic',
- ).to(im.dtype) for im in hd_images
- ]
- shapes = [[im.size(1), im.size(2)] for im in hd_images]
- mask_shapes = [[mask.size(0), mask.size(1)]
- for mask in image_attention_masks]
- global_attention_mask = [
- torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images
- ]
- hd_images_reshape = [
- im.reshape(1, 3, h // base_resolution, base_resolution,
- w // base_resolution, base_resolution).permute(
- 0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution,
- base_resolution).contiguous()
- for im, (h, w) in zip(hd_images, shapes)
- ]
- attention_masks_reshape = [
- mask.reshape(1, h // mask_resolution, mask_resolution,
- w // mask_resolution, mask_resolution).permute(
- 0, 1, 3, 2, 4).reshape(-1, mask_resolution,
- mask_resolution).contiguous()
- for mask, (h, w) in zip(image_attention_masks, mask_shapes)
- ]
- # NOTE token compression is hard coded here, and odd numbers seems to fail
- downsample_attention_masks = [
- mask[:, 0::2,
- 0::2].reshape(1, h // mask_resolution, w // mask_resolution,
- mask_resolution // 2 + mask_resolution % 2,
- mask_resolution // 2 + mask_resolution % 2).permute(
- 0, 1, 3, 2, 4)
- for mask, (h, w) in zip(attention_masks_reshape, mask_shapes)
- ]
- downsample_attention_masks = [
- mask.reshape(mask.size(1) * mask.size(2),
- mask.size(3) * mask.size(4))
- for mask in downsample_attention_masks
- ]
- # NOTE hard coded number of tokens
- num_img_tokens = [
- 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16
- for mask in downsample_attention_masks
- ]
-
- hd_images_reshape = [
- torch.cat([_global_image] + [_im], dim=0)
- for _global_image, _im in zip(global_image, hd_images_reshape)
- ]
- hd_masks_reshape = [
- torch.cat([_global_mask] + [_mask],
- dim=0) for _global_mask, _mask in zip(
- global_attention_mask, attention_masks_reshape)
- ]
- max_crops = max([img.size(0) for img in hd_images_reshape])
- image_transformed = [
- pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape
- ]
- image_transformed = torch.stack(image_transformed, dim=0)
- mask_transformed = [
- pad_mask_to_max_num_crops(mask, max_crops) \
- for mask in hd_masks_reshape
- ]
- mask_transformed = torch.stack(mask_transformed, dim=0)
-
- returned_input_image_embeds = image_transformed
- returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
- returned_image_attention_mask = mask_transformed
- returned_num_img_tokens = num_img_tokens
-
- data = {
- "pixel_values": returned_input_image_embeds,
- "image_sizes": returned_image_sizes,
- "image_attention_mask": returned_image_attention_mask,
- "num_img_tokens": returned_num_img_tokens,
- }
- return data
-
-
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
vision_config = {
"hidden_size": 1152,
@@ -492,7 +222,7 @@ def get_img_features(self,
def forward(self, pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
- image_attention_mask: torch.Tensor) -> torch.FloatTensor:
+ image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]:
"""
process image and return vision embeddings.
@@ -656,785 +386,528 @@ def forward(self, pixel_values: torch.FloatTensor,
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_device).to(target_dtype))
- img_set_tensor.append(img_feature_proj)
+ img_set_tensor.append(img_feature_proj.squeeze(0))
return img_set_tensor
-class Phi4MMAudioFeatureInputs(TypedDict):
- type: Literal["audio_features"]
- data: Tuple[NestedTensors]
- """Shape: `((batch_size, num_audios, 80, M), )"""
-
-
-class Phi4MMAudioEmbeddingInputs(TypedDict):
- type: Literal["audio_embeds"]
- data: NestedTensors
- """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
-
-
-Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
-
-
-def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
- """Create a Mel filter-bank the same as SpeechLib FbankFC.
-
- Args:
- sample_rate (int): Sample rate in Hz. number > 0 [scalar]
- n_fft (int): FFT size. int > 0 [scalar]
- n_mel (int): Mel filter size. int > 0 [scalar]
- fmin (float): lowest frequency (in Hz). If None use 0.0.
- float >= 0 [scalar]
- fmax: highest frequency (in Hz). If None use sample_rate / 2.
- float >= 0 [scalar]
-
- Returns
- out (numpy.ndarray): Mel transform matrix
- [shape=(n_mels, 1 + n_fft/2)]
+class Phi4MMImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: Union[torch.Tensor, List[torch.Tensor]]
"""
+ Shape:
+ `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
- bank_width = int(n_fft // 2 + 1)
- if fmax is None:
- fmax = sample_rate / 2
- if fmin is None:
- fmin = 0
- assert fmin >= 0, "fmin cannot be negative"
- assert (fmin < fmax <=
- sample_rate / 2), "fmax must be between (fmin, samplerate / 2]"
-
- def mel(f):
- return 1127.0 * np.log(1.0 + f / 700.0)
-
- def bin2mel(fft_bin):
- return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
-
- def f2bin(f):
- return int((f * n_fft / sample_rate) + 0.5)
-
- # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
- klo = f2bin(fmin) + 1
- khi = f2bin(fmax)
-
- khi = max(khi, klo)
-
- # Spec 2: SpeechLib uses triangles in Mel space
- mlo = mel(fmin)
- mhi = mel(fmax)
- m_centers = np.linspace(mlo, mhi, n_mels + 2)
- ms = (mhi - mlo) / (n_mels + 1)
-
- matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
- for m in range(0, n_mels):
- left = m_centers[m]
- center = m_centers[m + 1]
- right = m_centers[m + 2]
- for fft_bin in range(klo, khi):
- mbin = bin2mel(fft_bin)
- if left < mbin < right:
- matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
-
- return matrix
-
-
-class LogFbankProcessor:
-
- def __init__(self):
-
- self._eightk_method = "fillzero"
- self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
-
- self._hamming400 = np.hamming(400) # for 16k audio
- self._hamming200 = np.hamming(200) # for 8k audio
+ Note that `num_patches` may be different per batch and image,
+ in which case the data is passed as a list instead of a batched tensor.
+ """
- def extract_spectrogram(self, wav, fs):
- """Extract spectrogram features from waveform.
- Args:
- wav (1D array): waveform of the input
- fs (int): sampling rate of the waveform, 16000 or 8000.
- If fs=8000, the waveform will be resampled to 16000Hz.
- Output:
- log_fbank (2D array): a TxD matrix of log Mel filterbank features.
- D=80, and T is the number of frames.
- """
- if wav.ndim > 1:
- wav = np.squeeze(wav)
+ image_sizes: torch.Tensor
+ """
+ Shape: `(batch_size * num_images, 2)`
- # by default, we extract the mean if stereo
- if len(wav.shape) == 2:
- wav = wav.mean(1)
+ This should be in `(height, width)` format.
+ """
- # Resample to 16000 or 8000 if needed
- if fs > 16000:
- wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
- fs = 16000
- elif 8000 < fs < 16000:
- wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
- fs = 8000
- elif fs < 8000:
- raise RuntimeError(f"Unsupported sample rate {fs}")
-
- if fs == 8000:
- if self._eightk_method == "resample":
- # Input audio is 8 kHz. Convert to 16 kHz before feature
- # extraction
- wav = scipy.signal.resample_poly(wav, 2, 1)
- fs = 16000
- # Do nothing here for fillzero method
- elif fs != 16000:
- # Input audio is not a supported sample rate.
- raise RuntimeError(
- f"Input data using an unsupported sample rate: {fs}")
-
- preemphasis = 0.97
-
- if fs == 8000:
- n_fft = 256
- win_length = 200
- hop_length = 80
- fft_window = self._hamming200
- elif fs == 16000:
- n_fft = 512
- win_length = 400
- hop_length = 160
- fft_window = self._hamming400
-
- # Spec 1: SpeechLib cut remaining sample insufficient for a hop
- n_batch = (wav.shape[0] - win_length) // hop_length + 1
- # Here we don't use stride_tricks since the input array may not satisfy
- # memory layout requirement and we need writeable output
- # Here we only use list of views before copy to destination
- # so it is more efficient than broadcasting
- y_frames = np.array(
- [
- wav[_stride:_stride + win_length]
- for _stride in range(0, hop_length * n_batch, hop_length)
- ],
- dtype=np.float32,
- )
+ num_img_tokens: list[int]
+ """Shape: `(batch_size * num_images)`"""
- # Spec 2: SpeechLib applies preemphasis within each batch
- y_frames_prev = np.roll(y_frames, 1, axis=1)
- y_frames_prev[:, 0] = y_frames_prev[:, 1]
- y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
+ image_attention_mask: torch.Tensor
+ """Shape: `(batch_size * num_images, H_mask, W_mask)`"""
- S = np.fft.rfft(fft_window * y_frames, n=n_fft,
- axis=1).astype(np.complex64)
- if fs == 8000:
- # Need to pad the output to look like 16 kHz data but with zeros in
- # the 4 to 8 kHz bins.
- frames, bins = S.shape
- padarray = np.zeros((frames, bins))
- S = np.concatenate((S[:, 0:-1], padarray),
- axis=1) # Nyquist bin gets set to zero
+class Phi4MMImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: Union[torch.Tensor, List[torch.Tensor]]
+ """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
- spec = np.abs(S).astype(np.float32)
- return spec
+ `hidden_size` must match the hidden size of language model backbone.
+ """
- def extract_features(self, wav, fs):
- """Extract log filterbank features from waveform.
- Args:
- wav (1D array): waveform of the input
- fs (int): sampling rate of the waveform, 16000 or 8000.
- If fs=8000, the waveform will be resampled to 16000Hz.
- Output:
- log_fbank (2D array): a TxD matrix of log Mel filterbank features.
- D=80, and T is the number of frames.
- """
- spec = self.extract_spectrogram(wav, fs)
- spec_power = spec**2
- fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
- log_fbank = np.log(fbank_power).astype(np.float32)
+class Phi4MMAudioFeatureInputs(TypedDict):
+ type: Literal["audio_features"]
+ data: Union[torch.Tensor, List[torch.Tensor]]
+ """Shape: `(batch_size * num_audios, 80, M)"""
- return log_fbank
+class Phi4MMAudioEmbeddingInputs(TypedDict):
+ type: Literal["audio_embeds"]
+ data: NestedTensors
+ """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
-@lru_cache
-def audio_feature_extractor() -> LogFbankProcessor:
- # Creates an instance of the audio processor, needed to extract the
- # the audio features from the sound file
- # LRU cache ensures that we only make one copy
- return LogFbankProcessor()
+Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
+Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
-def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
- vit_patch_size, token_compression_factor):
- """
- compute the number of tokens an image is expected to take up considering
- the image encoder architecture and exclude output features containing
- only padding pixels
- for siglip, vit_image_size=448, vit_patch_size=14, so output will be
- 32x32 feature map
- NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
- """
- assert vit_image_size % vit_patch_size == 0, \
- "vit_image_size must be divisible by vit_patch_size"
- assert vit_image_size // vit_patch_size % token_compression_factor == 0, \
- "vit_image_size // vit_patch_size must be divisible by "\
- "token_compression_factor"
-
- target_aspect_ratio, target_height, target_width = (
- _find_target_aspect_ratio(image,
- vit_image_size,
- dynamic_hd_size,
- min_num=1))
- assert target_aspect_ratio[
- 0] * vit_image_size == target_width, \
- f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
- assert target_aspect_ratio[
- 1] * vit_image_size == target_height, \
- f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
- assert (target_height % vit_image_size == 0
- and target_width % vit_image_size == 0)
-
- padding_height, padding_width = _get_padding_size(image, target_height,
- target_width)
- assert padding_width == 0 or padding_height == 0, \
- "padding_width or padding_height must be 0"
-
- target_feat_width = target_width // vit_patch_size
- target_feat_height = target_height // vit_patch_size
- if padding_width >= vit_patch_size:
- assert padding_height == 0, "padding_height not 0"
- non_pad_feat_width = target_feat_width - math.floor(
- padding_width / vit_patch_size)
- non_pad_feat_height = target_feat_height
- elif padding_height >= vit_patch_size:
- assert padding_width == 0, "padding_width not 0"
- non_pad_feat_height = target_feat_height - math.floor(
- padding_height / vit_patch_size)
- non_pad_feat_width = target_feat_width
- else:
- # small padding shorter than a vit patch
- non_pad_feat_width = target_feat_width
- non_pad_feat_height = target_feat_height
-
- feat_width = non_pad_feat_width // token_compression_factor
- feat_height = non_pad_feat_height // token_compression_factor
- # NOTE it's possible that the non-padding feature is not divisible
- if non_pad_feat_width % token_compression_factor != 0:
- feat_width += 1
- if non_pad_feat_height % token_compression_factor != 0:
- feat_height += 1
- num_hd_patch_tokens = feat_width * feat_height
- num_hd_newline_tokens = feat_height
- vit_feature_size = vit_image_size // vit_patch_size
- num_global_image_tokens = (vit_feature_size // token_compression_factor)**2
- num_sep_tokens = 1
- num_global_image_newline_tokens = \
- vit_feature_size // token_compression_factor
-
- return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens +
- num_hd_newline_tokens + num_global_image_newline_tokens)
-
-
-def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]:
+def cat_with_pad(tensors, dim, padding_value=0):
"""
- Compute the output size of the `extract_features` method.
-
- Args:
- wav_length (int): Length of the input waveform in samples.
- fs (int): Sampling rate of the waveform, either 16000 or 8000.
-
- Returns:
- tuple (int, int): Output size as (T, D), where:
- T: Number of time frames.
- D: Number of Mel filterbank bins (80).
+ cat along dim, while pad to max for all other dims
"""
+ ndim = tensors[0].dim()
+ assert all(
+ t.dim() == ndim for t in
+ tensors[1:]), "All tensors must have the same number of dimensions"
- # Resample to 16000 or 8000 if needed
- if fs > 16000:
- wav_length //= fs // 16000
- fs = 16000
- elif 8000 <= fs < 16000:
- # We'll resample to 16K from 8K
- wav_length *= 2
- fs = 16000
- elif fs < 8000:
- raise RuntimeError(f"Unsupported sample rate {fs}")
+ out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
+ out_size[dim] = sum(t.shape[dim] for t in tensors)
+ output = tensors[0].new_full(out_size, padding_value)
- # Spectrogram parameters for 16 kHz
- win_length = 400 # Frame length in samples
- hop_length = 160 # Frame shift in samples
- mel_bins = 80 # Number of mel filterbank bins
+ index = 0
+ for t in tensors:
+ # Create a slice list where every dimension except dim is full slice
+ slices = [slice(0, t.shape[d]) for d in range(ndim)]
+ # Update only the concat dimension slice
+ slices[dim] = slice(index, index + t.shape[dim])
- # Calculate number of frames (T)
- T = (wav_length - win_length) // hop_length + 1
- if T < 1:
- raise ValueError("Waveform too short for given parameters.")
+ output[slices] = t
+ index += t.shape[dim]
- # Return time frames (T) and mel bins (D)
- return T, mel_bins
+ return output
-def _get_audio_embed_sizes(audios, ctx: InputContext):
- """
- Get the audio embedding sizes for each audio file.
+class Phi4MMProcessingInfo(BaseProcessingInfo):
- Args:
- audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
- waveform and sample rate.
- ctx (InputContext): Input context.
+ def get_hf_processor(
+ self,
+ *,
+ dynamic_hd: Optional[int] = None,
+ **kwargs: object,
+ ) -> ProcessorMixin:
+ if dynamic_hd is not None:
+ kwargs["dynamic_hd"] = dynamic_hd
- Returns:
- List[int]: List of audio embedding sizes.
- """
- audio_embed_sizes = []
- for audio in audios:
- audio_data, sf = audio
- audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf)
- audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(),
- audio_frames)
- audio_embed_sizes.append(audio_embed_size)
- return audio_embed_sizes
+ return self.ctx.get_hf_processor(**kwargs)
+ @property
+ def image_tokens(self) -> list[str]:
+ return [f"<|image_{i+1}|>" for i in range(100)]
-def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""):
- """
- The following will search for `<|audio_{idx}|>` tokens and
- return a mapping of audio placeholder tokens to audio placeholder token ids
- based on the size of the audio embeddings.
+ @property
+ def audio_tokens(self) -> list[str]:
+ return [f"<|audio_{i+1}|>" for i in range(100)]
- Args:
- audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
- waveform and sample rate.
- ctx (InputContext): Input context.
- prompt_str (str): The prompt string.
+ def get_dynamic_hd(
+ self,
+ processor: Optional[ProcessorMixin] = None,
+ ) -> int:
+ if processor is None:
+ processor = self.get_hf_processor()
+ image_processor = processor.image_processor
+ return image_processor.dynamic_hd
- Returns:
- Dict[str, List[int]]: Mapping of audio placeholder tokens to audio
- placeholder token ids.
+ def get_feature_extractor(self) -> SequenceFeatureExtractor:
+ return self.get_hf_processor().audio_processor
- """
- if len(audios) == 0:
- return {}
-
- audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
- audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str)
- audio_ids = [int(audio_id) for audio_id in audio_ids]
- assert len(audio_ids) == len(
- audio_embed_sizes
- ), "Number of audio tokens and audio features do not match"
- assert tuple(audio_ids) == tuple(range(1,
- len(audio_ids) +
- 1)), "Audio ids are not in order!"
- audio_id_to_input_ids = {
- f"<|audio_{audio_id}|>":
- [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
- for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes)
- }
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"audio": None, "image": None}
- return audio_id_to_input_ids
-
-
-def _count_image_tokens(images, ctx: InputContext):
- hf_config = ctx.get_hf_config()
- vision_encoder_name = hf_config.img_processor
- if vision_encoder_name is None:
- vision_encoder_name = SIGLIP_NAME
- prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
- dynamic_hd_size = prepro_config['dynamic_hd']
- vit_image_size = prepro_config['vit_image_size']
- vit_patch_size = prepro_config['vit_patch_size']
- token_compression_factor = prepro_config['token_compression_factor']
-
- image_token_counts = [
- _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
- vit_patch_size, token_compression_factor)
- for image in images
- ]
- return image_token_counts
-
-
-def _get_image_id_to_input_ids(images, prompt, ctx: InputContext):
- if len(images) == 0:
- return {}
-
- image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt)
- image_ids = [int(image_id) for image_id in image_ids]
- assert len(image_ids) == len(
- set(image_ids)), "Duplicate image tokens in prompt"
- assert len(images) == len(
- image_ids), "Number of images and image tokens in prompt do not match"
-
- # NOTE the following assertion is not strictly necessary
- assert tuple(image_ids) == tuple(range(1,
- len(image_ids) +
- 1)), "Image ids are not in order"
-
- image_token_counts = _count_image_tokens(images, ctx)
- image_id_to_input_ids = {
- f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens
- for image_id, num_tokens in zip(image_ids, image_token_counts)
- }
- return image_id_to_input_ids
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {
+ "image": self.get_max_image_tokens(),
+ "audio": self.get_max_audio_tokens(),
+ }
+ def get_max_audio_tokens(self) -> int:
+ sr = self.get_feature_extractor().sampling_rate
+ num_frames = self.get_audio_num_frames(_AUDIO_MAX_SOUNDFILE_SIZE, sr)
+ return self._compute_audio_embed_size(num_frames)
-def input_processor_for_phi4mm(ctx: InputContext,
- inputs: DecoderOnlyInputs) -> TokenInputs:
- """
- Implements the input processor, which transforms the input prompt ids
- to include the audio placeholder token. This will become the `input_ids`
- in `forward` for the model.
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(image_width=target_width,
+ image_height=target_height)
- Args:
- ctx (InputContext): Input context.
- inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids)
- to process.
+ def _find_target_aspect_ratio(
+ self,
+ orig_width: int,
+ orig_height: int,
+ image_size: int,
+ max_num: int,
+ min_num: int,
+ ):
+ w_crop_num = math.ceil(orig_width / float(image_size))
+ h_crop_num = math.ceil(orig_height / float(image_size))
+ if w_crop_num * h_crop_num > max_num:
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set((i, j) for i in range(1, max_num + 1)
+ for j in range(1, max_num + 1)
+ if i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ image_processor = self.get_hf_processor().image_processor
+ target_aspect_ratio = image_processor.find_closest_aspect_ratio(
+ aspect_ratio,
+ target_ratios,
+ orig_width,
+ orig_height,
+ image_size,
+ )
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ else:
+ target_width = image_size * w_crop_num
+ target_height = image_size * h_crop_num
+ target_aspect_ratio = (w_crop_num, h_crop_num)
+ return target_aspect_ratio, target_height, target_width
- Returns:
- TokenInputs: Processed inputs
- """
- multi_modal_data = inputs.get("multi_modal_data")
- if (multi_modal_data is None or
- ("audio" not in multi_modal_data and "image" not in multi_modal_data)):
- # pure text input, so no need to do pre-processing
- return inputs
-
- prompt_str = inputs.get("prompt")
- prompt_token_ids = inputs.get("prompt_token_ids")
- # for offline_inference, we will get str input and we parse MM special
- # tokens from it
- # (ignore prompt_token_ids)
- # for OAI server, we will get prompt_token_ids, where MM special tokens
- # are already parsed
-
- if 'audio' in multi_modal_data:
- audios = multi_modal_data["audio"]
-
- if not isinstance(audios, list):
- audios = [audios]
- if prompt_str is not None:
- audio_id_to_input_ids = _get_audio_id_to_input_ids(
- audios, ctx, prompt_str=prompt_str)
- audio_embed_sizes = []
- elif prompt_token_ids is not None:
- audio_id_to_input_ids = {}
- audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
- else:
- audio_id_to_input_ids = {}
- audio_embed_sizes = []
-
- if 'image' in multi_modal_data:
- # PIL Image or list of PIL Images
- images = multi_modal_data["image"]
- if not isinstance(images, list):
- images = [images]
- if prompt_str is not None:
- image_id_to_input_ids = _get_image_id_to_input_ids(
- images, prompt_str, ctx)
- image_token_counts = []
- elif prompt_token_ids is not None:
- image_id_to_input_ids = {}
- image_token_counts = _count_image_tokens(images, ctx)
- else:
- image_id_to_input_ids = {}
- image_token_counts = []
-
- # Handle the case where the prompt is a string and we need to manually
- # tokenize it.
- # In this case, the `audio_id_to_input_ids` dict will be mapping from
- # an audio placeholder
- # string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the
- # given audio length.
- if prompt_str:
- pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)"
- prompt_chunk_strings = re.split(pattern, prompt_str)
- prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""]
-
- # Create the new input_ids with the placeholder image and audio
- # tokens inserted
- tokenizer = cached_tokenizer_from_config(ctx.model_config)
- input_ids = []
- has_imag, has_audio, has_user_text_input = False, False, False
- for prompt_chunk_string in prompt_chunk_strings:
- if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string):
- input_ids.extend(image_id_to_input_ids[prompt_chunk_string])
- has_imag = True
- elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string):
- input_ids.extend(audio_id_to_input_ids[prompt_chunk_string])
- has_audio = True
- else:
- curr_token_ids = tokenizer(prompt_chunk_string).input_ids
- if not has_user_text_input:
- for token_id in curr_token_ids:
- if token_id not in NON_USER_INPUT_TOKENS:
- has_user_text_input = True
- break
- input_ids.extend(curr_token_ids)
- if has_audio and has_imag and has_user_text_input:
- raise ValueError(
- "Phi4MMForCausalLM does not support text + audio + image" +
- " inputs in the same prompt")
- # Handle the case where the prompt is already tokenized
- else:
- assert prompt_token_ids is not None, \
- "If string prompt isn't provided, prompt_token_ids must be"
-
- i = 0
- input_ids = prompt_token_ids
- # only needed for later assertion
- img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0
- image_token_count_iter = iter(image_token_counts)
- audio_embed_size_iter = iter(audio_embed_sizes)
- while i < len(input_ids):
- token_id = input_ids[i]
- if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID:
- token_count = next(audio_embed_size_iter)
- audio_cnt += 1
- elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID:
- token_count = next(image_token_count_iter)
- img_cnt += 1
- else:
- user_text_input_cnt += 1 if token_id not in \
- NON_USER_INPUT_TOKENS else 0
- i += 1
- continue
- tokens = [token_id] * token_count
- input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
- i += token_count
-
- if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0:
- raise ValueError(
- "Phi4MMForCausalLM does not support text + audio + image" +
- " inputs in the same prompt")
- # If the below assertion fails, it might be that input pure-text
- # messages contain image/audio special tokens literally
- # (<|endoftext10|>, <|endoftext11|>).
- assert (img_cnt == len(image_token_counts)), (
- f"Number of image tokens in prompt_token_ids ({img_cnt}) "
- f"does not match number of images ({len(image_token_counts)})")
- assert (audio_cnt == len(audio_embed_sizes)), (
- f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
- f"does not match number of audios ({len(audio_embed_sizes)})")
-
- # NOTE: Create a defensive copy of the original inputs
- return token_inputs(
- prompt_token_ids=input_ids,
- prompt=prompt_str,
- multi_modal_data=multi_modal_data,
- )
+ def _compute_num_image_tokens(
+ self,
+ orig_width: int,
+ orig_height: int,
+ dynamic_hd_size: int,
+ vit_image_size: int,
+ vit_patch_size: int,
+ token_compression_factor: int = 2,
+ ):
+ """
+ compute the number of tokens an image is expected to take up considering
+ the image encoder architecture and exclude output features containing
+ only padding pixels
+ for siglip, vit_image_size=448, vit_patch_size=14, so output will be
+ 32x32 feature map
+ NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
+ """
+ assert vit_image_size % vit_patch_size == 0, (
+ "vit_image_size must be divisible by vit_patch_size")
+ assert (vit_image_size // vit_patch_size %
+ token_compression_factor == 0), (
+ "vit_image_size // vit_patch_size must be divisible by "
+ "token_compression_factor")
+
+ target_aspect_ratio, target_height, target_width = (
+ self._find_target_aspect_ratio(orig_width,
+ orig_height,
+ vit_image_size,
+ dynamic_hd_size,
+ min_num=1))
+ assert target_aspect_ratio[0] * vit_image_size == target_width, (
+ f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}")
+ assert target_aspect_ratio[1] * vit_image_size == target_height, (
+ f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}")
+ assert (target_height % vit_image_size == 0
+ and target_width % vit_image_size == 0)
+
+ padding_height, padding_width = _get_padding_size(
+ orig_width, orig_height, target_height, target_width)
+ assert padding_width == 0 or padding_height == 0, \
+ "padding_width or padding_height must be 0"
+
+ target_feat_width = target_width // vit_patch_size
+ target_feat_height = target_height // vit_patch_size
+ if padding_width >= vit_patch_size:
+ assert padding_height == 0, "padding_height not 0"
+ non_pad_feat_width = target_feat_width - math.floor(
+ padding_width / vit_patch_size)
+ non_pad_feat_height = target_feat_height
+ elif padding_height >= vit_patch_size:
+ assert padding_width == 0, "padding_width not 0"
+ non_pad_feat_height = target_feat_height - math.floor(
+ padding_height / vit_patch_size)
+ non_pad_feat_width = target_feat_width
+ else:
+ # small padding shorter than a vit patch
+ non_pad_feat_width = target_feat_width
+ non_pad_feat_height = target_feat_height
+
+ feat_width = non_pad_feat_width // token_compression_factor
+ feat_height = non_pad_feat_height // token_compression_factor
+ # NOTE it's possible that the non-padding feature is not divisible
+ if non_pad_feat_width % token_compression_factor != 0:
+ feat_width += 1
+ if non_pad_feat_height % token_compression_factor != 0:
+ feat_height += 1
+ num_hd_patch_tokens = feat_width * feat_height
+ num_hd_newline_tokens = feat_height
+ vit_feature_size = vit_image_size // vit_patch_size
+ num_global_image_tokens = (vit_feature_size //
+ token_compression_factor)**2
+ num_sep_tokens = 1
+ num_global_image_newline_tokens = \
+ vit_feature_size // token_compression_factor
+
+ return (num_global_image_tokens + num_sep_tokens +
+ num_hd_patch_tokens + num_hd_newline_tokens +
+ num_global_image_newline_tokens)
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[ProcessorMixin] = None,
+ ) -> int:
+ hf_config = self.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[
+ vision_encoder_name]
+ vit_image_size = prepro_config['vit_image_size']
+ vit_patch_size = prepro_config['vit_patch_size']
+ token_compression_factor = prepro_config['token_compression_factor']
+
+ dynamic_hd_size = self.get_dynamic_hd(processor=processor)
+
+ image_num_tokens = self._compute_num_image_tokens(
+ image_width,
+ image_height,
+ dynamic_hd_size=dynamic_hd_size,
+ vit_image_size=vit_image_size,
+ vit_patch_size=vit_patch_size,
+ token_compression_factor=token_compression_factor,
+ )
-def _compute_audio_embed_size(hf_config, audio_frames):
- """
- Compute the audio embedding size based on the audio frames and
- compression rate.
- """
- compression_rate = hf_config.embd_layer['audio_embd_layer'][
- 'compression_rate']
- # NOTE: this is a hard-coded value but might be configurable in the future
- qformer_compression_rate = 1
- integer = audio_frames // compression_rate
- remainder = audio_frames % compression_rate
+ return image_num_tokens
- result = integer if remainder == 0 else integer + 1
+ def get_image_size_with_most_features(
+ self,
+ processor: Optional[ProcessorMixin] = None,
+ ) -> ImageSize:
+ hf_config = self.get_hf_config()
+ vision_encoder_name = hf_config.img_processor
+ if vision_encoder_name is None:
+ vision_encoder_name = SIGLIP_NAME
+ prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[
+ vision_encoder_name]
+ vit_image_size = prepro_config['vit_image_size']
+
+ max_side = vit_image_size * self.get_dynamic_hd(processor=processor)
+ return ImageSize(height=max_side, width=vit_image_size)
+
+ def get_audio_num_frames(self, audio_len: int, sr: float) -> int:
+ """
+ Compute the output size of the `extract_features` method.
- integer = result // qformer_compression_rate
- remainder = result % qformer_compression_rate
- result = integer if remainder == 0 else integer + 1 # qformer compression
+ Args:
+ audio_len (int): Length of the input waveform in samples.
+ sr (float): Sampling rate of the waveform, either 16000 or 8000.
- return result
+ Returns:
+ tuple (int, int): Output size as (T, D), where:
+ T: Number of time frames.
+ D: Number of Mel filterbank bins (80).
+ """
+ # Resample to 16000 or 8000 if needed
+ if sr > 16000:
+ audio_len //= sr // 16000
+ elif 8000 <= sr < 16000:
+ # We'll resample to 16K from 8K
+ audio_len *= 2
+ elif sr < 8000:
+ raise RuntimeError(f"Unsupported sample rate {sr}")
+
+ # Spectrogram parameters for 16 kHz
+ win_length = 400 # Frame length in samples
+ hop_length = 160 # Frame shift in samples
+
+ # Calculate number of frames (T)
+ num_frames = (audio_len - win_length) // hop_length + 1
+ if num_frames < 1:
+ raise ValueError("Waveform too short for given parameters.")
+
+ # Return time frames (T)
+ return num_frames
+
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
+ """
+ Compute the audio embedding size based on the audio frames and
+ compression rate.
+ """
+ hf_config = self.get_hf_config()
+ compression_rate = hf_config.embd_layer['audio_embd_layer'][
+ 'compression_rate']
+ # NOTE: this is a hard-coded value but might be configurable
+ # in the future
+ qformer_compression_rate = 1
+ integer = audio_frames // compression_rate
+ remainder = audio_frames % compression_rate
-def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int:
- return 10000
+ result = integer if remainder == 0 else integer + 1
+ integer = result // qformer_compression_rate
+ remainder = result % qformer_compression_rate
+ # qformer compression
+ result = integer if remainder == 0 else integer + 1
-def dummy_audio_for_phi4mm(audio_count: int) -> dict:
- """
- Create dummy audio data for the Phi4MM model, which is used for profiling.
+ return result
- Args:
- audio_count (int): Number of audio samples.
- Returns:
- dict: Dummy audio data.
- """
- dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0)
- return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count
+class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_audios = mm_counts.get("audio", 0)
+ num_images = mm_counts.get("image", 0)
-def dummy_image_for_phi4mm(width: int, height: int):
- image = Image.new('RGB', (width, height), color='black')
- return image
+ image_tokens: list[str] = self.info.image_tokens[:num_images]
+ audio_tokens: list[str] = self.info.audio_tokens[:num_audios]
+ return "".join(image_tokens + audio_tokens)
-def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]) -> DummyData:
- """
- Create dummy sequence (input_ids) and audio data for the Phi4MM model,
- which is used for profiling.
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_audios = mm_counts.get("audio", 0)
+ num_images = mm_counts.get("image", 0)
- In this case, the sequence data is a bunch of 0s with a number of audio
- tokens that correspond to the audio embed size of the
- _AUDIO_MAX_SOUNDFILE_SIZE.
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
- Args:
- ctx (InputContext): Input context.
- seq_len (int): Length of the sequence.
- mm_counts (Mapping[str, int]): Multi-modal counts.
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
- Returns:
- Tuple: Dummy sequence data and dummy audio data.
- """
- audio_count = mm_counts["audio"]
- audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE,
- DUMMY_SAMPLING_FREQUENCY)
- audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(),
- audio_frames)
-
- image_count = mm_counts["image"]
- dummy_image = get_max_dummy_image(ctx)
- max_image_tokens = get_max_phi4mm_image_tokens(ctx)
- total_image_tokens = image_count * max_image_tokens
-
- if seq_len - audio_feature_size * audio_count - total_image_tokens < 0:
- raise RuntimeError(
- f"Phi4MM cannot process {audio_count} audios and {image_count}"
- f"images in a prompt, please increase max_model_len to be at"
- f" larger than "
- f"{audio_feature_size * audio_count + total_image_tokens}"
- " or reduce audio/image limit by --limit-mm-per-prompt.")
-
- if audio_feature_size * audio_count > total_image_tokens:
- seq_data = SequenceData.from_prompt_token_counts(
- (_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count),
- (0, seq_len - audio_feature_size * audio_count),
- )
mm_data = {
- "audio": dummy_audio_for_phi4mm(audio_count),
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ "audio":
+ self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
+ num_audios=num_audios),
}
- else:
- seq_data = SequenceData.from_prompt_token_counts(
- (_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens),
- (0, seq_len - total_image_tokens),
- )
- mm_data = {
- "image": [dummy_image] * image_count,
- }
- return DummyData(seq_data, mm_data)
+ return mm_data
-def input_mapper_for_phi4mm_audio(ctx: InputContext,
- data: object) -> MultiModalKwargs:
- """
- This function is used to create the MultiModalKwargs for the Phi4MM
- (audio) model.
- Specifically, for audio, we extract the audio features from the sound
- file and create pairs of audio features and audio embed lengths (the
- latter of which is used to repeat the audio placeholder token in the
- input prompt IDs).
- These pairs are used, downstream, in `_audio_features_to_embeddings`
- (via `_process_audio_input`).
-
- Note that the incoming audio data (each entry in `data`) is a tuple of
- the audio data and the sampling frequency (e.g. from soundfile.read).
-
- Args:
- ctx (InputContext): Input context.
- data (object): Audio data.
-
- Returns:
- MultiModalKwargs: Multi-modal inputs.
- """
- if not isinstance(data, list):
- data = [data]
-
- if len(data) == 0:
- return MultiModalKwargs()
-
- audio_features = []
- for audio_input in data:
- if not isinstance(audio_input, tuple):
- raise NotImplementedError(
- f"Unsupported data type: {type(audio_input)}")
-
- audio, sf = audio_input
- feature_extractor = audio_feature_extractor()
- single_audio_features = feature_extractor.extract_features(audio, sf)
- feat_stride = (1 if not hasattr(feature_extractor, "stride") else
- feature_extractor.stride)
- audio_frames = len(single_audio_features) * feat_stride
- single_audio_embed_size = _compute_audio_embed_size(
- ctx.get_hf_config(), audio_frames)
- single_audio_feature_audio_len_pair = (
- single_audio_features,
- [single_audio_embed_size],
- )
- audio_features.append(single_audio_feature_audio_len_pair)
- return MultiModalKwargs({"audio_features": audio_features})
-
-
-def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
- if not isinstance(data, list):
- data = [data]
- # data: list of PIL images
- if len(data) == 0:
- return MultiModalKwargs()
- hf_config = ctx.get_hf_config()
- vision_encoder_name = hf_config.img_processor
- if vision_encoder_name is None:
- vision_encoder_name = SIGLIP_NAME
- prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
- dynamic_hd_size = prepro_config['dynamic_hd']
- vit_image_size = prepro_config['vit_image_size']
- vit_patch_size = prepro_config['vit_patch_size']
-
- image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
- vit_patch_size)
- return MultiModalKwargs({
- "pixel_values":
- image_input_dict["pixel_values"],
- "image_sizes":
- image_input_dict["image_sizes"],
- "image_attention_mask":
- image_input_dict["image_attention_mask"],
- "num_img_tokens":
- image_input_dict["num_img_tokens"],
- })
+class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
-def cat_with_pad(tensors, dim, padding_value=0):
- """
- cat along dim, while pad to max for all other dims
- """
- ndim = tensors[0].dim()
- assert all(
- t.dim() == ndim for t in
- tensors[1:]), "All tensors must have the same number of dimensions"
+ def _get_data_parser(self) -> MultiModalDataParser:
+ feature_extractor = self.info.get_feature_extractor()
+ return MultiModalDataParser(target_sr=feature_extractor.sampling_rate,
+ audio_resample_method="scipy")
- out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
- out_size[dim] = sum(t.shape[dim] for t in tensors)
- output = tensors[0].new_full(out_size, padding_value)
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if not mm_data:
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ sr = self.info.get_feature_extractor().sampling_rate
+ if (audio_data := mm_data.get("audios", [])):
+ mm_data['audios'] = [(data, sr) for data in audio_data]
+
+ processed_outputs = super()._call_hf_processor(prompt, mm_data,
+ mm_kwargs)
+
+ num_img_tokens = [
+ self.info.get_num_image_tokens(image_width=img_size[0],
+ image_height=img_size[1])
+ for img_size in processed_outputs["image_sizes"]
+ ]
+ processed_outputs["num_img_tokens"] = num_img_tokens
- index = 0
- for t in tensors:
- # Create a slice list where every dimension except dim is full slice
- slices = [slice(0, t.shape[d]) for d in range(ndim)]
- # Update only the concat dimension slice
- slices[dim] = slice(index, index + t.shape[dim])
+ audio_features = processed_outputs['input_audio_embeds']
+ feature_sizes = [
+ self.info.get_audio_num_frames(len(audio), sr)
+ for audio in audio_data
+ ]
+ processed_outputs['input_audio_embeds'] = [
+ audio_features[idx, :size]
+ for idx, size in enumerate(feature_sizes)
+ ]
- output[slices] = t
- index += t.shape[dim]
+ return processed_outputs
- return output
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ input_image_embeds=MultiModalFieldConfig.batched("image"),
+ image_attention_mask=MultiModalFieldConfig.batched("image"),
+ image_sizes=MultiModalFieldConfig.batched("image"),
+ num_img_tokens=MultiModalFieldConfig.batched("image"),
+ input_audio_embeds=MultiModalFieldConfig.batched("audio"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> Sequence[PromptUpdate]:
+ image_tokens: list[str] = self.info.image_tokens # type: ignore
+ audio_tokens: list[str] = self.info.audio_tokens # type: ignore
+ feature_extractor = self.info.get_feature_extractor()
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ def get_image_replacement_phi4mm(item_idx: int):
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems))
+
+ if isinstance(images, ImageEmbeddingItems):
+ num_image_tokens = images.get_feature_size(item_idx)
+ else:
+ image_size = images.get_image_size(item_idx)
+ num_image_tokens = self.info.get_num_image_tokens(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
+ )
+
+ image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
+
+ return image_tokens
+
+ def get_audio_replacement_phi4mm(item_idx: int):
+ audios = mm_items.get_items("audio", AudioProcessorItems)
+ # TODO(Isotr0py): support embedding inputs
+ audio_len = audios.get_audio_length(item_idx)
+ audio_frames = self.info.get_audio_num_frames(
+ audio_len, feature_extractor.sampling_rate)
+ audio_embed_size = self.info._compute_audio_embed_size(
+ audio_frames)
+
+ audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
+
+ return audio_tokens
+
+ num_images = mm_items.get_count("image", strict=False)
+ num_audios = mm_items.get_count("audio", strict=False)
+
+ image_repl = [
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=get_image_replacement_phi4mm,
+ ) for image_token in image_tokens[:num_images]
+ ]
+ audio_repl = [
+ PromptReplacement(
+ modality="audio",
+ target=audio_token,
+ replacement=get_audio_replacement_phi4mm,
+ ) for audio_token in audio_tokens[:num_audios]
+ ]
+ return image_repl + audio_repl
-@MULTIMODAL_REGISTRY.register_input_mapper("audio",
- input_mapper_for_phi4mm_audio)
-@MULTIMODAL_REGISTRY.register_input_mapper("image",
- input_mapper_for_phi4mm_image)
-@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
- "audio", get_max_phi4mm_audio_tokens)
-@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
- "image", get_max_phi4mm_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm)
-class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
- SupportsV0Only):
+@MULTIMODAL_REGISTRY.register_processor(
+ Phi4MMMultiModalProcessor,
+ info=Phi4MMProcessingInfo,
+ dummy_inputs=Phi4MMDummyInputsBuilder,
+)
+class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"""
Implements the Phi-4-multimodal-instruct model in vLLM.
"""
@@ -1518,48 +991,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
- self.sampler = Sampler()
-
- def _audio_features_to_embeddings(
- self,
- input_ids: torch.Tensor,
- input_features: List[torch.Tensor],
- audio_input_sizes: torch.Tensor,
- audio_projection_mode: str,
- ) -> torch.Tensor:
- """
- Convert audio features to embeddings, which are used as input to the
- model (via `inputs_embeds`).
-
- Args:
- input_ids (torch.Tensor): Input IDs (the prompt in this case).
- input_features (list[torch.Tensor]): Input features (the audio
- embeddings).
- audio_input_sizes (list[torch.Tensor]): Audio input sizes (the
- audio embed lengths to use for padding the audio placeholder token
- in the input prompt IDs).
- """
- # The audio projection can either be a single linear or Sequential,
- # so handle both cases
- if isinstance(self.embed_tokens_extend.audio_projection,
- nn.Sequential):
- target_dtype = self.embed_tokens_extend.audio_projection[
- 0].bias.dtype
- else:
- target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype
-
- audio_input = [
- input.unsqueeze(0).to(target_dtype) for input in input_features
- ]
- kwargs = {
- "wte": self.model.embed_tokens,
- 'audio_projection_mode': audio_projection_mode
- }
- audio_embeddings = self.embed_tokens_extend(input_ids, audio_input,
- audio_input_sizes,
- **kwargs)
- audio_embeddings = audio_embeddings.to(target_dtype)
- return audio_embeddings
+ self.sampler = get_sampler()
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
@@ -1574,7 +1006,7 @@ def _parse_and_validate_audio_input(
Returns:
Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
"""
- audio_features = kwargs.pop("audio_features", None)
+ audio_features = kwargs.pop("input_audio_embeds", None)
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None:
@@ -1586,7 +1018,7 @@ def _parse_and_validate_audio_input(
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
- data=audio_features)
+ data=flatten_bn(audio_features))
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
@@ -1598,8 +1030,7 @@ def _parse_and_validate_audio_input(
raise AssertionError("This line should be unreachable.")
- def _process_audio_input(self, input_ids: torch.Tensor,
- audio_input: Phi4MMAudioInputs,
+ def _process_audio_input(self, audio_input: Phi4MMAudioInputs,
audio_projection_mode: str) -> NestedTensors:
"""
Create the audio embeddings from the audio input, where the audio input
@@ -1607,8 +1038,6 @@ def _process_audio_input(self, input_ids: torch.Tensor,
created by `input_mapper_for_phi4mm_audio`.
Args:
- input_ids (torch.Tensor): Input IDs (the prompt in this case,
- before the audio token replication).
audio_input (Phi4MMAudioInputs): Audio input.
Returns:
@@ -1620,21 +1049,20 @@ def _process_audio_input(self, input_ids: torch.Tensor,
audio_features = audio_input["data"]
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
- audio_feature = [i[0] for j in audio_features for i in j]
- audio_feature_len = [i[1].item() for j in audio_features for i in j]
- # Add the batch dim via `squeeze`
- return self._audio_features_to_embeddings(
- input_ids.unsqueeze(0),
- audio_feature,
- audio_feature_len,
- audio_projection_mode,
- ).squeeze(0)
+ dtype = next(self.embed_tokens_extend.parameters()).dtype
+ audio_embeds = [
+ self.embed_tokens_extend(
+ features.to(dtype),
+ audio_projection_mode=audio_projection_mode,
+ ) for features in audio_features
+ ]
+ return audio_embeds
def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[Dict]:
- pixel_values: Optional[Dict] = kwargs.get("pixel_values")
- if pixel_values is None:
+ input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
+ if input_image_embeds is None:
return None
image_sizes = kwargs.get("image_sizes")
@@ -1643,23 +1071,24 @@ def _parse_and_validate_image_input(self,
assert image_sizes is not None and image_attention_mask is not None\
and num_img_tokens is not None, "Missing image inputs"
- if isinstance(pixel_values, list):
- assert pixel_values[0].dim() == 5, "Incorrect image inputs"
+ if is_list_of(input_image_embeds, torch.Tensor):
+ assert all(p.dim() == 5
+ for p in input_image_embeds), "Incorrect image inputs"
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
- pixel_values = cat_with_pad(pixel_values, dim=0)
- elif isinstance(pixel_values, torch.Tensor):
+ input_image_embeds = cat_with_pad(input_image_embeds, dim=0)
+ elif isinstance(input_image_embeds, torch.Tensor):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
- assert pixel_values.dim() == 6, "Incorrect image inputs"
- pixel_values = pixel_values.flatten(0, 1)
+ assert input_image_embeds.dim() == 6, "Incorrect image inputs"
+ input_image_embeds = input_image_embeds.flatten(0, 1)
else:
- raise ValueError("Incorrect pixel_values inputs")
+ raise ValueError("Incorrect input_image_embeds inputs")
if isinstance(image_attention_mask, list):
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
@@ -1685,80 +1114,140 @@ def _parse_and_validate_image_input(self,
else:
raise ValueError("Incorrect image_attention_mask inputs")
- return {
- 'pixel_values': pixel_values,
- 'image_sizes': image_sizes,
- 'image_attention_mask': image_attention_mask,
- 'num_img_tokens': num_img_tokens,
- }
+ return Phi4MMImagePixelInputs(
+ type="pixel_values",
+ data=input_image_embeds,
+ image_sizes=image_sizes,
+ image_attention_mask=image_attention_mask,
+ num_img_tokens=num_img_tokens,
+ )
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("input_image_embeds",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("input_audio_embeds",
+ "audio_embeds") and "audios" not in modalities:
+ modalities["audios"] = self._parse_and_validate_audio_input(
+ **kwargs)
+
+ return modalities
+
+ def _process_image_input(
+ self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]:
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ dtype = next(self.vision_encoder.parameters()).dtype
+ pixel_values = image_input['data'].to(dtype)
+ image_sizes = image_input['image_sizes']
+ image_attention_mask = image_input['image_attention_mask']
+ image_embeds = self.vision_encoder(pixel_values, image_sizes,
+ image_attention_mask)
+ return image_embeds
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return None
- def merge_image_features_to_inputs_embeds(
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ audio_projection_mode = 'speech'
+ for modality in modalities:
+ # make sure process images first
+ if modality == "images":
+ audio_projection_mode = "vision"
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += tuple(vision_embeddings)
+ if modality == "audios":
+ audio_input = modalities["audios"]
+ audio_embeddings = self._process_audio_input(
+ audio_input, audio_projection_mode=audio_projection_mode)
+ multimodal_embeddings += tuple(audio_embeddings)
+
+ return multimodal_embeddings
+
+ def get_input_embeddings(
self,
input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor,
- image_set_tensors: List[torch.Tensor],
- ):
- position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(
- as_tuple=True)
-
- assert all([t.shape[0] == 1 for t in image_set_tensors
- ]), 'img_set_tensor should have shape (1, N_tokens, C)'
- # Shape: (merged_N_tokens, C)
- image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0)
- image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(
- inputs_embeds.device)
- merged_embeds = inputs_embeds.index_put(
- indices=position_tuple,
- values=image_set_tensor,
- accumulate=False,
- )
- return merged_embeds
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.model.embed_tokens(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
+ return inputs_embeds
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: torch.Tensor,
+ image_input: Optional[Phi4MMImagePixelInputs] = None,
+ audio_input: Optional[Phi4MMAudioFeatureInputs] = None,
+ ) -> torch.Tensor:
+ audio_projection_mode = 'speech'
+ inputs_embeds = self.get_input_embeddings(input_ids)
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=_IMAGE_PLACEHOLDER_TOKEN_ID,
+ )
+ audio_projection_mode = 'vision'
+
+ if audio_input is not None:
+ audio_embeds = self._process_audio_input(
+ audio_input, audio_projection_mode=audio_projection_mode)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ audio_embeds,
+ placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN_ID,
+ )
+ return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
- input_ids = None
inputs_embeds = None
- else:
- # Each entry in this is a pair of audio_features and audio_embed
- # lengths
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
audio_input = self._parse_and_validate_audio_input(**kwargs)
- image_inputs = self._parse_and_validate_image_input(**kwargs)
-
- has_audio = audio_input is not None
- has_image = image_inputs is not None
-
- if has_audio:
- audio_projection_mode = 'vision' if has_image else 'speech'
- inputs_embeds = self._process_audio_input(
- input_ids, audio_input, audio_projection_mode)
-
- if has_image:
- dtype = self.vision_encoder.img_processor.embeddings.\
- patch_embedding.weight.dtype
- pixel_values = image_inputs['pixel_values'].to(dtype)
- image_sizes = image_inputs['image_sizes']
- image_attention_mask = image_inputs['image_attention_mask']
- image_set_tensors = self.vision_encoder(
- pixel_values, image_sizes, image_attention_mask)
- if not has_audio:
- inputs_embeds = self.model.embed_tokens(input_ids)
-
- inputs_embeds = self.merge_image_features_to_inputs_embeds(
- input_ids, inputs_embeds, image_set_tensors)
-
- if has_image or has_audio:
- # multi-modal input, we have set inputs_embeds properly in
- # previous steps
- input_ids = None
- else:
- # text-only, we keep using original input_ids
+
+ if image_input is None and audio_input is None:
inputs_embeds = None
+ else:
+ inputs_embeds = self.get_input_embeddings_v0(
+ input_ids,
+ image_input=image_input,
+ audio_input=audio_input)
+ input_ids = None
hidden_states = self.model(
input_ids,
diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py
index db90848f9809..34a7a73d057a 100644
--- a/vllm/model_executor/models/phi4mm_audio.py
+++ b/vllm/model_executor/models/phi4mm_audio.py
@@ -1159,8 +1159,11 @@ def get_audio_features(
input_embeds: torch.FloatTensor,
audio_attention_mask: torch.Tensor = None,
audio_projection_mode: str = "speech",
- ):
-
+ ) -> torch.FloatTensor:
+ """
+ arguments:
+ input_embeds: audio features (B, T, D) B: num audios in a sequence
+ """
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds,
@@ -1210,62 +1213,20 @@ def get_audio_features(
def forward(
self,
- input_ids: torch.LongTensor,
- input_embeds: torch.FloatTensor,
- audio_embed_sizes,
- **kwargs,
+ audio_features: torch.FloatTensor,
+ audio_attention_mask: torch.Tensor = None,
+ audio_projection_mode: str = "speech",
) -> torch.FloatTensor:
"""
arguments:
- input_ids: input text ids (B, U)
- input_embeds: audio features (B, T, D) B: num audios in a sequence
+ audio_features: audio features (T, D)
+
+ returns:
+ audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
"""
- assert input_embeds is not None and len(input_embeds) == len(
- audio_embed_sizes)
-
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
-
- with torch.no_grad():
- positions = (input_ids == _AUDIO_PLACEHOLDER_TOKEN_ID).nonzero(
- as_tuple=False)
-
- if not isinstance(input_embeds, list):
- input_embeds = [input_embeds]
-
- audio_projection_mode = kwargs.get("audio_projection_mode", "speech")
- audio_set_tensor = [
- self.get_audio_features(
- input_embed, audio_projection_mode=audio_projection_mode)
- for input_embed in input_embeds
- ]
-
- with torch.no_grad():
- input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
-
- if "wte" in kwargs:
- # we use the token embedding layer from the huggingface model, this
- # is REQUIRED to make sure we are using the loaded weights.
- hidden_states = kwargs["wte"](input_ids)
- else:
- # otherwise, we use token embedding in pretrained mixformer from
- # phi team
- hidden_states = self.wte(input_ids)
-
- if len(positions.tolist()) > 0:
- assert sum(audio_embed_sizes) == len(
- positions
- ), "please ensure the encoder outputs have the same length as"\
- " defined in input_ids!"
- idx = 0
- for i in range(len(audio_embed_sizes)):
- cnt = audio_embed_sizes[i]
- assert audio_set_tensor[i].shape[0] == 1
- hidden_states[
- positions[idx, 0],
- positions[idx, 1]:positions[idx, 1] + cnt,
- ] = (audio_set_tensor[i][0, :audio_embed_sizes[i], :].to(
- hidden_states.dtype).to(hidden_states.device))
- idx += cnt
-
- return hidden_states
+ audio_embeds = self.get_audio_features(
+ audio_features.unsqueeze(0),
+ audio_attention_mask=audio_attention_mask,
+ audio_projection_mode=audio_projection_mode,
+ )
+ return audio_embeds.squeeze(0)
diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py
index f379ec1682a3..70a912c9c9ef 100644
--- a/vllm/multimodal/audio.py
+++ b/vllm/multimodal/audio.py
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
-
import base64
from io import BytesIO
from pathlib import Path
+from typing import Literal, Optional
import numpy as np
import numpy.typing as npt
@@ -43,7 +43,7 @@ def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"There is no default maximum multimodal tokens")
-def resample_audio(
+def resample_audio_librosa(
audio: npt.NDArray[np.floating],
*,
orig_sr: float,
@@ -52,6 +52,55 @@ def resample_audio(
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
+def resample_audio_scipy(
+ audio: npt.NDArray[np.floating],
+ *,
+ orig_sr: float,
+ target_sr: float,
+):
+ # lazy import scipy.signal, otherwise it will crash doc build.
+ import scipy.signal
+
+ if orig_sr > target_sr:
+ return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
+ elif orig_sr < target_sr:
+ return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
+ return audio
+
+
+class AudioResampler:
+ """Resample audio data to a target sample rate."""
+
+ def __init__(
+ self,
+ target_sr: Optional[float] = None,
+ method: Literal["librosa", "scipy"] = "librosa",
+ ):
+ self.target_sr = target_sr
+ self.method = method
+
+ def resample(
+ self,
+ audio: npt.NDArray[np.floating],
+ *,
+ orig_sr: float,
+ ) -> npt.NDArray[np.floating]:
+ if self.target_sr is None:
+ raise RuntimeError("Audio resampling is not supported when "
+ "`target_sr` is not provided")
+ if self.method == "librosa":
+ return resample_audio_librosa(audio,
+ orig_sr=orig_sr,
+ target_sr=self.target_sr)
+ elif self.method == "scipy":
+ return resample_audio_scipy(audio,
+ orig_sr=orig_sr,
+ target_sr=self.target_sr)
+ else:
+ raise ValueError(f"Invalid resampling method: {self.method}. "
+ "Supported methods are 'librosa' and 'scipy'.")
+
+
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py
index fc5a294564e3..9707b9cfcf8b 100644
--- a/vllm/multimodal/parse.py
+++ b/vllm/multimodal/parse.py
@@ -3,8 +3,8 @@
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
-from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
- Union)
+from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
+ TypeVar, Union)
import numpy as np
import torch
@@ -14,7 +14,7 @@
from vllm.utils import is_list_of
-from .audio import resample_audio
+from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
@@ -308,10 +308,18 @@ class MultiModalDataParser:
items to the model's expected sampling rate.
"""
- def __init__(self, *, target_sr: Optional[float] = None) -> None:
+ def __init__(
+ self,
+ *,
+ target_sr: Optional[float] = None,
+ audio_resample_method: Literal["librosa", "scipy"] = "librosa",
+ ) -> None:
super().__init__()
- self.target_sr = target_sr
+ self.audio_resampler = AudioResampler(
+ target_sr=target_sr,
+ method=audio_resample_method,
+ )
def _is_embeddings(
self, data: object
@@ -374,15 +382,8 @@ def _parse_audio_data(
if orig_sr is None:
new_audio = audio
else:
- target_sr = self.target_sr
- if target_sr is None:
- raise RuntimeError(
- "Audio resampling is not supported when "
- "`target_sr` is not provided")
-
- new_audio = resample_audio(audio,
- orig_sr=orig_sr,
- target_sr=target_sr)
+ new_audio = self.audio_resampler.resample(audio,
+ orig_sr=orig_sr)
new_audios.append(new_audio)