Skip to content

Commit b23e977

Browse files
committed
refactoring input prep again to allow out-of-tree models to work with quickstart, trtllm-bench, etc.
Signed-off-by: Rakib Hasan <[email protected]>
1 parent e968f98 commit b23e977

16 files changed

+566
-161
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from quickstart_advanced import add_llm_args, setup_llm
66

7-
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
8-
default_multimodal_input_loader)
7+
from tensorrt_llm.inputs import default_multimodal_input_loader
8+
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
9+
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
910

1011
example_medias_and_prompts = {
1112
"image": {
@@ -79,10 +80,11 @@
7980

8081

8182
def add_multimodal_args(parser):
82-
parser.add_argument("--model_type",
83-
type=str,
84-
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
85-
help="Model type.")
83+
parser.add_argument(
84+
"--model_type",
85+
type=str,
86+
choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(),
87+
help="Model type.")
8688
parser.add_argument("--modality",
8789
type=str,
8890
choices=[
@@ -108,6 +110,18 @@ def add_multimodal_args(parser):
108110
type=str,
109111
default="cpu",
110112
help="The device to have the input on.")
113+
parser.add_argument(
114+
"--custom_module_dirs",
115+
type=str,
116+
nargs="+",
117+
default=None,
118+
help=
119+
("Paths to an out-of-tree model directory which should be imported."
120+
" This is useful to load a custom model. The directory should have a structure like:"
121+
" <model_name>"
122+
" ├── __init__.py"
123+
" ├── <model_name>.py"
124+
" └── <sub_dirs>"))
111125
return parser
112126

113127

@@ -140,6 +154,15 @@ def parse_arguments():
140154

141155
def main():
142156
args = parse_arguments()
157+
if args.custom_module_dirs is not None:
158+
for custom_module_dir in args.custom_module_dirs:
159+
try:
160+
import_custom_module_from_dir(custom_module_dir)
161+
except Exception as e:
162+
print(
163+
f"Failed to import custom module from {custom_module_dir}: {e}"
164+
)
165+
raise e
143166

144167
lora_config = None
145168
if args.load_lora:
@@ -159,16 +182,19 @@ def main():
159182
model_type = args.model_type
160183
else:
161184
model_type = json.load(
162-
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
163-
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
185+
open(os.path.join(str(llm._hf_model_dir),
186+
'config.json')))['model_type']
187+
assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \
188+
f"Unsupported model_type: {model_type} found!\n" \
189+
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
164190

165191
# set prompts and media to example prompts and images if they are not provided
166192
if args.prompt is None:
167193
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
168194
if args.media is None:
169195
args.media = example_medias_and_prompts[args.modality]["media"]
170196
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
171-
model_dir=llm._hf_model_dir,
197+
model_dir=str(llm._hf_model_dir),
172198
model_type=model_type,
173199
modality=args.modality,
174200
prompts=args.prompt,

tensorrt_llm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def _add_trt_llm_dll_directory():
6262
from .sampling_params import SamplingParams
6363
from .version import __version__
6464

65+
# Lazy import to avoid circular dependency on lora_manager
66+
import tensorrt_llm._torch.models as torch_models # isort:skip
67+
6568
__all__ = [
6669
'AutoConfig',
6770
'AutoModelForCausalLM',
@@ -82,6 +85,7 @@ def _add_trt_llm_dll_directory():
8285
'default_trtnet',
8386
'precision',
8487
'net_guard',
88+
'torch_models',
8589
'Network',
8690
'Mapping',
8791
'MnnvlMemory',

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
BaseWeightMapper
1111

1212
from ..._utils import nvtx_range
13-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
13+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
14+
MultimodalPlaceholderMetadata,
15+
MultimodalPlaceholderPlacement, TextPrompt,
1416
register_input_processor)
1517
from ...logger import logger
1618
from ...sampling_params import SamplingParams
@@ -137,7 +139,13 @@ def forward(self, vision_outputs: torch.Tensor):
137139

138140

139141
@register_auto_model("Gemma3ForConditionalGeneration")
140-
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
142+
@register_input_processor(
143+
Gemma3InputProcessor,
144+
model_type="gemma3",
145+
placeholder_metadata=MultimodalPlaceholderMetadata(
146+
placeholder_map={"image": "<start_of_image>"},
147+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
148+
))
141149
class Gemma3VLM(PreTrainedModel):
142150

143151
def __init__(self, model_config: ModelConfig[Gemma3Config]):

tensorrt_llm/_torch/models/modeling_hyperclovax.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
from tensorrt_llm.inputs.multimodal import MultimodalParams
1717

18-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
18+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
19+
MultimodalPlaceholderMetadata,
20+
MultimodalPlaceholderPlacement, TextPrompt,
1921
register_input_processor)
2022
from ...logger import logger
2123
from ...sampling_params import SamplingParams
@@ -961,7 +963,23 @@ def forward(self, multimodal_params: List[MultimodalParams]):
961963

962964

963965
@register_auto_model("HCXVisionForCausalLM")
964-
@register_input_processor(HCXVisionInputProcessor, model_type="hyperclovax_vlm")
966+
@register_input_processor(
967+
HCXVisionInputProcessor,
968+
model_type="hyperclovax_vlm",
969+
placeholder_metadata=MultimodalPlaceholderMetadata(
970+
placeholder_map={
971+
"image":
972+
('<im_end>\n<|im_start|>user (mime) \n'
973+
'{"type": "image/jpeg", "filename": ""}<|im_end|>\n'
974+
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n'
975+
'<|im_start|>image/aux\n'
976+
'다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 '
977+
'keyword와 bbox 위치입니다.bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 '
978+
'형태입니다. 참고하여 답변하세요. '
979+
'{"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}')
980+
},
981+
placeholder_placement=MultimodalPlaceholderPlacement.AFTER_TEXT,
982+
))
965983
class HCXVisionForCausalLM(PreTrainedModel):
966984

967985
def __init__(self, model_config: ModelConfig):

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from tensorrt_llm.lora_manager import HfLoraLoader
2121
from tensorrt_llm.models.convert_utils import split_matrix_tp
2222

23-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
23+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
24+
MultimodalPlaceholderMetadata,
25+
MultimodalPlaceholderPlacement, TextPrompt,
2426
register_input_processor)
2527
from ...sampling_params import SamplingParams
2628
from ..attention_backend import AttentionMetadata
@@ -1168,7 +1170,13 @@ def __call__(
11681170

11691171

11701172
@register_auto_model("Llama4ForConditionalGeneration")
1171-
@register_input_processor(Llama4InputProcessor, model_type="llama4")
1173+
@register_input_processor(
1174+
Llama4InputProcessor,
1175+
model_type="llama4",
1176+
placeholder_metadata=MultimodalPlaceholderMetadata(
1177+
placeholder_map={"image": "<|image|>"},
1178+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
1179+
))
11721180
class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
11731181
Llama4Config]):
11741182

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from tensorrt_llm.inputs.multimodal import MultimodalParams
1616

