Skip to content
Merged
135 changes: 133 additions & 2 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from PIL.Image import Image
Expand Down Expand Up @@ -978,7 +978,10 @@ def __init__(self,
self.tokenizer = tokenizer
self.vocab_size = model_config.text_config.vocab_size
self.image_token_index = model_config.image_token_index

self.fake_image_token = self.processor.fake_image_token
self.image_token = self.processor.img_patch_token
self.image_token_start_index = self.model_config.boi_token_index
self.image_token_end_index = self.model_config.eoi_token_index
self.encoder = nn.ModuleDict({
"vision_model":
Llama4VisionModel(model_config.vision_config),
Expand All @@ -987,6 +990,134 @@ def __init__(self,
}).cuda()
load_sharded_checkpoint(self.encoder, model_path, strict=False)

def attach_multimodal_embeddings(
self, inputs: TextPrompt, multimodal_embedding: Dict[str,
List[Dict[str,
Any]]],
sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Attach pre-processed multimodal embeddings into text token stream for Llama4 model.

This method skips vision processing and works with externally provided embeddings.
It replaces/expands image placeholders in the text with appropriate tokens and prepares
the embeddings for model forward pass.

Args:
inputs: Text prompt containing image placeholders
multimodal_embedding: Dictionary containing pre-processed image embedding data with special token information.
Consider adding metadata fields (e.g., model_type, model_name, version) for validation.
Returns:
Tuple of (token_ids, extra_processed_inputs) where:
- token_ids: List of processed token IDs with image placeholders
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
"""
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")

if not isinstance(multimodal_embedding, dict):
raise ValueError("multimodal_embedding must be a dictionary")

if 'image' not in multimodal_embedding:
raise ValueError(
"Only image modality is supported for external multimodal embedding"
)

mm_embedding_info = multimodal_embedding['image']
if not mm_embedding_info or not isinstance(mm_embedding_info[0], dict):
raise ValueError(
"Llama4 image embedding must contain special token information")

# Extract embedding components
try:
mm_embeddings = [
mm_embedding['mm_embeddings']
for mm_embedding in mm_embedding_info
]
mm_embedding_special_tokens = [
mm_embedding['image_special_tokens']
for mm_embedding in mm_embedding_info
]
mm_embedding_special_offsets = [
mm_embedding['image_special_token_offsets']
for mm_embedding in mm_embedding_info
]
except KeyError as e:
raise ValueError(
f"Missing required key in multimodal embedding: {e}")

# Validate embedding dimensions
model_hidden_size = self.model_config.text_config.hidden_size
for i, embedding in enumerate(mm_embeddings):
if embedding.shape[-1] != model_hidden_size:
raise ValueError(
f"Multimodal embedding {i} hidden size {embedding.shape[-1]} "
f"must match model hidden size {model_hidden_size}")

# Count image placeholders (number of images) in the prompt
total_placeholders = text_prompt.count(self.fake_image_token)
if total_placeholders == 0:
raise ValueError(
"No image placeholders found in the prompt, but multimodal embedding was provided"
)

if total_placeholders != len(mm_embeddings):
raise ValueError(
f"Number of image placeholders ({total_placeholders}) "
f"does not match number of embeddings ({len(mm_embeddings)})")

# Process prompt with image embeddings
prompt_splits = text_prompt.split(self.fake_image_token)
new_prompt_parts = []

for local_image_index, split_part in enumerate(prompt_splits):
new_prompt_parts.append(split_part)

if local_image_index < total_placeholders:
# Calculate total tokens for this image
num_tokens = len(mm_embeddings[local_image_index]) + len(
mm_embedding_special_tokens[local_image_index])

# Create image token sequence
image_tokens = [self.image_token] * num_tokens

# Replace special tokens with actual decoded tokens
for offset, token_id in zip(
mm_embedding_special_offsets[local_image_index],
mm_embedding_special_tokens[local_image_index]):
if offset < 0 or offset >= len(image_tokens):
raise ValueError(
f"Image special token offset {offset} is out of range with the total image tokens length {len(image_tokens)}"
)
if offset < len(image_tokens):
image_tokens[offset] = self.tokenizer.decode([token_id])

# Join tokens without spaces
image_str = "".join(image_tokens)
new_prompt_parts.append(image_str)

# Combine all parts and tokenize
processed_text = "".join(new_prompt_parts)
kwargs = {}
if sampling_params.truncate_prompt_tokens is not None:
kwargs = dict(truncation=True,
max_length=sampling_params.truncate_prompt_tokens)
text_inputs = self.tokenizer(
processed_text,
return_tensors="pt",
add_special_tokens=sampling_params.add_special_tokens,
**kwargs)
token_ids = text_inputs.input_ids.squeeze()

# Replace image token indices with out-of-vocabulary tokens
token_ids[token_ids == self.image_token_index] = self.vocab_size + 1
# Concatenate all multimodal embeddings
multimodal_data = {}
multimodal_data["multimodal_embedding"] = torch.cat(mm_embeddings,
dim=0)
return token_ids.tolist(), {"multimodal_data": multimodal_data}

@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
Expand Down
44 changes: 43 additions & 1 deletion tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -195,6 +195,48 @@ def _postprocess(self, input_ids, mm_features):
mm_features = mm_features.view(-1, mm_features.shape[-1])
return fused_input_ids, mm_features

def attach_multimodal_embeddings(
self, inputs: TextPrompt,
multimodal_embedding: Dict[str, List[torch.Tensor]],
sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.

This method skips vision processing and works with externally provided embeddings.
It replaces/expands image placeholders in the text with appropriate tokens and prepares
the embeddings for model forward pass.

Args:
inputs: Text prompt containing image placeholders
multimodal_embedding: Dictionary containing pre-processed image embedding data
Returns:
Tuple of (token_ids, extra_processed_inputs) where:
- token_ids: List of processed token IDs with image placeholders
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
"""
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")

if not isinstance(multimodal_embedding, dict):
raise ValueError("multimodal_embedding must be a dictionary")

if 'image' not in multimodal_embedding:
raise ValueError(
"Only image modality is supported for external multimodal embedding"
)

input_ids = self.tokenizer(
text_prompt, return_tensors="pt").input_ids[0].to(self.device)
mm_features = torch.stack(multimodal_embedding['image'])
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
multimodal_data = {}
multimodal_data["multimodal_embedding"] = mm_features
return fused_input_ids.to(torch.int32).tolist(), {
"multimodal_data": multimodal_data
}

@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
Expand Down
11 changes: 11 additions & 0 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,17 @@ def _deduce_max_tokens(request: GenerationRequest,

if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:
# Convert back to tensor, as opposite to `to_handle` in `llm.generate_async`
# for values with non-selected keys, it's no-op
request.multimodal_params.to_tensor(
"multimodal_data", key="multimodal_embedding")
embedding = request.multimodal_params.multimodal_data.get(
"multimodal_embedding")
if embedding is not None and embedding.is_cuda:
# make sure the embedding resides on the local device
request.multimodal_params.multimodal_data[
"multimodal_embedding"] = embedding.to("cuda")

executor_request.py_multimodal_data = request.multimodal_params.multimodal_data

if self._is_pytorch_backend and request.sampling_params.logits_processor:
Expand Down
146 changes: 146 additions & 0 deletions tensorrt_llm/inputs/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,152 @@ def _to_device(
f"MultimodalParams: Unsupported element '{element}' to move to device. "
f"Supported elements: 'multimodal_data', 'multimodal_input'")

def to_handle(self, element: str, key: Optional[str] = None) -> None:
"""Convert multimodal data to tensor handle.

Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries)
for efficient IPC. This function is a in-place operation.

Args:
element: Element to convert ("multimodal_data" or "multimodal_input")
key: Specific key to convert. If None, converts all tensor values in multimodal_data.
Defaults to None.

Example:
# Convert all tensors in multimodal_data to handles
params.to_handle("multimodal_data", key=None)

# Convert only multimodal_embedding section tensors to handles
params.to_handle("multimodal_data", key="multimodal_embedding")
"""
# Lazy import to avoid circular dependency
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer

def _to_tensor_handle(data):
for k, v in data.items():
if isinstance(v, torch.Tensor):
# Convert tensor to handle
handle = SharedTensorContainer.from_tensor(v).dump_to_dict()
data[k] = handle
elif isinstance(v, dict):
_to_tensor_handle(v)
elif isinstance(v, list):
for i, item in enumerate(v):
if isinstance(item, torch.Tensor):
handle = SharedTensorContainer.from_tensor(
item).dump_to_dict()
v[i] = handle

if element == "multimodal_data":
if self.multimodal_data is None:
return
if key is None:
_to_tensor_handle(self.multimodal_data)
else:
if key not in self.multimodal_data:
return # no-op if key not found

value = self.multimodal_data[key]
if isinstance(value, torch.Tensor):
handle = SharedTensorContainer.from_tensor(
value).dump_to_dict()
self.multimodal_data[key] = handle
elif isinstance(value, dict):
_to_tensor_handle(value)
else:
raise ValueError(
f"Unsupported value type for multimodal_data: {type(value)}"
)
elif element == "multimodal_input":
# No-op for multimodal_input
return
else:
raise ValueError(
f"Unsupported element '{element}' to convert to handle.")

def to_tensor(self, element: str, key: Optional[str] = None) -> None:
"""Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle.

Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects
for local computation. This function performs in-place modifications to the multimodal_data.

Args:
element: Element to convert ("multimodal_data" or "multimodal_input")
key: Specific key to convert. If None, converts all tensor handles in multimodal_data.
Defaults to None.

Example:
# Convert all handles back to tensors
params.to_tensor("multimodal_data", key=None)

# Convert only multimodal_embedding section handles back to tensors
params.to_tensor("multimodal_data", key="multimodal_embedding")
"""
# Lazy import to avoid circular dependency
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer

def _to_tensor(data):
for k, v in data.items():
if isinstance(v, dict) and 'method_key' in v:
# This is a tensor handle (dict with method_key)
try:
tensor = SharedTensorContainer.from_dict(
v).get_local_view()
data[k] = tensor
except Exception as e:
raise ValueError(
f"Failed to convert handle to tensor for key '{k}': {e}"
)
elif isinstance(v, dict):
_to_tensor(v)
elif isinstance(v, list):
for i, item in enumerate(v):
if isinstance(item, dict) and 'method_key' in item:
try:
tensor = SharedTensorContainer.from_dict(
item).get_local_view()
v[i] = tensor
except Exception as e:
raise ValueError(
f"Failed to convert handle to tensor in list at index {i}: {e}"
)

if element == "multimodal_data":
if self.multimodal_data is None:
return

if key is None:
_to_tensor(self.multimodal_data)
else:
if key not in self.multimodal_data:
return # no-op if key not found

value = self.multimodal_data[key]
if isinstance(
value, dict
) and 'method_key' in value: # This is a tensor handle
try:
tensor = SharedTensorContainer.from_dict(
value).get_local_view()
self.multimodal_data[key] = tensor
except Exception as e:
raise ValueError(
f"Failed to convert handle to tensor for key '{key}': {e}"
)
elif isinstance(value, dict):
_to_tensor(value)
else:
raise ValueError(
f"Unsupported value type for multimodal_data: {type(value)}"
)

elif element == "multimodal_input":
# No-op for multimodal_input
return
else:
raise ValueError(
f"Unsupported element '{element}' to convert to tensor.")

def strip_for_context(self) -> None:
"""Strip multimodal data for context processing.

Expand Down
Loading