diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 14c25ec8ab..8846497348 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -668,14 +668,19 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( - torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1]) + torch.ops.aten.split.Tensor, + capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( - torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1]) + torch.ops.aten.split.sizes, + capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, capability_validator=has_static_shapes_in_args([1]), + supports_dynamic_shapes=True, ) def aten_ops_split( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_split_aten.py b/tests/py/dynamo/conversion/test_split_aten.py index 142f9b337c..aa26340452 100644 --- a/tests/py/dynamo/conversion/test_split_aten.py +++ b/tests/py/dynamo/conversion/test_split_aten.py @@ -119,6 +119,7 @@ def forward(self, input): @parameterized.expand( [ ("select_split_size_or_sections_dim_dynamic_shape", 2, 1), + ("select_split_size_or_sections_non_divisible_dim_dynamic_shape", 3, 1), ] ) def test_split_dynamic(self, _, split_size_or_tensor, dim): @@ -132,9 +133,37 @@ def forward(self, input): input_specs = [ Input( - shape=(1, 10, -1), dtype=torch.float32, - shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + min_shape=[1, 10, 1], + opt_shape=[1, 10, 10], + max_shape=[1, 10, 10], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + ) + + @parameterized.expand( + [ + ("select_split_size_or_sections_dim_dynamic_shape_on_first_axis", 2, 1), + ] + ) + def test_split_dynamic_first_axis_dynamic(self, _, split_size_or_tensor, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.split.Tensor(input, split_size_or_tensor, dim) + return out + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=[1, 10, 10], + opt_shape=[3, 10, 10], + max_shape=[5, 10, 10], ), ] self.run_test_with_dynamic_shape(