17-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
17+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
18+
MultimodalPlaceholderMetadata,
19+
MultimodalPlaceholderPlacement, TextPrompt,
1820
register_input_processor)
1921
from ...llmapi.utils import download_hf_model
2022
from ...logger import logger
@@ -263,7 +265,13 @@ def forward(self, multimodal_params: List[MultimodalParams]):
263265

264266

265267
@register_auto_model("LlavaNextForConditionalGeneration")
266-
@register_input_processor(LlavaNextInputProcessor, model_type="llava_next")
268+
@register_input_processor(
269+
LlavaNextInputProcessor,
270+
model_type="llava_next",
271+
placeholder_metadata=MultimodalPlaceholderMetadata(
272+
placeholder_map={"image": "<image>"},
273+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
274+
))
267275
class LlavaNextModel(PreTrainedModel):
268276
config_class = LlavaNextConfig
269277

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
from tensorrt_llm._torch.speculative import SpecMetadata
3030
from tensorrt_llm.functional import PositionEmbeddingType
3131
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
32-
TextPrompt, register_input_processor)
32+
MultimodalPlaceholderMetadata,
33+
MultimodalPlaceholderPlacement, TextPrompt,
34+
register_input_processor)
3335
from tensorrt_llm.llmapi import SamplingParams
3436
from tensorrt_llm.logger import logger
3537

