Skip to content

Commit 037dc22

Browse files
committed
Fix Issue apple#2583: Dynamic padding in torch.nn.functional.pad
Modified _array_construct to handle dynamic padding values: Creates proper Var objects using mb.concat instead of Python lists + Fixes AttributeError when converting models with x.size(-1) padding
1 parent be33582 commit 037dc22

File tree

2 files changed

+231
-1
lines changed

2 files changed

+231
-1
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -903,10 +903,54 @@ def _array_construct(context, node, array_type):
903903
const = mb.const(val=val, name=node.name)
904904
context.add(const)
905905
else:
906+
# Previously:
906907
# If at least one input to the construct op is non-const, collect
907908
# the inputs and add them directly to the context. Ops that use this
908909
# node's output will take the list directly as input.
909-
context.add(array_type(inputs), node.name)
910+
#
911+
# Previous code inside the else condition:
912+
# context.add(array_type(inputs), node.name)
913+
914+
# Fix for Issue #2583: Dynamic padding values in torch.nn.functional.pad
915+
# GitHub Link: https://github.com/apple/coremltools/issues/2583
916+
#
917+
# Verified with:
918+
# - coremltools 8.3.0
919+
# - torch 2.6.0
920+
#
921+
# Problem:
922+
# When padding values contain runtime-determined sizes (e.g., x.size(-1)),
923+
# the original code would return a Python list instead of the Var object.
924+
# This breaks downstream operations like pad() which expect a Var with
925+
# a .val attribute, causing: AttributeError: 'list' object has no attribute 'val'
926+
# See detailed analysis in issue #2583 for MIL representation differences
927+
# between static and dynamic padding cases.
928+
#
929+
# Root cause:
930+
# The condition `inp.can_be_folded_to_const()` returns False for dynamic
931+
# values, causing len(scalar_inputs) != len(inputs). This triggered the
932+
# else branch which added a raw Python list to the context instead of
933+
# creating a proper Var object.
934+
#
935+
# Solution:
936+
# Create a proper Var object using mb.concat that can handle both
937+
# constant and dynamic values, maintaining the expected Var interface
938+
# throughout the conversion pipeline.
939+
940+
# Convert all inputs to tensors (scalars become 1D tensors)
941+
tensor_inputs = []
942+
for inp in inputs:
943+
if len(inp.shape) == 0: # It is a scalar object
944+
# Convert scalar to 1D tensor
945+
tensor_inp = mb.expand_dims(x=inp, axes=[0])
946+
tensor_inputs.append(tensor_inp)
947+
else:
948+
tensor_inputs.append(inp)
949+
950+
# Concatenate into a single tensor Var with shape (n,)
951+
# This creates a proper Var that will contain the correct values at runtime
952+
stacked_var = mb.concat(values=tensor_inputs, axis=0, name=node.name)
953+
context.add(stacked_var)
910954

911955

