1
1
import copy
2
2
import os
3
- from typing import List , Optional , Tuple
3
+ from typing import List , Optional , Tuple , Dict
4
4
5
5
import numpy as np
6
6
import torch
@@ -118,6 +118,128 @@ def get_num_tokens_per_image(
118
118
)
119
119
return unpadded_feature_size + newline_feature_size + base_feature_size
120
120
121
+ def _postprocess (self , input_ids , mm_features ):
122
+ # Define model specific variables here before shared logic
123
+ mm_tokens = torch .tensor ([self .model_config .image_token_index
124
+ ]).to (input_ids .device )
125
+ model_hidden_size = self .model_config .text_config .hidden_size
126
+ vocab_size = self .model_config .text_config .vocab_size
127
+ start_len = end_len = 0 # for llava, need not append start/end token around each image token
128
+ # End model specific variables
129
+
130
+ ## find mm token positions in input_ids
131
+ mm_token_positions = torch .where (torch .isin (input_ids , mm_tokens ))[0 ]
132
+ num_medias = num_mm_tokens = len (mm_token_positions )
133
+ if num_medias > 1 and isinstance (mm_features , torch .Tensor ):
134
+ mm_features = list (
135
+ mm_features .split (mm_features .shape [0 ] // num_medias ))
136
+
137
+ if isinstance (mm_features , torch .Tensor ):
138
+ # 1 prompt + 1 media
139
+ # "split" means what a single mm_token in the input_ids should represent
140
+ # image: one split --> one frame
141
+ # video: one split --> N frames
142
+ num_frames , mm_feature_length , mm_hidden_dim = mm_features .shape
143
+ mm_lengths_per_split = [mm_feature_length * num_frames ]
144
+ mm_lengths_per_frame = [mm_feature_length ]
145
+ elif isinstance (mm_features , list ):
146
+ # 1 prompt + N media
147
+ num_frames = len (mm_features ) if mm_features [0 ].dim () == 2 else sum (
148
+ [f .shape [0 ] for f in mm_features ])
149
+ mm_lengths_per_split = [
150
+ f .shape [0 ] if f .dim () == 2 else f .shape [0 ] * f .shape [1 ]
151
+ for f in mm_features
152
+ ]
153
+ mm_lengths_per_frame = [
154
+ f .shape [0 ] if f .dim () == 2 else f .shape [1 ] for f in mm_features
155
+ ]
156
+ mm_hidden_dim = mm_features [0 ].shape [- 1 ]
157
+ mm_features = torch .cat (mm_features , dim = 0 )
158
+ else :
159
+ raise ValueError (
160
+ f"Invalid multimodal features type: { type (mm_features )} " )
161
+ mm_total_length = sum (mm_lengths_per_split )
162
+ assert mm_hidden_dim == model_hidden_size , "Multimodal embedding_dim must match model hidden_size"
163
+
164
+ ## split input_ids into segments by isolating mm tokens
165
+ mm_split_positions = torch .cat (
166
+ [mm_token_positions , mm_token_positions + 1 ]).unique ()
167
+ input_ids_splits = list (input_ids .tensor_split (mm_split_positions .cpu (
168
+ ))) # len(input_ids_splits) = num_segments after mm tokens are isolated
169
+ mm_ids_splits = list (
170
+ torch .arange (vocab_size ,
171
+ vocab_size + mm_total_length ,
172
+ device = input_ids .device ).split (mm_lengths_per_split )
173
+ ) # len(mm_ids_splits) = num_mm_segments
174
+
175
+ for i , mm_ids in enumerate (mm_ids_splits ):
176
+ mm_ids = mm_ids .reshape (- 1 , mm_lengths_per_frame [i ])
177
+ mm_ids_splits [i ] = mm_ids .flatten ()
178
+
179
+ ## replace mm token ids with the expanded out-of-vocab ids
180
+ mm_split_idx = 0
181
+ for i , split in enumerate (input_ids_splits ):
182
+ if torch .isin (split , mm_tokens ).any ().item ():
183
+ input_ids_splits [i ] = mm_ids_splits [mm_split_idx ]
184
+ mm_split_idx += 1
185
+ assert mm_split_idx == len (
186
+ mm_ids_splits ), "All mm_ids_splits should be consumed"
187
+
188
+ ## concat text & mm input_ids, wrap mm feature in prompt tuning config
189
+ fused_input_ids = torch .cat (input_ids_splits ).to (
190
+ device = input_ids .device )
191
+ fused_length = len (input_ids ) + mm_total_length + num_frames * (
192
+ start_len + end_len ) - num_medias
193
+ assert len (
194
+ fused_input_ids
195
+ ) == fused_length , f"Fused input_ids length { len (fused_input_ids )} should match the sum of text and multimodal embedding lengths { fused_length } "
196
+
197
+ # [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
198
+ mm_features = mm_features .view (- 1 , mm_features .shape [- 1 ])
199
+ return fused_input_ids , mm_features
200
+
201
+
202
+ def attach_multimodal_embeddings (
203
+ self , inputs : TextPrompt ,
204
+ multimodal_embedding : Dict [str , List [torch .Tensor ]],
205
+ sampling_params : SamplingParams
206
+ ) -> Tuple [List [int ], Optional [ExtraProcessedInputs ]]:
207
+ """
208
+ Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
209
+ This method skips vision processing and works with externally provided embeddings.
210
+ It replaces/expands image placeholders in the text with appropriate tokens and prepares
211
+ the embeddings for model forward pass.
212
+ Args:
213
+ inputs: Text prompt containing image placeholders
214
+ multimodal_embedding: Dictionary containing pre-processed image embedding data
215
+ Returns:
216
+ Tuple of (token_ids, extra_processed_inputs) where:
217
+ - token_ids: List of processed token IDs with image placeholders
218
+ - extra_processed_inputs: Optional dictionary containing multimodal embeddings
219
+ """
220
+ text_prompt = inputs .get ("prompt" )
221
+ if not text_prompt :
222
+ raise ValueError ("Text prompt is required but not provided" )
223
+
224
+
225
+
226
+ if not isinstance (multimodal_embedding , dict ):
227
+ raise ValueError ("multimodal_embedding must be a dictionary" )
228
+
229
+ if 'image' not in multimodal_embedding :
230
+ raise ValueError (
231
+ "Only image modality is supported for external multimodal embedding"
232
+ )
233
+
234
+ input_ids = self .tokenizer (
235
+ text_prompt , return_tensors = "pt" ).input_ids [0 ]
236
+ mm_features = torch .stack (multimodal_embedding ['image' ])
237
+ fused_input_ids , mm_features = self ._postprocess (input_ids , mm_features )
238
+ multimodal_data = {}
239
+ multimodal_data ["multimodal_embedding" ] = mm_features
240
+ return fused_input_ids .to (torch .int32 ).tolist (), {
241
+ "multimodal_data" : multimodal_data
242
+ }
121
243
122
244
@torch .inference_mode ()
123
245
def __call__ (
@@ -158,9 +280,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
158
280
** kwargs ) -> None :
159
281
super ().__init__ ()
160
282
self .model_config = model_config
161
- pretrained_config = model_config .pretrained_config
283
+ self . pretrained_config = model_config .pretrained_config
162
284
self .device = f"cuda:{ model_config .mapping .rank } "
163
- model_path = pretrained_config ._name_or_path
285
+ model_path = self . pretrained_config ._name_or_path
164
286
165
287
# Determine the actual local path for model files
166
288
if os .path .isdir (model_path ):
@@ -200,7 +322,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
200
322
self .vision_tower = hf_vision_tower .to (self .device )
201
323
else :
202
324
vision_model_config = ModelConfig (
203
- pretrained_config = model_config .pretrained_config .vision_config ,
325
+ pretrained_config = self .pretrained_config .vision_config ,
204
326
attn_backend = "TRTLLM" )
205
327
self .vision_tower = CLIPVisionModel (vision_model_config ).to (
206
328
self .device ).to (self .dtype )
@@ -210,13 +332,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
210
332
self .mm_projector = hf_mm_projector
211
333
self .image_newline = hf_image_newline
212
334
self .vision_feature_select_strategy = getattr (
213
- model_config .pretrained_config , "vision_feature_select_strategy" ,
335
+ self .pretrained_config , "vision_feature_select_strategy" ,
214
336
"default" )
215
337
216
338
self .post_config ()
217
339
218
340
def post_config (self ):
219
- self .config = self .model_config . pretrained_config .vision_config
341
+ self .config = self .pretrained_config .vision_config
220
342
221
343
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
222
344
def pack_image_features (self ,
@@ -234,7 +356,7 @@ def pack_image_features(self,
234
356
235
357
num_patch_height , num_patch_width = get_anyres_image_grid_shape (
236
358
image_sizes [image_idx ],
237
- self .model_config . pretrained_config .image_grid_pinpoints ,
359
+ self .pretrained_config .image_grid_pinpoints ,
238
360
self .config .image_size ,
239
361
)
240
362
@@ -296,7 +418,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
296
418
image_num_patches = [
297
419
image_size_to_num_patches (
298
420
image_size = imsize ,
299
- grid_pinpoints = self .model_config . pretrained_config .image_grid_pinpoints ,
421
+ grid_pinpoints = self .pretrained_config .image_grid_pinpoints ,
300
422
patch_size = self .config .image_size ,
301
423
) for imsize in image_sizes
302
424
]
@@ -396,7 +518,13 @@ def forward(
396
518
mm_embeds = []
397
519
if len (multimodal_params ) > 0 :
398
520
if not DISAGG :
399
- mm_embeds = self .mm_encoder .forward (multimodal_params )
521
+ if multimodal_params [0 ].multimodal_data .get ("multimodal_embedding" , None ) is not None :
522
+ mm_embeds = [
523
+ multimodal_param .multimodal_data ["multimodal_embedding" ]
524
+ for multimodal_param in multimodal_params
525
+ ]
526
+ else :
527
+ mm_embeds = self .mm_encoder .forward (multimodal_params )
400
528
else :
401
529
mm_embeds = [
402
530
multimodal_param .multimodal_data ["multimodal_embedding" ]
0 commit comments