From 892f7d007d894e5e0d7b1b37b33aa97e01998191 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 25 Aug 2023 10:38:47 -0700 Subject: [PATCH] fix: Add special cases where input of graph is output - TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned - The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs - Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly - Coalesce implementations of clone and to_copy, which are effectively the same operator - Add test cases to validate new behavior - Add new boilerplate converter validator utility to support this case --- .../dynamo/conversion/aten_ops_converters.py | 84 ++++++++++++++----- .../dynamo/conversion/converter_utils.py | 37 ++++++-- .../dynamo/conversion/impl/cast.py | 40 ++++----- .../conversion/impl/elementwise/base.py | 32 ++----- tests/py/dynamo/conversion/test_casts.py | 27 ++++++ 5 files changed, 153 insertions(+), 67 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 42d6165256..dac526c7e0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,10 +1,13 @@ import logging -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion.converter_utils import ( + is_only_operator_on_placeholder, +) from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from .converter_registry import dynamo_tensorrt_converter @@ -441,29 +444,59 @@ def aten_ops_permute( ) -def to_copy_dtype_validator(to_copy_node: Node) -> bool: - allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} - - # Validate input node has convertible kwargs - if "dtype" in to_copy_node.kwargs: - if to_copy_node.kwargs["dtype"] in allowed_casts: - return True +def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]: + """Return validator for to_copy node with placeholder restrictions""" + + def validate_dtype(to_copy_node: Node) -> bool: + """Returns true if the to_copy node can be converted to TRT + + Based on data type being casted to + """ + allowed_casts = { + torch.float, + torch.int32, + torch.bool, + torch.int8, + torch.float16, + } + + # Validate input node has convertible kwargs + if "dtype" in to_copy_node.kwargs: + if to_copy_node.kwargs["dtype"] in allowed_casts: + return True + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + ) + return False else: _LOGGER.debug( - f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" ) return False - else: - _LOGGER.debug( - f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" + + def validator(to_copy_node: Node) -> bool: + """Returns true if the to_copy node can be converted to TRT + and the placeholder restriction is satisfied + """ + # The placeholder restriction is satsfied if placeholder_only is the same + # truth value as is_only_operator_on_placeholder(to_copy_node) + return validate_dtype(to_copy_node) and ( + (not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node) ) - return False + + return validator @dynamo_tensorrt_converter( - torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator + torch.ops.aten.clone.default, + capability_validator=lambda node: not is_only_operator_on_placeholder(node), ) # type: ignore[misc] -def aten_ops_to_copy_dtype( +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, + capability_validator=to_copy_dtype_validator(placeholder_only=False), +) # type: ignore[misc] +def aten_ops_clone_copy_dtype( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], @@ -476,24 +509,37 @@ def aten_ops_to_copy_dtype( SourceIR.ATEN, name, args[0], - kwargs["dtype"], + kwargs.get("dtype", args[0].dtype), + force_layer=False, ) -@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc] -def aten_ops_clone( +@dynamo_tensorrt_converter( + torch.ops.aten.clone.default, + capability_validator=is_only_operator_on_placeholder, +) # type: ignore[misc] +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, + capability_validator=to_copy_dtype_validator(placeholder_only=True), +) # type: ignore[misc] +def aten_ops_clone_copy_placeholder( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.cast.clone( + # For clone or copy nodes where the input is also the output, + # we need to force cast to ensure a layer is added to the TRT engine + # since TRT engine inputs cannot also be TRT engine outputs + return impl.cast.to_copy( network, target, SourceIR.ATEN, name, args[0], + kwargs.get("dtype", args[0].dtype), + force_layer=True, ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 1d8dfecf3b..99cf2fa85a 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -45,6 +45,20 @@ def get_node_name(node: torch.fx.Node) -> str: return node_name +def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: + """Detects whether a call_function node is the only operator on a placeholder""" + # Returns true if the node operates on a placeholder and is a direct output + return ( + node.op == "call_function" + and any( + arg.op == "placeholder" + for arg in node.args + if isinstance(arg, torch.fx.Node) + ) + and any(user.op == "output" for user in list(node.users.keys())) + ) + + def dynamic_unsupported(node: torch.fx.Node) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( @@ -52,12 +66,17 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: ), "Inputs to validator functions must be FX Nodes" # Check node value itself - if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False): + if ("val" in node.meta) and getattr( + node.meta["val"], "_has_symbolic_sizes_strides", False + ): return False # Check node arguments individually if any( - getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) + ( + ("val" in arg.meta) + and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False) + ) for arg in node.args if isinstance(arg, torch.fx.Node) ): @@ -65,7 +84,10 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: # Check node keyword arguments individually if any( - getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) + ( + ("val" in kwarg.meta) + and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False) + ) for kwarg in node.kwargs.values() if isinstance(kwarg, torch.fx.Node) ): @@ -82,9 +104,12 @@ def cast_trt_tensor( target: Target = "", source_ir: Optional[SourceIR] = None, ) -> TRTTensor: - """ - Given a TRT Tensor, convert that Tensor to the specified dtype + """Given a TRT Tensor, convert that Tensor to the specified dtype + Adds an Identity layer to the network which performs the conversion + if the input's dtype is different from the cast type. Otherwise returns + input unchanged + Args: network (TRTNetwork): A TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type @@ -191,7 +216,7 @@ def extend_attr_to_tuple( if isinstance(val, tuple): return val else: - raise AssertionError(f"Could not extend attribute {val}") + raise AssertionError(f"Object {val} could not be extended to tuple") def cast_int_or_float_to_bool( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 0c55731169..f31fd9a396 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -3,7 +3,12 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import ( + Frameworks, + unified_dtype_converter, +) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor LOGGER: logging.Logger = logging.getLogger(__name__) @@ -16,28 +21,25 @@ def to_copy( name: str, input: TRTTensor, dtype: TRTDataType, + force_layer: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError( f"to_copy received input {input} that is not a TensorRT ITensor" ) - casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) - return casted_tensor - - -def clone( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"clone received input {input} that is not a TensorRT ITensor" - ) - - LOGGER.debug(f"Evaluating clone on object with name: {name}") - - return input + # If cast is forced, insert identity layer regardless of whether the dtype + # doesn't change + if force_layer: + trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + target_str = ConverterRegistry.qualified_name_or_str(target) + target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" + + identity_layer = network.add_identity(input) + identity_layer.set_output_type(0, trt_dtype) + identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]" + return identity_layer.get_output(0) + else: + casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 95dcd88a75..b2176653d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -11,11 +11,7 @@ cast_trt_tensor, get_trt_tensor, ) -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, - set_layer_name, - squeeze_left, -) +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter @@ -96,10 +92,10 @@ def convert_binary_elementwise( is_rhs_trt_tensor = False if isinstance(lhs_val, TRTTensor): - lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY) + lhs_dtype = lhs_val.dtype is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY) + rhs_dtype = rhs_val.dtype is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: @@ -124,23 +120,13 @@ def convert_binary_elementwise( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = np.array([rhs_val], dtype=lhs_dtype) + rhs_val = np.array( + [rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY) + ) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = np.array([lhs_val], dtype=rhs_dtype) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): - rhs_val = squeeze_left(rhs_val) + lhs_val = np.array( + [lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY) + ) lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/tests/py/dynamo/conversion/test_casts.py b/tests/py/dynamo/conversion/test_casts.py index d1893b1c46..50d94713c5 100644 --- a/tests/py/dynamo/conversion/test_casts.py +++ b/tests/py/dynamo/conversion/test_casts.py @@ -35,6 +35,19 @@ def forward(self, x): disable_passes=True, ) + def test_clone_direct(self): + class Clone(nn.Module): + def forward(self, x): + return x.clone() + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + class TestToCopyConverter(DispatchTestCase): def test_to_copy_half(self): @@ -83,6 +96,20 @@ def forward(self, x): disable_passes=True, ) + def test_to_copy_direct(self): + class ToCopyFloat(nn.Module): + def forward(self, x): + return x.to(dtype=torch.float, copy=True) + + inputs = [torch.rand((1, 3, 10)).float()] + self.run_test( + ToCopyFloat(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.float, + disable_passes=True, + ) + if __name__ == "__main__": run_tests()