@@ -269,8 +271,20 @@ def __call__(
269271

270272

271273
@register_auto_model("Mistral3ForConditionalGeneration")
272-
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
273-
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
274+
@register_input_processor(
275+
Mistral3InputProcessor,
276+
model_type="mistral3",
277+
placeholder_metadata=MultimodalPlaceholderMetadata(
278+
placeholder_map={
279+
"image": "[IMG]",
280+
},
281+
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
282+
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
283+
# src/mistral_common/tokens/tokenizers/base.py#L326
284+
# However, accuracy tests show that the model generates higher quality output when the image
285+
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
286+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
287+
))
274288
class Mistral3VLM(PreTrainedModel):
275289
"""Mistral3VLM implementation for TRTLLM.
276290

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from PIL import Image
1111

1212
from ...executor.request import LoRARequest
13-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
13+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
14+
MultimodalPlaceholderMetadata,
15+
MultimodalPlaceholderPlacement, TextPrompt,
1416
register_input_processor)
1517
from ...logger import logger
1618
from ...lora_manager import LoraConfig
@@ -138,7 +140,17 @@ def __call__(
138140

139141

140142
@register_auto_model("Phi4MMForCausalLM")
141-
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
143+
@register_input_processor(
144+
Phi4MMInputProcessor,
145+
model_type="phi4mm",
146+
placeholder_metadata=MultimodalPlaceholderMetadata(
147+
placeholder_map={
148+
"image": "<|image_{0}|>",
149+
"audio": "<|audio_{0}|>",
150+
},
151+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
152+
placeholders_separator="",
153+
))
142154
class Phi4MMForCausalLM(transformers.PreTrainedModel):
143155

144156
_supports_flash_attn_2 = True

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from ..._utils import nvtx_range_debug
1414
from ...functional import RopeEmbeddingUtils, RotaryScalingType
15-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
15+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
16+
MultimodalPlaceholderMetadata,
17+
MultimodalPlaceholderPlacement, TextPrompt,
1618
register_input_processor)
1719
from ...logger import logger
1820
from ...sampling_params import SamplingParams
@@ -645,7 +647,16 @@ def forward(
645647

646648

647649
@register_auto_model("Qwen2VLForConditionalGeneration")
648-
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl")
650+
@register_input_processor(
651+
Qwen2VLInputProcessorBase,
652+
model_type="qwen2_vl",
653+
placeholder_metadata=MultimodalPlaceholderMetadata(
654+
placeholder_map={
655+
"image": "<|vision_start|><|image_pad|><|vision_end|>",
656+
"video": "<|vision_start|><|video_pad|><|vision_end|>"
657+
},
658+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
659+
))
649660
class Qwen2VLModel(Qwen2VLModelBase):
650661

651662
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
@@ -657,7 +668,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
657668

658669

659670
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
660-
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl")
671+
@register_input_processor(
672+
Qwen2VLInputProcessorBase,
673+
model_type="qwen2_5_vl",
674+
placeholder_metadata=MultimodalPlaceholderMetadata(
675+
placeholder_map={
676+
"image": "<|vision_start|><|image_pad|><|vision_end|>",
677+
"video": "<|vision_start|><|video_pad|><|vision_end|>"
678+
}))
661679
class Qwen2_5_VLModel(Qwen2VLModelBase):
662680

663681
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,

tensorrt_llm/_torch/models/modeling_vila.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
PreTrainedModel)
3636

3737
from ..._utils import nvtx_range
38-
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
38+
from ...inputs import (ExtraProcessedInputs, InputProcessor,
39+
MultimodalPlaceholderMetadata,
40+
MultimodalPlaceholderPlacement, TextPrompt,
3941
register_input_processor)
4042
from ...logger import logger
4143
from ...sampling_params import SamplingParams
@@ -1118,7 +1120,16 @@ def __call__(
11181120

11191121

11201122
@register_auto_model(VilaConfig.model_architecture)
1121-
@register_input_processor(VilaInputProcessor, model_type="llava_llama")
1123+
@register_input_processor(
1124+
VilaInputProcessor,
1125+
model_type="llava_llama",
1126+
placeholder_metadata=MultimodalPlaceholderMetadata(
1127+
placeholder_map={
1128+
"image": "<image>",
1129+
"video": "<vila/video>"
1130+
},
1131+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
1132+
))
11221133
class VilaModel(PreTrainedModel):
11231134
config_class = VilaConfig
11241135

0 commit comments

Comments
 (0)