Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions docs/source/contributing/model/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ def get_dummy_processor_inputs(

## 4. Specify processing details

Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`
to fill in the missing details about HF processing.
Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (decoder-only models) /
{class}`~vllm.multimodal.processing.EncDecMultiModalProcessor` (encoder-decoder models) to fill in the missing details about HF processing.

:::{seealso}
[Multi-Modal Data Processing](#mm-processing)
Expand Down Expand Up @@ -932,6 +932,95 @@ def _get_prompt_updates(

::::

### (Optional) Encoder-Decoder prompt construction
If your model is encoder-decoder architecture, you need to also implement `create_encoder_prompt` and
`create_decoder_prompt` to indicate how to create econder/decoder prompt from an implicit text/tokens prompt.

::::{tab-set}
:::{tab-item} Cross modality example: Mllama
:sync: mllama

For models like Mllama and Whisper, their encoder only accept processed modality data. However, to support cross-attention
profiling, we still need to provide a "fake" encoder prompt to profile the sequence length occupied by encoder hidden states.

In this case, we can treat encoder prompt as features tokens created by prompt updates, and implicit prompt as decoder `input_ids`
(default behavior of `create_decoder_prompt`).

Therefore, we just need to provide image token for number of images as encoder prompt. And let prompt updates construct
the final encoder prompt for us automatically.

```python
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
data = mm_data.get("image", [])
num_images = 1 if isinstance(data, Image) else len(data)
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
```

:::

:::{tab-item} Cross text-only: Florence-2
:sync: florence2

For Florence-2, cross-attention only occurs in its text backbone (Bart), and it uses both
`encoder_input_ids` and `input_ids`(decoder) in forwarding.

In this case, we need to provide appropriate encoder prompt and decoder prompt to make sure
correct `encoder_input_ids` and `input_ids` are fed to Bart backbone.

For encoder part, we can treat implicit prompt as encoder prompt because it will be processed
by hf_processor and should be fed to encoder as `encoder_input_ids` directly. We can let
`create_encoder_prompt` return original prompt directly:

```python
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
```

Then, let's go to the decoder part and take a look at how HF's Bart preparing decoder token_ids:

```python
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

return shifted_input_ids

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
```

Given the code above, if we don't provide an explicit decoder prompt, the decoder prompt is just one EOS token,
so we can implement `create_decoder_prompt` as below:

```python
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return [self.info.get_hf_config().eos_token_id]
```

::::

## 5. Register processor-related classes

After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2),
Expand Down
Loading