From cef67ee34e9d535589e0cb5ac583b01800356a11 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 22 Jan 2025 09:54:48 +0000 Subject: [PATCH 1/3] fix Signed-off-by: Roger Wang --- vllm/model_executor/models/llava_onevision.py | 34 ++++++++++--- vllm/model_executor/models/qwen2_vl.py | 50 +++++++++++++------ 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 6faa79f65d8d..1a4c0bd71626 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -871,13 +871,35 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) - input_ids = None + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_index, + ) + + if video_input is not None: + video_embeds = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_index, + ) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 34d5c8ad089a..0b83304ee075 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1301,22 +1301,42 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - # NOTE: In v1, inputs_embeds is always generated at model runner, this - # condition is for v0 compatibility. + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - - # We need to check for usage of mrope here in case there is - # multimodal data. - # TODO (ywang96): move this to model runner in V1. - if multimodal_embeddings is not None and uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.get_input_embeddings(input_ids, - multimodal_embeddings) - input_ids = None + + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.get_input_embeddings(input_ids) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + input_ids = None hidden_states = self.language_model.model( input_ids=input_ids, From 57b1febb623089734823b39695a38dd42e5c7475 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 22 Jan 2025 10:52:24 +0000 Subject: [PATCH 2/3] simplify Signed-off-by: Roger Wang --- vllm/model_executor/models/llava_onevision.py | 50 +++++++++++------- vllm/model_executor/models/qwen2_vl.py | 52 +++++++++++-------- 2 files changed, 63 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 1a4c0bd71626..e1d908318894 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -852,6 +852,34 @@ def get_input_embeddings( [self.config.image_token_index, self.config.video_token_index]) return inputs_embeds + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[NestedTensors] = None, + video_input: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_index, + ) + + if video_input is not None: + video_embeds = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_index, + ) + + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -881,24 +909,10 @@ def forward( if image_input is None and video_input is None: inputs_embeds = None else: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_index, - ) - - if video_input is not None: - video_embeds = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_index, - ) + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) input_ids = None hidden_states = self.language_model.model(input_ids, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0b83304ee075..3c7223a83c38 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1268,6 +1268,33 @@ def get_input_embeddings( [self.config.image_token_id, self.config.video_token_id]) return inputs_embeds + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[tuple[torch.Tensor, ...]] = None, + video_input: Optional[tuple[torch.Tensor, ...]] = None, + ) -> torch.Tensor: + + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -1305,7 +1332,6 @@ def forward( # `get_multimodal_embeddings` and `get_input_embeddings`, this # condition is only for v0 compatibility. elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) @@ -1316,26 +1342,10 @@ def forward( assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.get_input_embeddings(input_ids) - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) input_ids = None hidden_states = self.language_model.model( From 012851bb0d85bfd21268b5e3a2d4bbef4edc33f6 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 22 Jan 2025 11:01:05 +0000 Subject: [PATCH 3/3] typing Signed-off-by: Roger Wang --- vllm/model_executor/models/llava_onevision.py | 5 ++--- vllm/model_executor/models/qwen2_vl.py | 7 +++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index e1d908318894..5b0f35b08646 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -816,7 +816,7 @@ def apply_pooling(self, image_features, stride=2): return image_feature def get_multimodal_embeddings( - self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -842,8 +842,7 @@ def get_multimodal_embeddings( def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[List[Tuple[NestedTensors, - str]]] = None, + multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3c7223a83c38..2519b93f4fd4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -55,7 +55,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargs, - NestedTensors, VideoItem) + VideoItem) from vllm.multimodal.parse import (ImageSize, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -1231,7 +1231,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return modalities def get_multimodal_embeddings( - self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: @@ -1258,8 +1258,7 @@ def get_multimodal_embeddings( def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[List[Tuple[NestedTensors, - str]]] = None, + multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: