diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 9cbfc32991f0..66ac30bedfcd 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -598,8 +598,8 @@ def get_dummy_processor_inputs( ## 4. Specify processing details -Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` -to fill in the missing details about HF processing. +Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (decoder-only models) / +{class}`~vllm.multimodal.processing.EncDecMultiModalProcessor` (encoder-decoder models) to fill in the missing details about HF processing. :::{seealso} [Multi-Modal Data Processing](#mm-processing) @@ -932,6 +932,95 @@ def _get_prompt_updates( :::: +### (Optional) Encoder-Decoder prompt construction +If your model is encoder-decoder architecture, you need to also implement `create_encoder_prompt` and +`create_decoder_prompt` to indicate how to create econder/decoder prompt from an implicit text/tokens prompt. + +::::{tab-set} +:::{tab-item} Cross modality example: Mllama +:sync: mllama + +For models like Mllama and Whisper, their encoder only accept processed modality data. However, to support cross-attention +profiling, we still need to provide a "fake" encoder prompt to profile the sequence length occupied by encoder hidden states. + +In this case, we can treat encoder prompt as features tokens created by prompt updates, and implicit prompt as decoder `input_ids` +(default behavior of `create_decoder_prompt`). + +Therefore, we just need to provide image token for number of images as encoder prompt. And let prompt updates construct +the final encoder prompt for us automatically. + +```python +def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, +) -> Union[str, list[int]]: + data = mm_data.get("image", []) + num_images = 1 if isinstance(data, Image) else len(data) + image_token_id = self.info.get_hf_config().image_token_index + return [image_token_id] * num_images +``` + +::: + +:::{tab-item} Cross text-only: Florence-2 +:sync: florence2 + +For Florence-2, cross-attention only occurs in its text backbone (Bart), and it uses both +`encoder_input_ids` and `input_ids`(decoder) in forwarding. + +In this case, we need to provide appropriate encoder prompt and decoder prompt to make sure +correct `encoder_input_ids` and `input_ids` are fed to Bart backbone. + +For encoder part, we can treat implicit prompt as encoder prompt because it will be processed +by hf_processor and should be fed to encoder as `encoder_input_ids` directly. We can let +`create_encoder_prompt` return original prompt directly: + +```python +def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, +) -> Union[str, list[int]]: + return prompt +``` + +Then, let's go to the decoder part and take a look at how HF's Bart preparing decoder token_ids: + +```python +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + +def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) +``` + +Given the code above, if we don't provide an explicit decoder prompt, the decoder prompt is just one EOS token, +so we can implement `create_decoder_prompt` as below: + +```python +def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, +) -> Union[str, list[int]]: + return [self.info.get_hf_config().eos_token_id] +``` + +:::: + ## 5. Register processor-related classes After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2),