diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9026158126..33f0de0ead 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -687,6 +687,11 @@ def aten_ops_select( @dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_slice( ctx: ConversionContext, target: Target, @@ -700,9 +705,9 @@ def aten_ops_slice( SourceIR.ATEN, name, args[0], - args[1], - args[2], - args[3], + args_bounds_check(args, 1, replacement=0), + args_bounds_check(args, 2, replacement=None), + args_bounds_check(args, 3, replacement=None), args_bounds_check(args, 4, replacement=1), ) @@ -877,6 +882,11 @@ def aten_ops_clone_copy_placeholder( @dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) def aten_ops_expand( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 724200fa2b..4f56ffbd85 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -339,8 +339,8 @@ def get_positive_dim( ) -> Union[int, Tuple[int, ...]]: """ Given an integer number or tuple that represents dimension(s) in the array, - transform it to a positive integer dim if it's negative. Otherwise, do - nothing. + transform it to a positive integer dim if it's negative. + Otherwise, truncate it to the dimension size Args: dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array. @@ -353,7 +353,8 @@ def get_positive_dim( def positive_dim(d: int) -> int: if d < 0: return d % dim_size - return d + else: + return min(d, dim_size) return ( positive_dim(dim) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 8a77508014..5c9ed2ef9c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -21,15 +21,17 @@ def slice_op( # TODO: This should be slice not whatever is in base name: str, input: TRTTensor, dim: int, - start: int, - stop: int, + start: Optional[int], + stop: Optional[int], step: int, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"slice_tensor received input {input} that is not part " - "of the TensorRT region!" - ) + # Special case for start being None + if start is None: + start = 0 + + # Special case for stop being None + if stop is None: + stop = input.shape[dim] dim = get_positive_dim(dim, len(input.shape)) start = get_positive_dim(start, input.shape[dim]) @@ -39,9 +41,6 @@ def slice_op( # TODO: This should be slice not whatever is in base # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" - if stop == 2**63 - 1: - stop = input.shape[dim] - start_slice = [0] * len(input.shape) start_slice[dim] = start stride_slice = [1] * len(input.shape) @@ -62,11 +61,6 @@ def expand( input_t: TRTTensor, shape: Shape, ) -> TRTTensor: - if not isinstance(input_t, TRTTensor): - raise RuntimeError( - f"expand received input {input_t} that is not a TensorRT ITensor" - ) - shape_rank = len(shape) initial_tensor_rank = len(input_t.shape) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a883018c5e..93ad4655b5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -10,6 +10,7 @@ from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .view_to_reshape import view_to_reshape ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ @@ -19,6 +20,7 @@ lower_efficient_attention, fuse_prims_broadcast, replace_max_pool_with_indices, + view_to_reshape, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py new file mode 100644 index 0000000000..efc836814f --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -0,0 +1,41 @@ +import logging +from typing import Callable, List, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def view_to_reshape( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" + orig, replacement = view_replacement() + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") + + return gm + + +def view_replacement() -> ( + Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], + ] +): + """Constructs the original and replacement functions for view""" + + # Original graph + def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.view.default(input, shape) + + # Replacement graph + def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.reshape.default(input, shape) + + return orig, replacement diff --git a/tests/py/dynamo/conversion/test_slice_aten.py b/tests/py/dynamo/conversion/test_slice_aten.py index 8c0d6dae42..b332fd3354 100644 --- a/tests/py/dynamo/conversion/test_slice_aten.py +++ b/tests/py/dynamo/conversion/test_slice_aten.py @@ -7,14 +7,16 @@ from .harness import DispatchTestCase -class TestSelectConverter(DispatchTestCase): +class TestSliceConverter(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 0, 0, 7, 2), - ("select_dim_start_stop_step_offset", 1, 0, 7, 2), - ("select_dim_start_stop_step_exact", 1, 0, 10, 2), - ("select_dim_start_stop_step_negatives", -3, -2, -1, 1), - ("select_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step", 0, 0, 7, 2), + ("slice_dim_start_stop_step_offset", 1, 0, 7, 2), + ("slice_dim_start_stop_step_exact", 1, 0, 10, 2), + ("slice_dim_start_stop_step_negatives", -3, -2, -1, 1), + ("slice_dim_start_stop_step_max_int", 2, 0, 2**63 - 1, 1), + ("slice_dim_start_stop_step_past_end", 2, 0, 2048, 1), + ("slice_dim_start_stop_step_none", 2, None, None, 1), ] ) def test_slice(self, _, dim, start, stop, step): @@ -32,12 +34,27 @@ def forward(self, input): input, ) + def test_slice_empty(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + ) + -class TestSelectConverterDynamicShape(DispatchTestCase): +class TestSliceConverterDynamicShape(DispatchTestCase): @parameterized.expand( [ - ("select_dim_start_stop_step", 1, 0, 7, 2), - ("select_dim_start_stop_step", 1, 0, 10, 2), + ("slice_dim_start_stop_step", 1, 0, 7, 2), + ("slice_dim_start_stop_step", 1, 0, 10, 2), ] ) def test_slice(self, _, dim, start, stop, step): diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 1bbb54192c..edbe93eddd 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -267,5 +267,70 @@ def forward(self, q, k, v): torch._dynamo.reset() +class TestLowerViewToReshape(TestCase): + def test_view_to_reshape(self): + class ViewToReshape(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.view.default(input, (1, 1, -1)) + return out + + inputs = [ + torch.rand((3, 4, 5, 32)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(ViewToReshape()) + expected_ops = {torch.ops.aten.reshape.default} + unexpected_ops = { + torch.ops.aten.view.default, + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"ViewToReshape TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests()