Skip to content

Commit 60073a7

Browse files
[None][feat] Support SharedTensor on MultimodalParams (#6254)
Signed-off-by: yechank <[email protected]>
1 parent b6baa9e commit 60073a7

File tree

5 files changed

+254
-220
lines changed

5 files changed

+254
-220
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,11 +1244,13 @@ def _prepare_tp_inputs(
12441244
multimodal_params = MultimodalParams(
12451245
multimodal_data=request.py_multimodal_data,
12461246
multimodal_runtime=py_multimodal_runtime)
1247-
multimodal_params.to_device("multimodal_data",
1248-
"cuda",
1249-
pin_memory=True)
12501247

12511248
if multimodal_params.has_content():
1249+
multimodal_params.to_device("multimodal_data",
1250+
"cuda",
1251+
pin_memory=True)
1252+
#re-assign the multimodal_data to the request after to_device for generation requests
1253+
request.py_multimodal_data = multimodal_params.multimodal_data
12521254
multimodal_params_list.append(multimodal_params)
12531255

12541256
request.py_batch_idx = request.py_seq_slot
@@ -1282,10 +1284,12 @@ def _prepare_tp_inputs(
12821284
multimodal_params = MultimodalParams(
12831285
multimodal_data=request.py_multimodal_data)
12841286
multimodal_params.strip_for_generation()
1285-
multimodal_params.to_device("multimodal_data",
1286-
"cuda",
1287-
pin_memory=True)
12881287
if multimodal_params.has_content():
1288+
multimodal_params.to_device("multimodal_data",
1289+
"cuda",
1290+
pin_memory=True)
1291+
# re-assign the multimodal_data to the request after strip_for_generation for another generation request,
1292+
request.py_multimodal_data = multimodal_params.multimodal_data
12891293
multimodal_params_list.append(multimodal_params)
12901294
extend_requests += extend_dummy_requests
12911295

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def _deduce_max_tokens(request: GenerationRequest,
487487
lora_config=lora_config,
488488
prompt_tuning_config=prompt_tuning_config,
489489
multimodal_input=multimodal_input,
490-
#NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
490+
# NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
491491
multimodal_embedding=None,
492492
mrope_config=None,
493493
logits_post_processor_name=(
@@ -503,17 +503,8 @@ def _deduce_max_tokens(request: GenerationRequest,
503503

504504
if self._is_pytorch_backend and request.multimodal_params is not None:
505505
if request.multimodal_params.multimodal_data is not None:
506-
# Convert back to tensor, as opposite to `to_handle` in `llm.generate_async`
507-
# for values with non-selected keys, it's no-op
508-
request.multimodal_params.to_tensor(
509-
"multimodal_data", key="multimodal_embedding")
510-
embedding = request.multimodal_params.multimodal_data.get(
511-
"multimodal_embedding")
512-
if embedding is not None and embedding.is_cuda:
513-
# make sure the embedding resides on the local device
514-
request.multimodal_params.multimodal_data[
515-
"multimodal_embedding"] = embedding.to("cuda")
516-
506+
# NOTE: Deserialize SharedTensor handle to actual tensor
507+
request.multimodal_params.to_tensor("multimodal_data")
517508
executor_request.py_multimodal_data = request.multimodal_params.multimodal_data
518509

519510
if self._is_pytorch_backend and request.sampling_params.logits_processor:

0 commit comments

Comments
 (0)