diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 32f3e9deff67..3239463ffb59 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -719,7 +719,7 @@ See [this page](#generative-models) for more information on how to use generativ * `THUDM/glm-4v-9b` etc. * ✅︎ * ✅︎ - * + * ✅︎ - * `H2OVLChatModel` * H2OVL * T + IE+ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 436c36570599..9a4183106cff 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -106,7 +106,9 @@ def run_glm4v(question: str, modality: str): trust_remote_code=True, enforce_eager=True, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = question + prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ + {question}<|assistant|>" + stop_token_ids = [151329, 151336, 151338] return llm, prompt, stop_token_ids diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 77cf3442df90..8658e60bc5b2 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -147,6 +147,7 @@ def _test_processing_correctness( "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", "adept/fuyu-8b", + "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", "OpenGVLab/InternVL2-1B", "HuggingFaceM4/Idefics3-8B-Llama3", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index a31648675259..9ee9e9ca8009 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,20 +4,21 @@ # https://github.com/THUDM/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" from argparse import Namespace -from array import array -from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple, - TypedDict) +from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple, + TypedDict, Union) import torch -from PIL import Image from torch import nn from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,73 +36,55 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BatchFeature, + BoundPromptReplacement, + MultiModalFieldConfig, + PlaceholderFeaturesInfo, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) +IMAGE_TOKEN_ID = 151329 -def calculate_image_placeholder(vision_config): - return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 +def build_normalization_transform(image_size: int) -> transforms.Compose: + """ + Build a normalization transform which can be applied to one or + more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. + """ -def mm_input_mapper_for_glmv( - ctx: InputContext, - data: ModalityData[object], -) -> Dict: - model_config = ctx.model_config - tokenizer = cached_get_tokenizer( - model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - if tokenizer is None: - raise RuntimeError("No HuggingFace processor is available " - "to process the image object") - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": data - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True).data - except Exception: - logger.error("Failed to process image (%s)", data) - raise - pixel_values = raw_batch_data['images'] - - return MultiModalKwargs({'pixel_values': pixel_values}) - - -def merge_glm_vision_embeddings( - input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - boi_token_id: int, - eoi_token_id: int, -) -> torch.Tensor: - - boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0] - eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0] - - mask = torch.zeros_like(input_ids, dtype=torch.bool) - - for boi_pos, eoi_pos in zip(boi_positions, eoi_positions): - assert boi_pos < eoi_pos - mask[boi_pos:eoi_pos + 1] = True - inputs_embeds[mask] = vision_embeddings.view(-1, - vision_embeddings.shape[-1]) - return inputs_embeds + return transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ]) + + +def calculate_image_placeholder(vision_config): + return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 class GLMImagePixelInputs(TypedDict): @@ -109,120 +92,177 @@ class GLMImagePixelInputs(TypedDict): """Shape: `(batch_size, num_channels, height, width)`""" -def get_max_glmv_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(ChatGLMConfig) +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. - vision_config = getattr(hf_config, 'vision_config', None) - if vision_config is None: - return 1 - elif isinstance(vision_config, dict): - return calculate_image_placeholder(vision_config) + """ - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer -def dummy_data_for_glmv(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]) -> DummyData: - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + if hasattr(self.config, "vision_config"): + self.image_transform = build_normalization_transform( + config.vision_config["image_size"]) + else: + self.image_transform = None - if vision_config is None: - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) - seq_data = SequenceData(token_ids) - return DummyData(seq_data, None) - elif isinstance(vision_config, dict): - image_size = vision_config["image_size"] - image_placeholder_length = calculate_image_placeholder(vision_config) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + - [0] * image_placeholder_length + - [hf_config.eoi_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0] * (seq_len - image_placeholder_length - 2)) - seq_data = SequenceData(token_ids) + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + text_inputs = self.tokenizer(text) + if len(images) == 0: + image_inputs = {} + else: + if self.image_transform is None: + raise ValueError("This model does not support image inputs") + + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) - mm_data = { - "image": Image.new("RGB", (image_size, image_size), color=0) - } - return DummyData(seq_data, mm_data) +class GLM4VProcessingInfo(BaseProcessingInfo): - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def __init__(self, ctx): + super().__init__(ctx) + self._pre_calculate() + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} -def find_all_positions(input_ids: List[int], target: int) -> List[int]: - return [index for index, value in enumerate(input_ids) if value == target] + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.image_token_num + 2} -def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + def _pre_calculate(self): + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + self.image_token_num = calculate_image_placeholder(vision_config) + self.image_size = vision_config["image_size"] - hf_config = ctx.get_hf_config(ChatGLMConfig) - vision_config = getattr(hf_config, 'vision_config', None) + def get_num_image_tokens(self) -> int: + return self.image_token_num + 2 - if vision_config is None: - return inputs - elif isinstance(vision_config, dict): - image_placeholder_length = calculate_image_placeholder(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def get_image_size(self) -> ImageSize: - input_ids = inputs["prompt_token_ids"] + return ImageSize(height=self.image_size, width=self.image_size) - tokenizer = cached_get_tokenizer( - ctx.model_config.model, - trust_remote_code=ctx.model_config.trust_remote_code) + def get_hf_processor(self) -> GLM4VProcessor: + return GLM4VProcessor( + self.get_hf_config(), + self.get_tokenizer(), + ) - try: - raw_batch_data = tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "image": multi_modal_data["image"], - "content": inputs['prompt'], - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True, - ).data - except Exception: - logger.error("Failed to process content (%s)", inputs['prompt']) - raise - input_ids = raw_batch_data['input_ids'][0].tolist() - boi_token_id = hf_config.boi_token_id - eoi_token_id = hf_config.eoi_token_id - boi_positions = find_all_positions(input_ids, boi_token_id) - eoi_positions = find_all_positions(input_ids, eoi_token_id) +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - assert len(boi_positions) == len(eoi_positions) + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + target_width, target_height = self.info.get_image_size() - new_input_ids = [] - final_processed_position = 0 + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + return ProcessorInputs( + prompt_text=text, + mm_data=mm_data, + ) - for boi_position, eoi_position in zip(boi_positions, eoi_positions): - assert boi_position < eoi_position - new_input_ids.extend(input_ids[final_processed_position:boi_position + - 1]) - new_input_ids.extend([input_ids[boi_position + 1]] * - image_placeholder_length) - final_processed_position = eoi_position - new_input_ids.extend(input_ids[final_processed_position:]) +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement(item_idx: int): + image_tokens = self.info.image_token_num + return [IMAGE_TOKEN_ID] * image_tokens + + return [ + PromptReplacement( + modality="image", + target=[IMAGE_TOKEN_ID], + replacement=get_replacement, + ), + ] - prompt = inputs.get("prompt") - if prompt is None: - prompt = tokenizer.decode(new_input_ids) + def _apply_prompt_replacements( + self, + token_ids: list[int], + mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids=token_ids, + mm_prompt_repls=mm_prompt_repls, + mm_item_counts=mm_item_counts, + ) + hf_config = self.info.get_hf_config() + boi_token_id = hf_config.boi_token_id + eoi_token_id = hf_config.eoi_token_id + placeholders = { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=p.start_idx - 1, + tokens=[boi_token_id] + p.tokens + [eoi_token_id], + ) for p in ps + ] + for modality, ps in placeholders.items() + } - return token_inputs( - prompt_token_ids=new_input_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - ) + return token_ids, text, placeholders class GLMAttention(nn.Module): @@ -572,12 +612,16 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) if multimodal_embeddings is not None: - inputs_embeds = merge_glm_vision_embeddings( + inputs_embeds = merge_multimodal_embeddings( input_ids=input_ids, inputs_embeds=inputs_embeds, - vision_embeddings=multimodal_embeddings, - boi_token_id=self.config.boi_token_id, - eoi_token_id=self.config.eoi_token_id) + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + IMAGE_TOKEN_ID, + self.config.eoi_token_id, + ], + ) return inputs_embeds def forward( @@ -593,14 +637,12 @@ def forward( # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. - if intermediate_tensors is None and inputs_embeds is None: + if intermediate_tensors is not None: + inputs_embeds = intermediate_tensors["hidden_states"] + elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) - input_ids = None - else: - inputs_embeds = intermediate_tensors["hidden_states"] - # Run encoder. hidden_states = self.encoder( hidden_states=inputs_embeds, @@ -763,11 +805,21 @@ def get_mm_mapping(self) -> MultiModelKeys: connector="transformer.vision.linear_proj", tower_model="transformer.vision.transformer") + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + return self.transformer.get_multimodal_embeddings(**kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids, + multimodal_embeddings) + -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not