Skip to content

Commit fee840b

Browse files
peri044cehongwang
andauthored
chore: Dynamic support for split (#2871) into main (#2914)
Co-authored-by: cehongwang <[email protected]>
1 parent a8a0797 commit fee840b

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,19 @@ def aten_ops_softmax(
668668

669669

670670
@dynamo_tensorrt_converter(
671-
torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1])
671+
torch.ops.aten.split.Tensor,
672+
capability_validator=has_static_shapes_in_args([1]),
673+
supports_dynamic_shapes=True,
672674
)
673675
@dynamo_tensorrt_converter(
674-
torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1])
676+
torch.ops.aten.split.sizes,
677+
capability_validator=has_static_shapes_in_args([1]),
678+
supports_dynamic_shapes=True,
675679
)
676680
@dynamo_tensorrt_converter(
677681
torch.ops.aten.split_with_sizes.default,
678682
capability_validator=has_static_shapes_in_args([1]),
683+
supports_dynamic_shapes=True,
679684
)
680685
def aten_ops_split(
681686
ctx: ConversionContext,

tests/py/dynamo/conversion/test_split_aten.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def forward(self, input):
119119
@parameterized.expand(
120120
[
121121
("select_split_size_or_sections_dim_dynamic_shape", 2, 1),
122+
("select_split_size_or_sections_non_divisible_dim_dynamic_shape", 3, 1),
122123
]
123124
)
124125
def test_split_dynamic(self, _, split_size_or_tensor, dim):
@@ -132,9 +133,37 @@ def forward(self, input):
132133

133134
input_specs = [
134135
Input(
135-
shape=(1, 10, -1),
136136
dtype=torch.float32,
137-
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
137+
min_shape=[1, 10, 1],
138+
opt_shape=[1, 10, 10],
139+
max_shape=[1, 10, 10],
140+
),
141+
]
142+
self.run_test_with_dynamic_shape(
143+
TestModule(),
144+
input_specs,
145+
)
146+
147+
@parameterized.expand(
148+
[
149+
("select_split_size_or_sections_dim_dynamic_shape_on_first_axis", 2, 1),
150+
]
151+
)
152+
def test_split_dynamic_first_axis_dynamic(self, _, split_size_or_tensor, dim):
153+
class TestModule(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
157+
def forward(self, input):
158+
out = torch.ops.aten.split.Tensor(input, split_size_or_tensor, dim)
159+
return out
160+
161+
input_specs = [
162+
Input(
163+
dtype=torch.float32,
164+
min_shape=[1, 10, 10],
165+
opt_shape=[3, 10, 10],
166+
max_shape=[5, 10, 10],
138167
),
139168
]
140169
self.run_test_with_dynamic_shape(

0 commit comments

Comments
 (0)