912956
@register_torch_op
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) 2020, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at
5+
# https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
# Test suite for dynamic padding conversion (Issue #2583)
8+
# These tests verify the fix for converting PyTorch pad operations with
9+
# runtime-determined padding values to Core ML.
10+
# The issue occurred in _translate_torch_args() when handling
11+
# dynamic padding values like (1, x.size(-1)).
12+
13+
import pytest
14+
from coremltools._deps import _HAS_TORCH
15+
import numpy as np
16+
17+
# Check if pytorch module is installed
18+
# Also, check if pytorch and coremltools' versions are compatible for this test
19+
if _HAS_TORCH:
20+
import torch
21+
import coremltools as ct
22+
23+
# get package versions
24+
torch_major = int(torch.__version__.split('.')[0])
25+
ct_version_parts = ct.__version__.split('.')
26+
ct_major = int(ct_version_parts[0])
27+
28+
# Run only on PyTorch 2.x and coremltools >= 8.x
29+
_TORCH_COMPATIBLE = torch_major >= 2
30+
_CT_COMPATIBLE = ct_major >= 8
31+
_VERSIONS_COMPATIBLE = _TORCH_COMPATIBLE and _CT_COMPATIBLE
32+
else:
33+
_VERSIONS_COMPATIBLE = False
34+
35+
36+
@pytest.mark.skipif(not _HAS_TORCH, reason="PyTorch not found")
37+
@pytest.mark.skipif(not _VERSIONS_COMPATIBLE, reason="Incompatible versions")
38+
class TestPadDynamicFix:
39+
"""
40+
Test dynamic padding fix for Issue #2583 - torch.nn.functional.pad
41+
with x.size(-1)
42+
"""
43+
44+
@staticmethod
45+
@pytest.mark.parametrize(
46+
"input_size, pad_fn, expected_size, test_name",
47+
[
48+
# Dynamic padding tests
49+
(3, lambda x: (1, x.size(-1)), 7, "dynamic_right"),
50+
(5, lambda x: (0, x.size(-1)), 10, "dynamic_right_only"),
51+
(4, lambda x: (x.size(-1), 0), 8, "dynamic_left_only"),
52+
(2, lambda x: (x.size(-1), x.size(-1)), 6, "both_dynamic"),
53+
]
54+
)
55+
def test_dynamic_padding(input_size, pad_fn, expected_size, test_name):
56+
"""
57+
Test dynamic padding cases where pad values depend on input size
58+
"""
59+
class TestModel(torch.nn.Module):
60+
def forward(self, x):
61+
return torch.nn.functional.pad(x, pad_fn(x))
62+
63+
model = TestModel()
64+
example = torch.rand(input_size)
65+
traced = torch.jit.trace(model, example)
66+
67+
mlmodel = ct.convert(
68+
traced,
69+
inputs=[ct.TensorType(
70+
shape=ct.EnumeratedShapes(
71+
shapes=[[2], [3], [4], [5], [input_size]],
72+
default=[input_size],
73+
),
74+
dtype=np.float32,
75+
name="input"
76+
)],
77+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
78+
convert_to="mlprogram"
79+
)
80+
81+
result = mlmodel.predict({"input": example.numpy()})
82+
assert result["output"].shape[0] == expected_size, \
83+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
84+
f"got {result['output'].shape}"
85+
86+
@staticmethod
87+
@pytest.mark.parametrize(
88+
"input_size,pad_fn,expected_size,test_name",
89+
[
90+
# Constant padding tests (regression test)
91+
(3, lambda x: (1, 2), 6, "both_constant"),
92+
(4, lambda x: (0, 3), 7, "constant_right_only"),
93+
(5, lambda x: (2, 0), 7, "constant_left_only"),
94+
(2, lambda x: (3, 4), 9, "large_constants"),
95+
]
96+
)
97+
def test_constant_padding(input_size, pad_fn, expected_size, test_name):
98+
"""
99+
Test constant padding cases - regression test
100+
"""
101+
class TestModel(torch.nn.Module):
102+
def forward(self, x):
103+
return torch.nn.functional.pad(x, pad_fn(x))
104+
105+
model = TestModel()
106+
example = torch.rand(input_size)
107+
traced = torch.jit.trace(model, example)
108+
109+
mlmodel = ct.convert(
110+
traced,
111+
inputs=[ct.TensorType(
112+
shape=ct.EnumeratedShapes(
113+
shapes=[[2], [3], [4], [5], [input_size]],
114+
default=[input_size],
115+
),
116+
dtype=np.float32,
117+
name="input"
118+
)],
119+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
120+
convert_to="mlprogram"
121+
)
122+
123+
result = mlmodel.predict({"input": example.numpy()})
124+
output = result["output"]
125+
126+
# Verify shape
127+
assert output.shape[0] == expected_size, \
128+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
129+
f"got {output.shape}"
130+
131+
# Verify padding values are zeros
132+
pad_config = pad_fn(example)
133+
left_pad, right_pad = pad_config
134+
135+
if left_pad > 0:
136+
assert np.allclose(output[:left_pad], 0.0), \
137+
f"Test '{test_name}' failed: left padding should be zeros"
138+
139+
assert np.allclose(
140+
output[left_pad:left_pad+input_size], example.numpy()
141+
), \
142+
f"Test '{test_name}' failed: original values not preserved"
143+
144+
if right_pad > 0:
145+
assert np.allclose(output[-right_pad:], 0.0), \
146+
f"Test '{test_name}' failed: right padding should be zeros"
147+
148+
@staticmethod
149+
@pytest.mark.parametrize(
150+
"input_size,pad_fn,expected_size,test_name",
151+
[
152+
# Mixed padding tests
153+
(3, lambda x: (2, x.size(-1)), 8, "constant_left_dynamic_right"),
154+
(4, lambda x: (x.size(-1), 3), 11, "dynamic_left_constant_right"),
155+
]
156+
)
157+
def test_mixed_padding(input_size, pad_fn, expected_size, test_name):
158+
"""
159+
Test mixed padding cases with both constant and dynamic values
160+
"""
161+
class TestModel(torch.nn.Module):
162+
def forward(self, x):
163+
return torch.nn.functional.pad(x, pad_fn(x))
164+
165+
model = TestModel()
166+
example = torch.rand(input_size)
167+
traced = torch.jit.trace(model, example)
168+
169+
mlmodel = ct.convert(
170+
traced,
171+
inputs=[ct.TensorType(
172+
shape=ct.EnumeratedShapes(
173+
shapes=[[2], [3], [4], [5], [input_size]],
174+
default=[input_size],
175+
),
176+
dtype=np.float32,
177+
name="input"
178+
)],
179+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
180+
convert_to="mlprogram"
181+
)
182+
183+
result = mlmodel.predict({"input": example.numpy()})
184+
assert result["output"].shape[0] == expected_size, \
185+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
186+
f"got {result['output'].shape}"

0 commit comments

Comments
 (0)