@@ -562,7 +562,23 @@ def __init__(
562
562
self .opset_version = _target (opset_version ) if opset_version is not None else None
563
563
self ._prog = mil .Program ()
564
564
565
+ self .src_model_has_all_fp16_weights = False
566
+
565
567
if isinstance (loaded_model , torch .jit .ScriptModule ):
568
+ # src_model_has_all_fp16_weights will be True
569
+ # if there are more than one trainable layers in the model
570
+ # and if all those trainable layers have the fp16 dtype
571
+ # eg: if pytorch_model.half() has been explicitly used.
572
+ num_trainable_layers = 0
573
+ num_trainable_fp16_layers = 0
574
+ for param in loaded_model .parameters ():
575
+ if param .requires_grad :
576
+ num_trainable_layers += 1
577
+ if param .dtype == torch .float16 :
578
+ num_trainable_fp16_layers += 1
579
+ if num_trainable_layers > 0 :
580
+ self .src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
581
+
566
582
self .context = TranscriptionContext (frontend = TorchFrontend .TORCHSCRIPT )
567
583
self .graph = InternalTorchIRGraph .from_torchscript (
568
584
torchscript = loaded_model , inputs = self .inputs , cut_at_symbols = cut_at_symbols
@@ -1140,6 +1156,11 @@ def convert(self) -> Program:
1140
1156
user_names = list (ssa_func_inputs .keys ())
1141
1157
internal_names = list (self .graph .inputs .keys ())
1142
1158
internal_names .extend (user_names [len (internal_names ) :])
1159
+ input_dtypes = []
1160
+ for torch_name , ssa_name in zip (internal_names , user_names ):
1161
+ input_var = ssa_func .inputs [ssa_name ]
1162
+ input_dtypes .append (input_var .dtype )
1163
+ all_fp16_inputs = all (x == types .fp16 for x in input_dtypes )
1143
1164
for torch_name , ssa_name in zip (internal_names , user_names ):
1144
1165
input_var = ssa_func .inputs [ssa_name ]
1145
1166
if self .context .frontend == TorchFrontend .TORCHSCRIPT :
@@ -1151,7 +1172,7 @@ def convert(self) -> Program:
1151
1172
# So here we perform the "cast input to fp32" step
1152
1173
if (
1153
1174
types .is_tensor (input_var .sym_type ) or types .is_scalar (input_var .sym_type )
1154
- ) and input_var .dtype == types .fp16 :
1175
+ ) and input_var .dtype == types .fp16 and not ( all_fp16_inputs and self . src_model_has_all_fp16_weights ) :
1155
1176
# This cast should have placeholder scope
1156
1177
with mb .scope (
1157
1178
ScopeInfo (
0 commit comments