-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input #11674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
DarkLight1337
merged 3 commits into
vllm-project:main
from
DarkLight1337:fix-v1-multi-mm
Jan 2, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2,7 +2,8 @@ | |||||||||||||
from collections import UserDict, defaultdict | ||||||||||||||
from collections.abc import Mapping, Sequence | ||||||||||||||
from dataclasses import dataclass | ||||||||||||||
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final | ||||||||||||||
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast, | ||||||||||||||
final) | ||||||||||||||
|
||||||||||||||
import numpy as np | ||||||||||||||
import torch | ||||||||||||||
|
@@ -11,7 +12,7 @@ | |||||||||||||
from transformers import BatchFeature | ||||||||||||||
from typing_extensions import NotRequired, TypeAlias | ||||||||||||||
|
||||||||||||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves | ||||||||||||||
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves | ||||||||||||||
|
||||||||||||||
_T = TypeVar("_T") | ||||||||||||||
|
||||||||||||||
|
@@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: | |||||||||||||
|
||||||||||||||
|
||||||||||||||
@dataclass(frozen=True) | ||||||||||||||
class MultiModalFieldItem: | ||||||||||||||
""" | ||||||||||||||
Contains metadata and data in :class:`MultiModalKwargs` | ||||||||||||||
corresponding to a data item in :class:`MultiModalDataItems`. | ||||||||||||||
""" | ||||||||||||||
class MultiModalFieldElem: | ||||||||||||||
"""Contains metadata and data of an item in :class:`MultiModalKwargs`.""" | ||||||||||||||
field: "BaseMultiModalField" | ||||||||||||||
data: NestedTensors | ||||||||||||||
|
||||||||||||||
|
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC): | |||||||||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: | ||||||||||||||
raise NotImplementedError | ||||||||||||||
|
||||||||||||||
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem: | ||||||||||||||
return MultiModalFieldItem(self, data) | ||||||||||||||
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem: | ||||||||||||||
return MultiModalFieldElem(self, data) | ||||||||||||||
|
||||||||||||||
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem: | ||||||||||||||
"""Merge multiple instances of :class:`MultiModalFieldItem` together.""" | ||||||||||||||
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem: | ||||||||||||||
"""Merge multiple instances of :class:`MultiModalFieldElem` together.""" | ||||||||||||||
fields = [item.field for item in batch] | ||||||||||||||
if len(set(fields)) > 1: | ||||||||||||||
raise ValueError(f"Cannot merge different {fields=}") | ||||||||||||||
|
||||||||||||||
data = self._reduce_data([item.data for item in batch]) | ||||||||||||||
|
||||||||||||||
return self._build_item(data) | ||||||||||||||
return self._build_elem(data) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@dataclass(frozen=True) | ||||||||||||||
class MultiModalBatchedField(BaseMultiModalField): | ||||||||||||||
""" | ||||||||||||||
A :class:`BaseMultiModalField` implementation where an item is obtained by | ||||||||||||||
directly indexing into the first dimension of the underlying data. | ||||||||||||||
A :class:`BaseMultiModalField` implementation where an element in the batch | ||||||||||||||
is obtained by indexing into the first dimension of the underlying data. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]: | ||||||||||||||
return [self._build_item(item) for item in batch] | ||||||||||||||
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]: | ||||||||||||||
return [self._build_elem(item) for item in batch] | ||||||||||||||
|
||||||||||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: | ||||||||||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): | ||||||||||||||
first_shape = batch[0].shape | ||||||||||||||
if all(item.shape == first_shape for item in batch): | ||||||||||||||
if all(elem.shape == first_shape for elem in batch): | ||||||||||||||
return torch.stack(batch) | ||||||||||||||
|
||||||||||||||
return batch | ||||||||||||||
|
@@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: | |||||||||||||
@dataclass(frozen=True) | ||||||||||||||
class MultiModalFlatField(BaseMultiModalField): | ||||||||||||||
""" | ||||||||||||||
A :class:`BaseMultiModalField` implementation where an item is obtained by | ||||||||||||||
slicing along the first dimension of the underlying data. | ||||||||||||||
A :class:`BaseMultiModalField` implementation where an element in the batch | ||||||||||||||
is obtained by slicing along the first dimension of the underlying data. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def build_items( | ||||||||||||||
def build_elems( | ||||||||||||||
self, | ||||||||||||||
batch: NestedTensors, | ||||||||||||||
slices: Sequence[slice], | ||||||||||||||
) -> list[MultiModalFieldItem]: | ||||||||||||||
return [self._build_item(batch[slice_]) for slice_ in slices] | ||||||||||||||
) -> list[MultiModalFieldElem]: | ||||||||||||||
return [self._build_elem(batch[slice_]) for slice_ in slices] | ||||||||||||||
|
||||||||||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: | ||||||||||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): | ||||||||||||||
first_shape = batch[0].shape | ||||||||||||||
if all(item.shape[1:] == first_shape[1:] for item in batch): | ||||||||||||||
if all(elem.shape[1:] == first_shape[1:] for elem in batch): | ||||||||||||||
return torch.concat(batch) | ||||||||||||||
|
||||||||||||||
return [elem for item in batch for elem in item] | ||||||||||||||
return [e for elem in batch for e in elem] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class MultiModalFieldConfig: | ||||||||||||||
|
@@ -267,115 +265,111 @@ def __init__( | |||||||||||||
) -> None: | ||||||||||||||
super().__init__() | ||||||||||||||
|
||||||||||||||
self._field_cls = field_cls | ||||||||||||||
self._modality = modality | ||||||||||||||
self._field_config = field_config | ||||||||||||||
self.field_cls = field_cls | ||||||||||||||
self.modality = modality | ||||||||||||||
self.field_config = field_config | ||||||||||||||
|
||||||||||||||
def build_items( | ||||||||||||||
def build_elems( | ||||||||||||||
self, | ||||||||||||||
key: str, | ||||||||||||||
batch: NestedTensors, | ||||||||||||||
) -> list[MultiModalFieldItem]: | ||||||||||||||
field = self._field_cls(key=key, modality=self._modality) | ||||||||||||||
return field.build_items(batch, **self._field_config) # type: ignore | ||||||||||||||
) -> Sequence[MultiModalFieldElem]: | ||||||||||||||
field = self.field_cls(key=key, modality=self.modality) | ||||||||||||||
return field.build_elems(batch, **self.field_config) # type: ignore | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class MultiModalKwargs(UserDict[str, NestedTensors]): | ||||||||||||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): | ||||||||||||||
""" | ||||||||||||||
A collection of :class:`MultiModalFieldElem` | ||||||||||||||
corresponding to a data item in :class:`MultiModalDataItems`. | ||||||||||||||
""" | ||||||||||||||
A dictionary that represents the keyword arguments to | ||||||||||||||
:meth:`~torch.nn.Module.forward`. | ||||||||||||||
|
||||||||||||||
The metadata :code:`items_by_key` defines how to split batched keyword | ||||||||||||||
arguments corresponding to each data item in :class:`MultiModalDataItems`: | ||||||||||||||
@staticmethod | ||||||||||||||
def from_elems(elems: Sequence[MultiModalFieldElem]): | ||||||||||||||
return MultiModalKwargsItem({elem.field.key: elem for elem in elems}) | ||||||||||||||
|
||||||||||||||
- For a keyword argument, we can access the :code:`i` th item in the batch | ||||||||||||||
via :code:`items_by_key[key][i]`. | ||||||||||||||
- We can gather the keyword arguments belonging to a modality by finding | ||||||||||||||
the keys with items that belong to that modality, then accessing | ||||||||||||||
the :code:`i` th item in the batch for each such key. | ||||||||||||||
@property | ||||||||||||||
def modality(self) -> str: | ||||||||||||||
modalities = {elem.field.modality for elem in self.data.values()} | ||||||||||||||
assert len(modalities) == 1, f"Found different modalities={modalities}" | ||||||||||||||
return next(iter(modalities)) | ||||||||||||||
|
||||||||||||||
Example: | ||||||||||||||
|
||||||||||||||
.. code-block:: python | ||||||||||||||
|
||||||||||||||
# All items belong to the "image" modality | ||||||||||||||
items_by_key={ | ||||||||||||||
"pixel_values": [a, b, c, d], # "image" modality | ||||||||||||||
"image_grid_thw": [e, f, g, h], # "image" modality | ||||||||||||||
"pixel_values_video": [h, i, j], # "video" modality | ||||||||||||||
"video_grid_thw": [k, l, m], # "video" modality | ||||||||||||||
} | ||||||||||||||
# NOTE: UserDict is for V0 compatibility. | ||||||||||||||
# V1 should access individual items via `get_item`. | ||||||||||||||
class MultiModalKwargs(UserDict[str, NestedTensors]): | ||||||||||||||
""" | ||||||||||||||
A dictionary that represents the keyword arguments to | ||||||||||||||
:meth:`~torch.nn.Module.forward`. | ||||||||||||||
|
||||||||||||||
- The keyword arguments belonging to the first image are | ||||||||||||||
:code:`{"pixel_values": a, "image_grid_thw": e}`. | ||||||||||||||
- The keyword arguments belonging to the second video are | ||||||||||||||
:code:`{"pixel_values_video": i, "video_grid_thw": l}`. | ||||||||||||||
The metadata :code:`items` enables us to obtain the keyword arguments | ||||||||||||||
corresponding to each data item in :class:`MultiModalDataItems`, via | ||||||||||||||
:meth:`get_item` and :meth:`get_items`. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def from_hf_inputs( | ||||||||||||||
hf_inputs: BatchFeature, | ||||||||||||||
config_by_key: Mapping[str, MultiModalFieldConfig], | ||||||||||||||
*, | ||||||||||||||
enable_sanity_checks: bool = False, | ||||||||||||||
): | ||||||||||||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` | ||||||||||||||
# We assume that those fields are not used in vLLM | ||||||||||||||
items_by_key = { | ||||||||||||||
key: config.build_items(key, batch) | ||||||||||||||
for key, config in config_by_key.items() | ||||||||||||||
if (batch := hf_inputs.get(key)) is not None | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
return MultiModalKwargs.from_items_by_key( | ||||||||||||||
items_by_key, | ||||||||||||||
enable_sanity_checks=enable_sanity_checks, | ||||||||||||||
) | ||||||||||||||
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]() | ||||||||||||||
keys_by_modality = defaultdict[str, set[str]](set) | ||||||||||||||
for key, config in config_by_key.items(): | ||||||||||||||
batch = hf_inputs.get(key) | ||||||||||||||
if batch is not None: | ||||||||||||||
elems = config.build_elems(key, batch) | ||||||||||||||
if len(elems) > 0: | ||||||||||||||
elems_by_key[key] = elems | ||||||||||||||
keys_by_modality[config.modality].add(key) | ||||||||||||||
|
||||||||||||||
items = list[MultiModalKwargsItem]() | ||||||||||||||
for modality, keys in keys_by_modality.items(): | ||||||||||||||
elems_in_modality = {k: elems_by_key[k] for k in keys} | ||||||||||||||
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} | ||||||||||||||
|
||||||||||||||
if len(set(batch_sizes.values())) > 1: | ||||||||||||||
raise ValueError( | ||||||||||||||
f"Cannot merge different batch sizes for {modality=}! " | ||||||||||||||
f"Found: {batch_sizes=}") | ||||||||||||||
|
||||||||||||||
batch_size = next(iter(batch_sizes.values())) | ||||||||||||||
for item_idx in range(batch_size): | ||||||||||||||
elems = [v[item_idx] for v in elems_in_modality.values()] | ||||||||||||||
items.append(MultiModalKwargsItem.from_elems(elems)) | ||||||||||||||
|
||||||||||||||
return MultiModalKwargs.from_items(items) | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def from_items_by_key( | ||||||||||||||
items_by_key: Mapping[str, list[MultiModalFieldItem]], | ||||||||||||||
*, | ||||||||||||||
enable_sanity_checks: bool = False, | ||||||||||||||
) -> "MultiModalKwargs": | ||||||||||||||
def from_items(items: Sequence[MultiModalKwargsItem]): | ||||||||||||||
"""Construct a new :class:`MultiModalKwargs` from multiple items.""" | ||||||||||||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) | ||||||||||||||
for item in items: | ||||||||||||||
for key, elem in item.items(): | ||||||||||||||
elems_by_key[key].append(elem) | ||||||||||||||
|
||||||||||||||
data = { | ||||||||||||||
key: items[0].field.reduce(items).data | ||||||||||||||
for key, items in items_by_key.items() if len(items) > 0 | ||||||||||||||
key: elems[0].field.reduce(elems).data | ||||||||||||||
for key, elems in elems_by_key.items() if len(elems) > 0 | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
return MultiModalKwargs(data, | ||||||||||||||
items_by_key=items_by_key, | ||||||||||||||
enable_sanity_checks=enable_sanity_checks) | ||||||||||||||
return MultiModalKwargs(data, items=items) | ||||||||||||||
|
||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
data: Mapping[str, NestedTensors], | ||||||||||||||
*, | ||||||||||||||
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {}, | ||||||||||||||
enable_sanity_checks: bool = False, | ||||||||||||||
items: Optional[Sequence[MultiModalKwargsItem]] = None, | ||||||||||||||
) -> None: | ||||||||||||||
super().__init__(data) | ||||||||||||||
|
||||||||||||||
# Shallow copy to avoid footgun in case a defaultdict is passed in | ||||||||||||||
self._items_by_key = dict(items_by_key) | ||||||||||||||
items_by_modality = full_groupby(items or [], key=lambda x: x.modality) | ||||||||||||||
self._items_by_modality = dict(items_by_modality) | ||||||||||||||
|
||||||||||||||
keys_by_modality = defaultdict[str, set[str]](set) | ||||||||||||||
for key, items in items_by_key.items(): | ||||||||||||||
for item in items: | ||||||||||||||
keys_by_modality[item.field.modality].add(key) | ||||||||||||||
|
||||||||||||||
self._keys_by_modality = dict(keys_by_modality) | ||||||||||||||
|
||||||||||||||
if enable_sanity_checks: | ||||||||||||||
for modality, keys in keys_by_modality.items(): | ||||||||||||||
items_in_modality = {k: items_by_key[k] for k in keys} | ||||||||||||||
batch_sizes = {k: len(v) for k, v in items_in_modality.items()} | ||||||||||||||
batch_size = next(iter(batch_sizes.values()), 0) | ||||||||||||||
assert all(bs == batch_size | ||||||||||||||
for bs in batch_sizes.values()), dict( | ||||||||||||||
modality=modality, | ||||||||||||||
batch_sizes=batch_sizes, | ||||||||||||||
items_by_key=items_by_key) | ||||||||||||||
@property | ||||||||||||||
def modalities(self): | ||||||||||||||
return self._items_by_modality.keys() | ||||||||||||||
Comment on lines
+370
to
+372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This is probably more intuitive? |
||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: | ||||||||||||||
|
@@ -452,58 +446,44 @@ def as_kwargs( | |||||||||||||
def __eq__(self, other: object) -> bool: | ||||||||||||||
if not isinstance(other, self.__class__): | ||||||||||||||
return False | ||||||||||||||
if self._items_by_key != other._items_by_key: | ||||||||||||||
if self._items_by_modality != other._items_by_modality: | ||||||||||||||
return False | ||||||||||||||
|
||||||||||||||
ks = self.keys() | ||||||||||||||
return (ks == other.keys() | ||||||||||||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks)) | ||||||||||||||
|
||||||||||||||
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: | ||||||||||||||
return self._items_by_key[key][item_index] | ||||||||||||||
def _validate_modality(self, method_name: str, modality: str) -> None: | ||||||||||||||
if not self._items_by_modality: | ||||||||||||||
raise RuntimeError( | ||||||||||||||
f"`{method_name}` is not supported when " | ||||||||||||||
"MultiModalKwargs is not initialized with `items`") | ||||||||||||||
|
||||||||||||||
def get_items_by_modality( | ||||||||||||||
self, | ||||||||||||||
modality: str, | ||||||||||||||
item_index: int, | ||||||||||||||
) -> Mapping[str, MultiModalFieldItem]: | ||||||||||||||
""" | ||||||||||||||
Get the keyword arguments corresponding to an item identified by | ||||||||||||||
its modality and index. | ||||||||||||||
""" | ||||||||||||||
if modality not in self._keys_by_modality: | ||||||||||||||
available_modalities = set(self._keys_by_modality.keys()) | ||||||||||||||
if modality not in self._items_by_modality: | ||||||||||||||
available_modalities = set(self._items_by_modality.keys()) | ||||||||||||||
raise KeyError(f"Modality {modality!r} not found. " | ||||||||||||||
f"Available modalities: {available_modalities}") | ||||||||||||||
|
||||||||||||||
keys_to_gather = self._keys_by_modality[modality] | ||||||||||||||
def get_item_count(self, modality: str) -> int: | ||||||||||||||
"""Get the number of items belonging to a modality.""" | ||||||||||||||
self._validate_modality("get_item_count", modality) | ||||||||||||||
return len(self._items_by_modality[modality]) | ||||||||||||||
|
||||||||||||||
return { | ||||||||||||||
key: self.get_item(key, item_index) | ||||||||||||||
for key in keys_to_gather if key in self | ||||||||||||||
} | ||||||||||||||
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem: | ||||||||||||||
""" | ||||||||||||||
Get the keyword arguments corresponding to an item identified by | ||||||||||||||
its modality and index. | ||||||||||||||
""" | ||||||||||||||
self._validate_modality("get_item", modality) | ||||||||||||||
return self._items_by_modality[modality][item_index] | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def from_items_by_modality( | ||||||||||||||
items_by_modality: Mapping[str, list[Mapping[str, | ||||||||||||||
MultiModalFieldItem]]], | ||||||||||||||
*, | ||||||||||||||
enable_sanity_checks: bool = False, | ||||||||||||||
) -> "MultiModalKwargs": | ||||||||||||||
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: | ||||||||||||||
""" | ||||||||||||||
Construct a new :class:`MultiModalKwargs` from multiple items returned | ||||||||||||||
by :meth:`get_fields_by_modality`. | ||||||||||||||
Get the keyword arguments corresponding to each item belonging to | ||||||||||||||
a modality. | ||||||||||||||
""" | ||||||||||||||
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list) | ||||||||||||||
for fields in items_by_modality.values(): | ||||||||||||||
for field in fields: | ||||||||||||||
for k, v in field.items(): | ||||||||||||||
items_by_key[k].append(v) | ||||||||||||||
|
||||||||||||||
return MultiModalKwargs.from_items_by_key( | ||||||||||||||
items_by_key, | ||||||||||||||
enable_sanity_checks=enable_sanity_checks, | ||||||||||||||
) | ||||||||||||||
self._validate_modality("get_items", modality) | ||||||||||||||
return self._items_by_modality[modality] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] | ||||||||||||||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.