Skip to content

Commit f848e9d

Browse files
committed
Skip casting model inputs to fp32 if weights and inputs are all fp16
1 parent 0e292a0 commit f848e9d

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

coremltools/converters/mil/frontend/torch/converter.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,23 @@ def __init__(
562562
self.opset_version = _target(opset_version) if opset_version is not None else None
563563
self._prog = mil.Program()
564564

565+
self.src_model_has_all_fp16_weights = False
566+
565567
if isinstance(loaded_model, torch.jit.ScriptModule):
568+
# src_model_has_all_fp16_weights will be True
569+
# if there are more than one trainable layers in the model
570+
# and if all those trainable layers have the fp16 dtype
571+
# eg: if pytorch_model.half() has been explicitly used.
572+
num_trainable_layers = 0
573+
num_trainable_fp16_layers = 0
574+
for param in loaded_model.parameters():
575+
if param.requires_grad:
576+
num_trainable_layers += 1
577+
if param.dtype == torch.float16:
578+
num_trainable_fp16_layers += 1
579+
if num_trainable_layers > 0:
580+
self.src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
581+
566582
self.context = TranscriptionContext(frontend=TorchFrontend.TORCHSCRIPT)
567583
self.graph = InternalTorchIRGraph.from_torchscript(
568584
torchscript=loaded_model, inputs=self.inputs, cut_at_symbols=cut_at_symbols
@@ -1140,6 +1156,11 @@ def convert(self) -> Program:
11401156
user_names = list(ssa_func_inputs.keys())
11411157
internal_names = list(self.graph.inputs.keys())
11421158
internal_names.extend(user_names[len(internal_names) :])
1159+
input_dtypes = []
1160+
for torch_name, ssa_name in zip(internal_names, user_names):
1161+
input_var = ssa_func.inputs[ssa_name]
1162+
input_dtypes.append(input_var.dtype)
1163+
all_fp16_inputs = all(x == types.fp16 for x in input_dtypes)
11431164
for torch_name, ssa_name in zip(internal_names, user_names):
11441165
input_var = ssa_func.inputs[ssa_name]
11451166
if self.context.frontend == TorchFrontend.TORCHSCRIPT:
@@ -1151,7 +1172,7 @@ def convert(self) -> Program:
11511172
# So here we perform the "cast input to fp32" step
11521173
if (
11531174
types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)
1154-
) and input_var.dtype == types.fp16:
1175+
) and input_var.dtype == types.fp16 and not (all_fp16_inputs and self.src_model_has_all_fp16_weights):
11551176
# This cast should have placeholder scope
11561177
with mb.scope(
11571178
ScopeInfo(

coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,30 @@ def forward(self, x, y):
11911191
result[name], expected.detach().numpy(), rtol=rtol, atol=atol
11921192
)
11931193

1194+
@staticmethod
1195+
@pytest.mark.parametrize(
1196+
"backend",
1197+
backends,
1198+
)
1199+
def test_torch_fp16_model_with_fp16_inputs(torch_model, backend):
1200+
if backend[0] == "neuralnetwork":
1201+
pytest.skip(
1202+
"Input float16 needs target >= iOS16, which doesn't support neuralnetwork."
1203+
)
1204+
traced_torch_model = torch.jit.trace(torch_model.half(), torch.rand(1, 10).half())
1205+
ct.convert(
1206+
traced_torch_model,
1207+
source="pytorch",
1208+
inputs=[
1209+
ct.TensorType(
1210+
shape=(1, 10),
1211+
)
1212+
],
1213+
outputs=[ct.TensorType(dtype=np.float16)],
1214+
convert_to=backend[0],
1215+
minimum_deployment_target=ct.target.macOS13,
1216+
)
1217+
11941218

11951219
@pytest.fixture
11961220
def int32_input_model():

0 commit comments

Comments
 (0)