Skip to content

Commit 810be54

Browse files
committed
Add unit test and add back attach_emb for llava model
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 712e95e commit 810be54

File tree

3 files changed

+322
-9
lines changed

3 files changed

+322
-9
lines changed

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 137 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import os
3-
from typing import List, Optional, Tuple
3+
from typing import List, Optional, Tuple, Dict
44

55
import numpy as np
66
import torch
@@ -118,6 +118,128 @@ def get_num_tokens_per_image(
118118
)
119119
return unpadded_feature_size + newline_feature_size + base_feature_size
120120

121+
def _postprocess(self, input_ids, mm_features):
122+
# Define model specific variables here before shared logic
123+
mm_tokens = torch.tensor([self.model_config.image_token_index
124+
]).to(input_ids.device)
125+
model_hidden_size = self.model_config.text_config.hidden_size
126+
vocab_size = self.model_config.text_config.vocab_size
127+
start_len = end_len = 0 # for llava, need not append start/end token around each image token
128+
# End model specific variables
129+
130+
## find mm token positions in input_ids
131+
mm_token_positions = torch.where(torch.isin(input_ids, mm_tokens))[0]
132+
num_medias = num_mm_tokens = len(mm_token_positions)
133+
if num_medias > 1 and isinstance(mm_features, torch.Tensor):
134+
mm_features = list(
135+
mm_features.split(mm_features.shape[0] // num_medias))
136+
137+
if isinstance(mm_features, torch.Tensor):
138+
# 1 prompt + 1 media
139+
# "split" means what a single mm_token in the input_ids should represent
140+
# image: one split --> one frame
141+
# video: one split --> N frames
142+
num_frames, mm_feature_length, mm_hidden_dim = mm_features.shape
143+
mm_lengths_per_split = [mm_feature_length * num_frames]
144+
mm_lengths_per_frame = [mm_feature_length]
145+
elif isinstance(mm_features, list):
146+
# 1 prompt + N media
147+
num_frames = len(mm_features) if mm_features[0].dim() == 2 else sum(
148+
[f.shape[0] for f in mm_features])
149+
mm_lengths_per_split = [
150+
f.shape[0] if f.dim() == 2 else f.shape[0] * f.shape[1]
151+
for f in mm_features
152+
]
153+
mm_lengths_per_frame = [
154+
f.shape[0] if f.dim() == 2 else f.shape[1] for f in mm_features
155+
]
156+
mm_hidden_dim = mm_features[0].shape[-1]
157+
mm_features = torch.cat(mm_features, dim=0)
158+
else:
159+
raise ValueError(
160+
f"Invalid multimodal features type: {type(mm_features)}")
161+
mm_total_length = sum(mm_lengths_per_split)
162+
assert mm_hidden_dim == model_hidden_size, "Multimodal embedding_dim must match model hidden_size"
163+
164+
## split input_ids into segments by isolating mm tokens
165+
mm_split_positions = torch.cat(
166+
[mm_token_positions, mm_token_positions + 1]).unique()
167+
input_ids_splits = list(input_ids.tensor_split(mm_split_positions.cpu(
168+
))) # len(input_ids_splits) = num_segments after mm tokens are isolated
169+
mm_ids_splits = list(
170+
torch.arange(vocab_size,
171+
vocab_size + mm_total_length,
172+
device=input_ids.device).split(mm_lengths_per_split)
173+
) # len(mm_ids_splits) = num_mm_segments
174+
175+
for i, mm_ids in enumerate(mm_ids_splits):
176+
mm_ids = mm_ids.reshape(-1, mm_lengths_per_frame[i])
177+
mm_ids_splits[i] = mm_ids.flatten()
178+
179+
## replace mm token ids with the expanded out-of-vocab ids
180+
mm_split_idx = 0
181+
for i, split in enumerate(input_ids_splits):
182+
if torch.isin(split, mm_tokens).any().item():
183+
input_ids_splits[i] = mm_ids_splits[mm_split_idx]
184+
mm_split_idx += 1
185+
assert mm_split_idx == len(
186+
mm_ids_splits), "All mm_ids_splits should be consumed"
187+
188+
## concat text & mm input_ids, wrap mm feature in prompt tuning config
189+
fused_input_ids = torch.cat(input_ids_splits).to(
190+
device=input_ids.device)
191+
fused_length = len(input_ids) + mm_total_length + num_frames * (
192+
start_len + end_len) - num_medias
193+
assert len(
194+
fused_input_ids
195+
) == fused_length, f"Fused input_ids length {len(fused_input_ids)} should match the sum of text and multimodal embedding lengths {fused_length}"
196+
197+
# [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
198+
mm_features = mm_features.view(-1, mm_features.shape[-1])
199+
return fused_input_ids, mm_features
200+
201+
202+
def attach_multimodal_embeddings(
203+
self, inputs: TextPrompt,
204+
multimodal_embedding: Dict[str, List[torch.Tensor]],
205+
sampling_params: SamplingParams
206+
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
207+
"""
208+
Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
209+
This method skips vision processing and works with externally provided embeddings.
210+
It replaces/expands image placeholders in the text with appropriate tokens and prepares
211+
the embeddings for model forward pass.
212+
Args:
213+
inputs: Text prompt containing image placeholders
214+
multimodal_embedding: Dictionary containing pre-processed image embedding data
215+
Returns:
216+
Tuple of (token_ids, extra_processed_inputs) where:
217+
- token_ids: List of processed token IDs with image placeholders
218+
- extra_processed_inputs: Optional dictionary containing multimodal embeddings
219+
"""
220+
text_prompt = inputs.get("prompt")
221+
if not text_prompt:
222+
raise ValueError("Text prompt is required but not provided")
223+
224+
225+
226+
if not isinstance(multimodal_embedding, dict):
227+
raise ValueError("multimodal_embedding must be a dictionary")
228+
229+
if 'image' not in multimodal_embedding:
230+
raise ValueError(
231+
"Only image modality is supported for external multimodal embedding"
232+
)
233+
234+
input_ids = self.tokenizer(
235+
text_prompt, return_tensors="pt").input_ids[0]
236+
mm_features = torch.stack(multimodal_embedding['image'])
237+
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
238+
multimodal_data = {}
239+
multimodal_data["multimodal_embedding"] = mm_features
240+
return fused_input_ids.to(torch.int32).tolist(), {
241+
"multimodal_data": multimodal_data
242+
}
121243

122244
@torch.inference_mode()
123245
def __call__(
@@ -158,9 +280,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
158280
**kwargs) -> None:
159281
super().__init__()
160282
self.model_config = model_config
161-
pretrained_config = model_config.pretrained_config
283+
self.pretrained_config = model_config.pretrained_config
162284
self.device = f"cuda:{model_config.mapping.rank}"
163-
model_path = pretrained_config._name_or_path
285+
model_path = self.pretrained_config._name_or_path
164286

165287
# Determine the actual local path for model files
166288
if os.path.isdir(model_path):
@@ -200,7 +322,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
200322
self.vision_tower = hf_vision_tower.to(self.device)
201323
else:
202324
vision_model_config = ModelConfig(
203-
pretrained_config=model_config.pretrained_config.vision_config,
325+
pretrained_config=self.pretrained_config.vision_config,
204326
attn_backend="TRTLLM")
205327
self.vision_tower = CLIPVisionModel(vision_model_config).to(
206328
self.device).to(self.dtype)
@@ -210,13 +332,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
210332
self.mm_projector = hf_mm_projector
211333
self.image_newline = hf_image_newline
212334
self.vision_feature_select_strategy = getattr(
213-
model_config.pretrained_config, "vision_feature_select_strategy",
335+
self.pretrained_config, "vision_feature_select_strategy",
214336
"default")
215337

216338
self.post_config()
217339

218340
def post_config(self):
219-
self.config = self.model_config.pretrained_config.vision_config
341+
self.config = self.pretrained_config.vision_config
220342

221343
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
222344
def pack_image_features(self,
@@ -234,7 +356,7 @@ def pack_image_features(self,
234356

235357
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
236358
image_sizes[image_idx],
237-
self.model_config.pretrained_config.image_grid_pinpoints,
359+
self.pretrained_config.image_grid_pinpoints,
238360
self.config.image_size,
239361
)
240362

@@ -296,7 +418,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
296418
image_num_patches = [
297419
image_size_to_num_patches(
298420
image_size=imsize,
299-
grid_pinpoints=self.model_config.pretrained_config.image_grid_pinpoints,
421+
grid_pinpoints=self.pretrained_config.image_grid_pinpoints,
300422
patch_size=self.config.image_size,
301423
) for imsize in image_sizes
302424
]
@@ -396,7 +518,13 @@ def forward(
396518
mm_embeds = []
397519
if len(multimodal_params) > 0:
398520
if not DISAGG:
399-
mm_embeds = self.mm_encoder.forward(multimodal_params)
521+
if multimodal_params[0].multimodal_data.get("multimodal_embedding", None) is not None:
522+
mm_embeds = [
523+
multimodal_param.multimodal_data["multimodal_embedding"]
524+
for multimodal_param in multimodal_params
525+
]
526+
else:
527+
mm_embeds = self.mm_encoder.forward(multimodal_params)
400528
else:
401529
mm_embeds = [
402530
multimodal_param.multimodal_data["multimodal_embedding"]

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
175175
pytorch_backend_config.load_format = LoadFormat.VISION_ONLY
176176
# TODO: add comment and print warning here
177177
pytorch_backend_config.disable_overlap_scheduler = True
178+
# TODO: add comment here to infer it by max_num_images and image_token_sizen
179+
executor_config.max_num_tokens = 16384
178180

179181
def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
180182
if executor_config.mapping is None:
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import os
2+
import pytest
3+
import copy
4+
import json
5+
6+
from tensorrt_llm import MultimodalEncoder
7+
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
8+
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
9+
from tensorrt_llm.llmapi import KvCacheConfig
10+
from tensorrt_llm.inputs import default_multimodal_input_loader
11+
12+
example_images = [
13+
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
14+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
15+
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
16+
]
17+
18+
19+
@pytest.fixture(scope="function")
20+
def multimodal_model_config():
21+
"""Get multimodal model configuration similar to integration tests"""
22+
# You can extend this to support multiple models or get from environment
23+
model_configs = {
24+
'llava-v1.6-mistral-7b-hf': {
25+
'model_name': 'llava-v1.6-mistral-7b-hf',
26+
'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf', # HuggingFace model ID
27+
}
28+
}
29+
30+
return model_configs['llava-v1.6-mistral-7b-hf']
31+
32+
33+
@pytest.mark.parametrize("model_key", [
34+
"llava-v1.6-mistral-7b-hf",
35+
])
36+
def test_single_image_chat(model_key, multimodal_model_config):
37+
"""Test processing single image using disaggregated encoder + LLM API.
38+
39+
This test verifies that disaggregated multimodal generation produces identical
40+
results to standard multimodal generation by comparing outputs.
41+
"""
42+
# Get model configuration
43+
if model_key != "llava-v1.6-mistral-7b-hf":
44+
pytest.skip(f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now")
45+
46+
# Extract model information from config
47+
model_name = multimodal_model_config['model_name']
48+
encoder_model_dir = multimodal_model_config['hf_model_dir']
49+
50+
# Test configuration
51+
max_tokens = 64
52+
free_gpu_memory_fraction = 0.6
53+
max_batch_size = 1
54+
55+
# Test data - OpenAI chat completion format
56+
prompts = ["Describe the natural environment in the image."]
57+
media = [example_images[0]]
58+
59+
# Create OpenAI chat messages format
60+
messages_list = []
61+
for prompt, image_url in zip(prompts, media):
62+
messages = [{
63+
"role": "user",
64+
"content": [{
65+
"type": "text",
66+
"text": prompt
67+
}, {
68+
"type": "image_url",
69+
"image_url": {
70+
"url": image_url
71+
}
72+
}]
73+
}]
74+
messages_list.append(messages)
75+
76+
# Sampling configuration
77+
sampling_params = SamplingParams(max_tokens=max_tokens)
78+
kv_cache_config = KvCacheConfig(
79+
enable_block_reuse=False,
80+
free_gpu_memory_fraction=free_gpu_memory_fraction,
81+
)
82+
83+
# Step 1: Process multimodal data using disaggregated encoder
84+
encoder = None
85+
llm = None
86+
87+
try:
88+
# Step 1: Initialize encoder
89+
encoder = MultimodalEncoder(model=encoder_model_dir, max_batch_size=max_batch_size)
90+
91+
# Step 2: Initialize LLM and prepare inputs
92+
llm = LLM(
93+
model=encoder_model_dir,
94+
backend='pytorch',
95+
kv_cache_config=kv_cache_config,
96+
trust_remote_code=True
97+
)
98+
99+
# Load model configuration
100+
config_path = os.path.join(llm._hf_model_dir, 'config.json')
101+
assert os.path.exists(config_path), f"Model config not found at {config_path}"
102+
103+
with open(config_path, 'r') as f:
104+
model_config = json.load(f)
105+
model_type = model_config['model_type']
106+
107+
# Prepare multimodal inputs
108+
inputs = default_multimodal_input_loader(
109+
tokenizer=llm.tokenizer,
110+
model_dir=llm._hf_model_dir,
111+
model_type=model_type,
112+
modality="image",
113+
prompts=prompts,
114+
media=media,
115+
image_data_format="pt"
116+
)
117+
118+
# Validate inputs structure
119+
assert len(inputs) == len(prompts), f"Expected {len(prompts)} inputs, got {len(inputs)}"
120+
# Step 3: Generate reference output with raw multimodal inputs
121+
outputs_ref = llm.generate(inputs, sampling_params=sampling_params)
122+
123+
# Validate reference outputs
124+
assert outputs_ref is not None, "Reference generation returned None"
125+
assert len(outputs_ref) == len(prompts), f"Expected {len(prompts)} reference outputs, got {len(outputs_ref)}"
126+
for i, output in enumerate(outputs_ref):
127+
assert len(output.outputs) > 0, f"Reference generation has no output text for input {i}"
128+
129+
# Step 4: Prepare inputs for disaggregated multimodal generation
130+
encoder_outputs = encoder.generate(inputs)
131+
inputs = default_multimodal_input_loader(
132+
tokenizer=llm.tokenizer,
133+
model_dir=llm._hf_model_dir,
134+
model_type=model_type,
135+
modality="image",
136+
prompts=prompts,
137+
mm_embeddings=[SharedTensorContainer.from_dict(output.mm_embedding_handle).get_local_view() for output in encoder_outputs],
138+
image_data_format="pt"
139+
)
140+
141+
# Step 5: Generate output using disaggregated multimodal parameters
142+
# Note: For batch processing, we need to match mm_params with inputs
143+
outputs = llm.generate(inputs, sampling_params=sampling_params)
144+
145+
# Validate disaggregated outputs
146+
assert len(outputs) == len(prompts), f"Expected {len(prompts)} disaggregated outputs, got {len(outputs)}"
147+
for i, output in enumerate(outputs):
148+
assert len(output.outputs) > 0, f"Disaggregated generation has no output text for input {i}"
149+
150+
# Step 6: Compare outputs - they should match exactly
151+
assert len(outputs_ref) == len(outputs), f"Number of outputs don't match: {len(outputs_ref)} vs {len(outputs)}"
152+
153+
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)):
154+
# Compare prompts
155+
assert ref_output.prompt == test_output.prompt, \
156+
f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
157+
158+
# Compare number of generated outputs
159+
assert len(ref_output.outputs) == len(test_output.outputs), \
160+
f"Number of generated outputs don't match for output {i}: {len(ref_output.outputs)} vs {len(test_output.outputs)}"
161+
162+
# Compare generated text and other attributes
163+
for j, (ref_gen, test_gen) in enumerate(zip(ref_output.outputs, test_output.outputs)):
164+
assert ref_gen.text == test_gen.text, \
165+
f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}"
166+
167+
# Compare token IDs if available
168+
if hasattr(ref_gen, 'token_ids') and hasattr(test_gen, 'token_ids'):
169+
assert ref_gen.token_ids == test_gen.token_ids, \
170+
f"Token IDs don't match for output {i}, generation {j}"
171+
172+
# Compare log probabilities if available
173+
if hasattr(ref_gen, 'logprobs') and hasattr(test_gen, 'logprobs'):
174+
assert ref_gen.logprobs == test_gen.logprobs, \
175+
f"Log probabilities don't match for output {i}, generation {j}"
176+
177+
finally:
178+
# Cleanup resources
179+
if encoder is not None:
180+
del encoder
181+
if llm is not None:
182+
del llm
183+

0 commit comments

Comments
 (0)