4
4
5
5
from quickstart_advanced import add_llm_args , setup_llm
6
6
7
- from tensorrt_llm .inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS ,
8
- default_multimodal_input_loader )
7
+ from tensorrt_llm .inputs import default_multimodal_input_loader
8
+ from tensorrt_llm .inputs .registry import MULTIMODAL_PLACEHOLDER_REGISTRY
9
+ from tensorrt_llm .tools .importlib_utils import import_custom_module_from_dir
9
10
10
11
example_medias_and_prompts = {
11
12
"image" : {
79
80
80
81
81
82
def add_multimodal_args (parser ):
82
- parser .add_argument ("--model_type" ,
83
- type = str ,
84
- choices = ALL_SUPPORTED_MULTIMODAL_MODELS ,
85
- help = "Model type." )
83
+ parser .add_argument (
84
+ "--model_type" ,
85
+ type = str ,
86
+ choices = MULTIMODAL_PLACEHOLDER_REGISTRY .get_registered_model_types (),
87
+ help = "Model type." )
86
88
parser .add_argument ("--modality" ,
87
89
type = str ,
88
90
choices = [
@@ -108,6 +110,18 @@ def add_multimodal_args(parser):
108
110
type = str ,
109
111
default = "cpu" ,
110
112
help = "The device to have the input on." )
113
+ parser .add_argument (
114
+ "--custom_module_dirs" ,
115
+ type = str ,
116
+ nargs = "+" ,
117
+ default = None ,
118
+ help =
119
+ ("Paths to an out-of-tree model directory which should be imported."
120
+ " This is useful to load a custom model. The directory should have a structure like:"
121
+ " <model_name>"
122
+ " ├── __init__.py"
123
+ " ├── <model_name>.py"
124
+ " └── <sub_dirs>" ))
111
125
return parser
112
126
113
127
@@ -140,6 +154,15 @@ def parse_arguments():
140
154
141
155
def main ():
142
156
args = parse_arguments ()
157
+ if args .custom_module_dirs is not None :
158
+ for custom_module_dir in args .custom_module_dirs :
159
+ try :
160
+ import_custom_module_from_dir (custom_module_dir )
161
+ except Exception as e :
162
+ print (
163
+ f"Failed to import custom module from { custom_module_dir } : { e } "
164
+ )
165
+ raise e
143
166
144
167
lora_config = None
145
168
if args .load_lora :
@@ -159,16 +182,19 @@ def main():
159
182
model_type = args .model_type
160
183
else :
161
184
model_type = json .load (
162
- open (os .path .join (llm ._hf_model_dir , 'config.json' )))['model_type' ]
163
- assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS , f"Unsupported model_type: { model_type } "
185
+ open (os .path .join (str (llm ._hf_model_dir ),
186
+ 'config.json' )))['model_type' ]
187
+ assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY .get_registered_model_types (), \
188
+ f"Unsupported model_type: { model_type } found!\n " \
189
+ f"Supported types: { MULTIMODAL_PLACEHOLDER_REGISTRY .get_registered_model_types ()} "
164
190
165
191
# set prompts and media to example prompts and images if they are not provided
166
192
if args .prompt is None :
167
193
args .prompt = example_medias_and_prompts [args .modality ]["prompt" ]
168
194
if args .media is None :
169
195
args .media = example_medias_and_prompts [args .modality ]["media" ]
170
196
inputs = default_multimodal_input_loader (tokenizer = llm .tokenizer ,
171
- model_dir = llm ._hf_model_dir ,
197
+ model_dir = str ( llm ._hf_model_dir ) ,
172
198
model_type = model_type ,
173
199
modality = args .modality ,
174
200
prompts = args .prompt ,
0 commit comments