@@ -197,25 +197,25 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
197
197
device_map = kwargs .get ("device_map" , "cpu" )
198
198
use_cpu = (True if device_map == torch .device ("cpu" ) or device_map == "cpu" else False )
199
199
use_xpu = (True if device_map == torch .device ("xpu" ) or device_map == "xpu" else False )
200
-
201
- if kwargs .get ("use_llm_runtime" , None ) is not None :
202
- use_neural_speed = kwargs .pop ("use_llm_runtime" , True ) and not use_xpu
203
- logger .warning ("use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead." )
204
- elif kwargs .get ("use_neural_speed" , None ) is not None :
205
- use_neural_speed = kwargs .pop ("use_neural_speed" , True ) and not use_xpu
206
- else :
207
- config = transformers .AutoConfig .from_pretrained (pretrained_model_name_or_path ,
208
- trust_remote_code = kwargs .get ('trust_remote_code' , False ))
209
- if hasattr (config , "model_type" ) == False :
210
- logger .error ("Can't get the model_type. Please check the correct model_type" )
211
- exit (0 )
212
-
213
- if config .model_type in cls .model_type_list :
214
- logger .info ("Using Neural Speed..." )
215
- use_neural_speed = True
200
+ use_neural_speed = False
201
+ if not use_xpu :
202
+ if kwargs .get ("use_llm_runtime" , None ) is not None :
203
+ use_neural_speed = kwargs .pop ("use_llm_runtime" , True ) and not use_xpu
204
+ logger .warning ("use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead." )
205
+ elif kwargs .get ("use_neural_speed" , None ) is not None :
206
+ use_neural_speed = kwargs .pop ("use_neural_speed" , True ) and not use_xpu
216
207
else :
217
- logger .info ("Using Pytorch..." )
218
- use_neural_speed = False
208
+ config = transformers .AutoConfig .from_pretrained (pretrained_model_name_or_path , ** kwargs )
209
+ if hasattr (config , "model_type" ) == False :
210
+ logger .error ("Can't get the model_type. Please check the correct model_type" )
211
+ exit (0 )
212
+
213
+ if config .model_type in cls .model_type_list :
214
+ logger .info ("Using Neural Speed..." )
215
+ use_neural_speed = True
216
+ else :
217
+ logger .info ("Using Pytorch..." )
218
+ use_neural_speed = False
219
219
220
220
if os .path .isfile (os .path .join (pretrained_model_name_or_path , QUANT_CONFIG )):
221
221
logger .info (
0 commit comments