diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 6d890d0450..8379b44f0f 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha Dynamo IR ------------- -Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based. -The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects +The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default. +In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation. +The `output_format` can take the following options -a) Converting to Torchscript +* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram +* `torchscript` (or) `ts` : This returns a TorchScript module +* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk. + +a) Torchscript ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk. -The following code illustrates this approach. +If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save .. code-block:: python @@ -30,9 +34,9 @@ The following code illustrates this approach. model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule - trt_traced_model = torch.jit.trace(trt_gm, inputs) - torch.jit.save(trt_traced_model, "trt_model.ts") + # trt_ts is a torch.jit.ScriptModule object + trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript") + torch.jit.save(trt_ts, "trt_model.ts") # Later, you can load it and run inference model = torch.jit.load("trt_model.ts").cuda() @@ -41,8 +45,7 @@ The following code illustrates this approach. b) ExportedProgram ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant -`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk. +`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation. .. code-block:: python @@ -51,26 +54,36 @@ b) ExportedProgram model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule - # Transform and create an exported program - trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs) - torch.export.save(trt_exp_program, "trt_model.ep") + # trt_ep is a torch.export.ExportedProgram object + trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs) + torch.export.save(trt_ep, "trt_model.ep") # Later, you can load it and run inference model = torch.export.load("trt_model.ep") model(*inputs) -`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. -This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). +c) GraphModule +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 +We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`. +Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or +exported into `ExportedProgram` objects +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + inputs = [torch.randn((1, 3, 224, 224)).cuda()] + # trt_gm is a torch.fx.GraphModule object + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module") Torchscript IR ------------- In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. -This behavior stays the same in 2.X versions as well. +For `ir=ts`, this behavior stays the same in 2.X versions as well. .. code-block:: python diff --git a/examples/int8/training/vgg16/vgg16.py b/examples/int8/training/vgg16/vgg16.py index b371b8e243..379306114b 100644 --- a/examples/int8/training/vgg16/vgg16.py +++ b/examples/int8/training/vgg16/vgg16.py @@ -3,10 +3,12 @@ - [Very Deep Convolutional Networks for Large-Scale Image Recognition]( https://arxiv.org/abs/1409.1556) (ICLR 2015) """ + +from functools import reduce + import torch import torch.nn as nn import torch.nn.functional as F -from functools import reduce class VGG(nn.Module): diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 0f8ce1e392..6f20b6c84c 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -32,12 +32,14 @@ class Device(object): allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """ - device_type: Optional[ - trt.DeviceType - ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + device_type: Optional[trt.DeviceType] = ( + None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified. + ) gpu_id: int = -1 #: Device ID for target GPU dla_core: int = -1 #: Core ID for target DLA core - allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed + allow_gpu_fallback: bool = ( + False #: Whether falling back to GPU if DLA cannot support an op should be allowed + ) def __init__(self, *args: Any, **kwargs: Any): """__init__ Method for torch_tensorrt.Device diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 9acb073c62..db36678d17 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -28,12 +28,12 @@ class _ShapeMode(Enum): STATIC = 0 DYNAMIC = 1 - shape_mode: Optional[ - _ShapeMode - ] = None #: Is input statically or dynamically shaped - shape: Optional[ - Tuple[int, ...] | Dict[str, Tuple[int, ...]] - ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` + shape_mode: Optional[_ShapeMode] = ( + None #: Is input statically or dynamically shaped + ) + shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = ( + None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` + ) dtype: _enums.dtype = ( _enums.dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index e705c069d5..504f4d4491 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,6 +5,7 @@ from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union import torch +import torch_tensorrt from torch.export import ExportedProgram from torch.fx.node import Target from torch_tensorrt import _enums @@ -29,6 +30,7 @@ MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, + OUTPUT_FORMAT, PASS_THROUGH_BUILD_FAILURES, PRECISION, REFIT, @@ -46,6 +48,7 @@ dryrun_stats_display, parse_non_trt_nodes, ) +from torch_tensorrt.dynamo._exporter import export from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -66,8 +69,6 @@ to_torch_tensorrt_device, ) -import torch_tensorrt - logger = logging.getLogger(__name__) @@ -103,8 +104,9 @@ def compile( enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, dryrun: bool = DRYRUN, hardware_compatible: bool = HARDWARE_COMPATIBLE, + output_format: str = OUTPUT_FORMAT, **kwargs: Any, -) -> torch.fx.GraphModule: +) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]: """Compile a TorchScript module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler @@ -161,6 +163,7 @@ def compile( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) + output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -217,9 +220,9 @@ def compile( "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size, - "torch_executed_ops": torch_executed_ops - if torch_executed_ops is not None - else set(), + "torch_executed_ops": ( + torch_executed_ops if torch_executed_ops is not None else set() + ), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, @@ -238,11 +241,14 @@ def compile( "dla_global_dram_size": dla_global_dram_size, "dryrun": dryrun, "hardware_compatible": hardware_compatible, + "output_format": output_format, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - return compile_module(gm, inputs, settings) + trt_gm = compile_module(gm, inputs, settings) + trt_result = export(trt_gm, torch_inputs, output_format) + return trt_result def compile_module( diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 3d48ab3def..ec038c0dba 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,6 +26,7 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False +OUTPUT_FORMAT = "exported_program" def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index df9150ea2d..c7e2f37795 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -1,4 +1,3 @@ -import copy import operator from typing import Any, Dict, Sequence, Tuple, cast @@ -19,50 +18,43 @@ def export( gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], - *, - ir: str = "torchscript", + output_format: str, ) -> ExportedProgram: - """Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded. - - > Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile`` + """Export the result of TensorRT compilation into the desired output format. Arguments: - src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``) gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` - - Keyword Arguments: - inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using - torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum - to select device type. :: - - input=[ - torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 - torch_tensorrt.Input( - min_shape=(1, 224, 224, 3), - opt_shape=(1, 512, 512, 3), - max_shape=(1, 1024, 1024, 3), - dtype=torch.int32 - format=torch.channel_last - ), # Dynamic input shape for input #2 - torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings - ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program. + inputs (torch.Tensor): Torch input tensors + output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" """ - if ir == "torchscript": + if output_format == "torchscript" or output_format == "ts": return torch.jit.trace(gm, inputs) - elif ir == "exported_program": + elif output_format == "exported_program" or output_format == "ep": patched_module = transform(gm, inputs) exp_program = create_trt_exp_program(patched_module) - return exp_program + elif output_format == "graph_module" or output_format == "fx": + return gm else: raise ValueError( - f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program" + f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx" ) def transform( gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] ) -> torch.fx.GraphModule: + """ + Transforms the graphmodule by inlining Pytorch and TensorRT submodules. + Inlining collapses submodules into nodes which is necessary for torch.export + serialization. + + Arguments: + gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` + inputs (torch.Tensor): Torch input tensors + + Returns an inlined torch.fx.GraphModule + """ # Run shape analysis _, outputs_map = partitioning.run_shape_analysis(gm, inputs) @@ -72,10 +64,6 @@ def transform( # Inline pytorch submodules inline_torch_modules(gm) - # Lift constant buffers and parameters in the graph - # torch.export serialization expects them to be lifted - lift_constant_pass(gm) - # Clean the graph gm.delete_all_unused_submodules() gm.graph.eliminate_dead_code() @@ -84,34 +72,89 @@ def transform( return gm -def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule: + """ + Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders. + Arguments: + gm (torch.fx.GraphModule): Unlifted GraphModule which contains parameters and buffers as get_attr nodes. + graph_signature (torch.export.ExportGraphSignature): Instance of ExportGraphSignature class created for the output ExportedProgram. + After lifting, this graph_signature will be modified with the parameters and buffers added appropriately. + Returns: + A lifted fx.GraphModule, modified graph_signature and a new state_dict + """ + # Get the state_dict of graph_module. This is different from exported_program.state_dict + # exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict + # has all parameters registered as torch.tensors. + state_dict = gm.state_dict() + fake_mode = detect_fake_mode( - tuple( - node.meta["val"] for node in trt_gm.graph.nodes if node.op == "placeholder" - ) + tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") ) + assert fake_mode is not None + # Locate the user input to insert new placeholders before them first_user_input = None - for node in trt_gm.graph.nodes: - if node.op == "placeholder": + for node in gm.graph.nodes: + if node.op == "placeholder" and node.name in graph_signature.user_inputs: first_user_input = node break - for node in trt_gm.graph.nodes: + # At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0 + # The input_specs should be of the form [params, buffers, constant_tensors, user_inputs] + non_user_input_idx = 0 + for node in gm.graph.nodes: if node.op == "get_attr": - constant_tensor = getattr(trt_gm, node.target) - with trt_gm.graph.inserting_before(first_user_input): - const_placeholder_node = trt_gm.graph.placeholder(node.target) - const_placeholder_node.meta = copy.deepcopy(node.meta) - const_placeholder_node.meta["val"] = fake_mode.from_tensor( - constant_tensor + if node.target not in state_dict: + raise ValueError( + f"The get_attr node : {node.name} with target: {node.target} value could not be found in state_dict. Please check the input exported_program's graphmodule parameters." + ) + + constant_tensor = state_dict[node.target] + input_kind = InputKind.CONSTANT_TENSOR + + # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively. + for name, _ in gm.named_parameters(): + if node.target == name: + input_kind = InputKind.PARAMETER + state_dict[name] = torch.nn.Parameter(state_dict[name]) + break + for name, _ in gm.named_buffers(): + if node.target == name: + input_kind = InputKind.BUFFER + break + + # Replace get_attr nodes with placeholder nodes and copy metadata. + with gm.graph.inserting_before(first_user_input): + const_placeholder_node = gm.graph.placeholder(node.target) + # Copy the node meta into this new placeholder node + const_placeholder_node.meta = node.meta + const_placeholder_node.meta["val"] = cast( + FakeTensor, + torch.empty_strided( + tuple(constant_tensor.shape), + tuple([1] * len(constant_tensor.shape)), + ), ) + node.replace_all_uses_with(const_placeholder_node) - trt_gm.graph.erase_node(node) + gm.graph.erase_node(node) + + # Add these parameters/buffers/constants to the existing graph signature + # before user inputs. These specs are looked up in the state_dict during ExportedProgram creation. + graph_signature.input_specs.insert( + non_user_input_idx, + InputSpec( + kind=input_kind, + arg=TensorArgument(name=const_placeholder_node.name), + target=node.target, + ), + ) + non_user_input_idx += 1 + + gm.graph.eliminate_dead_code() + gm.graph.lint() - trt_gm.graph.eliminate_dead_code() - trt_gm.graph.lint() - return trt_gm + return gm, graph_signature, state_dict def get_duplicate_nodes( @@ -140,7 +183,7 @@ def get_duplicate_nodes( def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Inline a submodule within the parent graph (gm). All `call_module` nodes - should be replaced by their submodule nodes. + should be replaced by their nodes in the submodule. """ # Clean the graph gm.graph.eliminate_dead_code() @@ -165,7 +208,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Copy all nodes in the submodule into gm and # store the output node of this submodule which is now present in gm - submodule_output = gm.graph.graph_copy(submodule.graph, val_map) # Get their references (since we copied) in the parent graph (gm) @@ -197,7 +239,7 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: gm_node.replace_all_uses_with(submodule_output) # copy the attributes of the submodule into gm (graph_copy doesn't do this) - copy_submodule_attributes(gm, gm_node.name) + copy_submodule_attributes(gm, submodule, gm_node.name) # Erase the pytorch submodule (call_module) node gm.graph.erase_node(gm_node) @@ -205,20 +247,24 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm -def copy_submodule_attributes(gm: torch.fx.GraphModule, submod_name: str) -> None: +def copy_submodule_attributes( + gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule, submodule_name: str +) -> None: """ - Copy the getattr attriibutes from submodule to parent module gm. - The graph_copy call doesn't do this for us unfortunately. + The submodule parameters are available in the parent gm's state_dict, but they have + the submodule name as a prefix in their keys. For eg: gm.state_dict() would have + _run_on_gpu_0.conv.weight etc. Since we graph copied the submodule into gm, we should + also copy it's parameters and buffers into gm without the submodule namespace as prefix. + _assign_attr does exactly that. It creates a module for eg: conv, adds an attribute weight + to it and adds this conv module as an attribute to parent gm. """ - for param in gm.named_parameters(): - if param[0].startswith(submod_name + "."): - attr_name = param[0].replace(submod_name + ".", "") - gm.register_parameter(attr_name, param[1]) + from torch.export.unflatten import _assign_attr, _AttrKind + + for key, value in submodule.named_parameters(): + _assign_attr(value, gm, key, _AttrKind.PARAMETER) - for buffer in gm.named_buffers(): - if buffer[0].startswith(submod_name + "."): - attr_name = buffer[0].replace(submod_name + ".", "") - gm.register_buffer(attr_name, buffer[1]) + for key, value in submodule.named_buffers(): + _assign_attr(value, gm, key, _AttrKind.BUFFER) def create_trt_exp_program( @@ -227,6 +273,7 @@ def create_trt_exp_program( """Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines and constructs an Exported Program object with the new IO node names and state_dict """ + input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] output_nodes = [node for node in gm.graph.nodes if node.op == "output"] assert output_nodes @@ -245,8 +292,18 @@ def create_trt_exp_program( input_specs=input_specs, output_specs=output_specs ) + # Lift parameters/buffers/constants in the graph + # torch.export serialization expects them to be lifted + gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature) + trt_exp_program = ExportedProgram( - gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], [] + gm, + gm.graph, + trt_graph_signature, + state_dict, + {}, + [], + [], ) return trt_exp_program diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 2420a227d8..c00b049f45 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -19,6 +19,7 @@ MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, + OUTPUT_FORMAT, PASS_THROUGH_BUILD_FAILURES, PRECISION, REFIT, @@ -70,6 +71,7 @@ class CompilationSettings: TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the ouptut to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) + output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" """ precision: torch.dtype = PRECISION @@ -97,3 +99,4 @@ class CompilationSettings: dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE + output_format: str = OUTPUT_FORMAT diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 5db9fc183e..06ae596ed0 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -28,9 +28,9 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ - Callable[[torch.fx.GraphModule], None] -] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( + Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +) class UnsupportedOperatorException(RuntimeError): @@ -92,9 +92,9 @@ def __init__( self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = [] - self._itensor_to_tensor_meta: Dict[ - trt.tensorrt.ITensor, TensorMetadata - ] = dict() + self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( + dict() + ) self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index f90c869c15..f9d14917f1 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -324,13 +324,11 @@ def get_trt_tensor( @overload -def get_positive_dim(dim: int, dim_size: int) -> int: - ... +def get_positive_dim(dim: int, dim_size: int) -> int: ... @overload -def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: - ... +def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ... def get_positive_dim( diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index af92a9dc50..de791851db 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -7,9 +7,9 @@ aten = torch.ops.aten -_core_aten_decompositions: Dict[ - OpOverload, Callable[[Any], Any] -] = core_aten_decompositions() +_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = ( + core_aten_decompositions() +) torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten._adaptive_avg_pool2d_backward, aten.addcdiv, @@ -180,9 +180,9 @@ } -ENABLED_TORCH_DECOMPOSITIONS: Dict[ - OpOverload, Callable[[Any], Any] -] = get_torch_decompositions(torch_enabled_decompositions) +ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = ( + get_torch_decompositions(torch_enabled_decompositions) +) TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 2443e33d50..81a9d76d6e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -42,11 +42,6 @@ def constant_fold( for node in gm.graph.nodes: # If get_attr node has no users, mark it for deletion if node.op == "get_attr" and len(node.users) == 0: - # If the node's parameter is not a parameter of any other node, remove it - if not any( - other.target == node.target for other in gm.graph.nodes if other != node - ): - delattr(gm, node.target) erased_params.append(node) # Remove unused nodes from the graph diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py index 75ad067a3f..ef2c0531a6 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py @@ -22,12 +22,10 @@ def lower_linear( return gm -def linear_replacement() -> ( - Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], - ] -): +def linear_replacement() -> Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: """Constructs the original and replacement functions for linear""" # Original graph diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 74dee9c0c9..161dbbe9df 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -60,12 +60,10 @@ def lower_scaled_dot_product_attention( return gm -def scaled_dot_product_attention_replacement() -> ( - Tuple[ - Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], - ] -): +def scaled_dot_product_attention_replacement() -> Tuple[ + Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]], + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +]: """Constructs the original and replacement functions for efficient attention""" # Efficient Attention original graph diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py index efc836814f..e2ef051f06 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -22,12 +22,10 @@ def view_to_reshape( return gm -def view_replacement() -> ( - Tuple[ - torch.fx.GraphModule, - Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], - ] -): +def view_replacement() -> Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], +]: """Constructs the original and replacement functions for view""" # Original graph diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index db45609123..3a66ed3716 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch +import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt.dynamo.runtime.tools import ( @@ -15,8 +16,6 @@ ) from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -import torch_tensorrt - logger = logging.getLogger(__name__) @@ -101,9 +100,11 @@ def _initialize(self) -> None: for idx in self.output_binding_indices_in_order ] self.output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes = [ @@ -113,9 +114,11 @@ def _initialize(self) -> None: for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.hidden_output_binding_indices_in_order ] @@ -167,9 +170,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.context = self.engine.create_execution_context() def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:Forward" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") + if self.profiling_enabled + else nullcontext() + ): self._check_initialized() # If in safe mode, check at each iteration for for whether a switch is required @@ -200,9 +205,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . inputs = tuple([tensor.to(device) for tensor in inputs]) logger.warning(f"Moved all input Tensors to cuda:{device_id}") - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessInputs" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): assert len(inputs) == len( self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." @@ -239,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . idx, tuple(contiguous_inputs[i].shape) ) - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:ProcessOutputs" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): # create output tensors outputs: List[torch.Tensor] = [] @@ -266,9 +279,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ) bindings[idx] = output.data_ptr() - with torch.autograd.profiler.record_function( - "PythonTorchTensorRTModule:TensorRTRuntime" - ) if self.profiling_enabled else nullcontext(): + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): self.context.execute_async_v2( bindings, torch.cuda.current_stream().cuda_stream ) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 17c19eda33..f11e40a6db 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -3,24 +3,22 @@ import math import operator import warnings -from typing import cast, Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils +from torch.fx.immutable_collections import immutable_list +from torch.fx.node import Argument, Target from torch_tensorrt.fx.converters import acc_ops_converters +from torch_tensorrt.fx.converters.impl import activation, convolution from ..converter_registry import tensorrt_converter - from ..types import * # noqa: F403 -from torch.fx.immutable_collections import immutable_list -from torch.fx.node import Argument, Target - from .converter_utils import * # noqa: F403 -import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch_tensorrt.fx.converters.impl import activation, convolution _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -317,21 +315,17 @@ def aten_ops_max_poolnd( kwargs_new = { "input": args[0], "kernel_size": args[1], - "stride": args[2] - if len(args) > 2 - else (None, None) - if len(args[1]) == 2 - else (None, None, None), - "padding": args[3] - if len(args) > 3 - else (0, 0) - if len(args[1]) == 2 - else (0, 0, 0), - "dilation": args[4] - if len(args) > 4 - else (1, 1) - if len(args[1]) == 2 - else (1, 1, 1), + "stride": ( + args[2] + if len(args) > 2 + else (None, None) if len(args[1]) == 2 else (None, None, None) + ), + "padding": ( + args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0) + ), + "dilation": ( + args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1) + ), "ceil_mode": args[5] if len(args) > 5 else False, } return acc_ops_converters.acc_ops_max_poolnd( diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d7ef976fba..6a29932b1b 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -17,13 +17,13 @@ from .converter_registry import CONVERTERS from .input_tensor_spec import InputTensorSpec from .observer import Observer -from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks +from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter _LOGGER: logging.Logger = logging.getLogger(__name__) -TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[ - Callable[[torch.fx.GraphModule], None] -] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = ( + Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +) class TRTInterpreterResult(NamedTuple): @@ -75,9 +75,9 @@ def __init__( self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = [] - self._itensor_to_tensor_meta: Dict[ - trt.tensorrt.ITensor, TensorMetadata - ] = dict() + self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( + dict() + ) def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 5f66519e05..fa148ce6cb 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -16,7 +16,6 @@ from .passes.pass_utils import PassFunc, validate_inference from .tools.timing_cache_utils import TimingCacheManager from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting - from .tracer.acc_tracer import acc_tracer from .trt_module import TRTModule from .utils import LowerPrecision @@ -126,9 +125,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: input_specs=self.lower_setting.input_specs, explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision, - logger_level=trt.Logger.VERBOSE - if self.lower_setting.verbose_log - else trt.Logger.WARNING, + logger_level=( + trt.Logger.VERBOSE + if self.lower_setting.verbose_log + else trt.Logger.WARNING + ), ) interp_result: TRTInterpreterResult = interpreter.run( @@ -138,9 +139,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: strict_type_constraints=self.lower_setting.strict_type_constraints, algorithm_selector=algo_selector, timing_cache=cache_data, - profiling_verbosity=trt.ProfilingVerbosity.DETAILED - if self.lower_setting.verbose_profile - else trt.ProfilingVerbosity.LAYER_NAMES_ONLY, + profiling_verbosity=( + trt.ProfilingVerbosity.DETAILED + if self.lower_setting.verbose_profile + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY + ), tactic_sources=self.lower_setting.tactic_sources, ) @@ -297,10 +300,8 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module: # handle inputs with custom types. By default, just handle # tensors and NoneType. if fp16_conversion_fn is None: - conversion_fn = ( - lambda x: x.half() - if x is not None and x.dtype == torch.float32 - else x + conversion_fn = lambda x: ( + x.half() if x is not None and x.dtype == torch.float32 else x ) else: conversion_fn = fp16_conversion_fn diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index b203bc82e0..fb75a3e3c3 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -11,7 +11,6 @@ from torch.fx.experimental.const_fold import split_const_subgraphs from ..observer import observable - from ..tracer.acc_tracer import acc_ops from ..tracer.acc_tracer.acc_utils import get_attr from .pass_utils import log_before_after, validate_inference @@ -538,9 +537,9 @@ def get_reshape_batch_size_inferred_source( ) if not reshape_batch_size: continue - reshape_batch_size_inferred_source: Optional[ - fx.Node - ] = get_reshape_batch_size_inferred_source(reshape_batch_size) + reshape_batch_size_inferred_source: Optional[fx.Node] = ( + get_reshape_batch_size_inferred_source(reshape_batch_size) + ) if not reshape_batch_size_inferred_source: continue diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 6e6b40d42f..8f3cc576ec 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -5,19 +5,17 @@ import torch from torch import nn -from torch.fx.passes.pass_manager import inplace_wrapper, PassManager +from torch.fx.passes.pass_manager import PassManager, inplace_wrapper from torch.fx.passes.shape_prop import ShapeProp -from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult +from torch.fx.passes.splitter_base import SplitResult, generate_inputs_for_submodules from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion from torch_tensorrt.fx.utils import LowerPrecision from ..input_tensor_spec import generate_input_specs - from ..lower_setting import LowerSetting from ..observer import Observer from ..passes.remove_duplicate_output_args import remove_duplicate_output_args from .graph_opts import common_subexpression_elimination - from .lower_basic_pass import ( # noqa fix_clamp_numerical_limits_to_fp16, fix_reshape_batch_dim, @@ -26,7 +24,6 @@ run_const_fold, ) - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -196,9 +193,11 @@ def lower_func(split_result: SplitResult) -> nn.Module: self.lower_setting.input_specs = generate_input_specs( submod_inputs, self.lower_setting, - additional_submodule_inputs[submod_name] - if additional_submodule_inputs - else None, + ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None + ), ) lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name @@ -236,9 +235,11 @@ def lower_func(split_result: SplitResult) -> nn.Module: lowering_start_time = datetime.datetime.now() self.lower_setting.additional_inputs = ( - additional_submodule_inputs[submod_name] - if additional_submodule_inputs - else None, + ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None + ), ) lowered_module = self._lower_func( diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 0b8578ffba..2de5c23aaf 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -195,9 +195,7 @@ def pass_with_validation( kwargs2["rtol"] = rtol if atol: kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( + kwargs2["msg"] = ( lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" ) # If tensors are on different devices, make sure to compare diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py index 29d174d9fd..cf49d028ae 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py @@ -23,9 +23,11 @@ def forward(self, x): Split(), inputs, expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor + ( + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + ) }, test_explicit_batch_dim=False, ) @@ -70,9 +72,11 @@ def forward(self, x): Split(), input_specs, expected_ops={ - acc_ops.split - if isinstance(split_size_or_sections, int) - else acc_ops.slice_tensor + ( + acc_ops.split + if isinstance(split_size_or_sections, int) + else acc_ops.slice_tensor + ) }, ) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 6d883a4f62..2ddd832c2a 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -7,7 +7,6 @@ import tensorrt as trt import torch import torch.fx - import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.experimental.normalize import NormalizeArgs @@ -154,9 +153,9 @@ def run_test_custom_compare_results( self.assert_has_op(mod, expected_ops) interpreter_result = interpreter.run( - lower_precision=LowerPrecision.FP16 - if fp16_mode - else LowerPrecision.FP32 + lower_precision=( + LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32 + ) ) trt_mod = TRTModule( interpreter_result.engine, diff --git a/py/torch_tensorrt/fx/trt_module.py b/py/torch_tensorrt/fx/trt_module.py index ab2d9ac348..c5bab21353 100644 --- a/py/torch_tensorrt/fx/trt_module.py +++ b/py/torch_tensorrt/fx/trt_module.py @@ -4,7 +4,7 @@ import tensorrt as trt import torch -from .utils import unified_dtype_converter, Frameworks +from .utils import Frameworks, unified_dtype_converter class TRTModule(torch.nn.Module): @@ -69,9 +69,11 @@ def _initialize(self): for idx in self.output_binding_indices_in_order ] self.output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes: Sequence[torch.dtype] = [ @@ -81,9 +83,11 @@ def _initialize(self): for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ - tuple(self.engine.get_binding_shape(idx)) - if self.engine.has_implicit_batch_dimension - else tuple() + ( + tuple(self.engine.get_binding_shape(idx)) + if self.engine.has_implicit_batch_dimension + else tuple() + ) for idx in self.hidden_output_binding_indices_in_order ] diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index b9a84152e1..37f5fb79e3 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -3,6 +3,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Set +import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C from torch_tensorrt import _C, _enums @@ -11,8 +12,6 @@ from torch_tensorrt.logging import Level, log from torch_tensorrt.ts._Input import TorchScriptInput -import tensorrt as trt - def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: clone = torch.classes.tensorrt._Input() @@ -406,9 +405,9 @@ def TensorRTCompileSpec( "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers. - "enabled_precisions": enabled_precisions - if enabled_precisions is not None - else set(), # Enabling FP16 kernels + "enabled_precisions": ( + enabled_precisions if enabled_precisions is not None else set() + ), # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels diff --git a/pyproject.toml b/pyproject.toml index c987ac1f40..5c42700ef8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,7 +177,7 @@ skip = [ [tool.black] #line-length = 120 -target-versions = ["py38", "py39", "py310", "py311", "py312"] +target-version = ["py38", "py39", "py310", "py311", "py312"] force-exclude = """ elu_converter/setup.py """ diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index ea7700443c..efa593890e 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -3,11 +3,10 @@ import pytest import timm import torch +import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -import torch_tensorrt as torchtrt - assertions = unittest.TestCase() @@ -43,8 +42,7 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) torch.export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch.export.load("/tmp/trt.ep") @@ -95,8 +93,7 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) torch.export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch.export.load("/tmp/trt.ep") # Check Pyt and TRT exported program outputs @@ -115,15 +112,15 @@ def forward(self, x): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) @pytest.mark.unit -def test_base_full_compile_save_load(ir): +def test_no_compile(ir): """ - This tests export save and load functionality on a base model - with multiple outputs which is fully TRT convertible + This tests export serde functionality on a model + which won't convert to TRT because of min_block_size=5 constraint """ class MyModule(torch.nn.Module): @@ -148,31 +145,30 @@ def forward(self, x): ) ], "ir": ir, - "min_block_size": 1, + "debug": True, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) torch.export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - + # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_exp_program(input) - # Check Pyt and TRT exported program outputs for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + # Check Pyt and deserialized TRT exported program outputs outputs_trt_deser = deser_trt_exp_program(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) @@ -211,8 +207,7 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) torch.export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch.export.load("/tmp/trt.ep") @@ -222,7 +217,7 @@ def forward(self, x): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) outputs_trt_deser = deser_trt_exp_program(input) @@ -230,12 +225,12 @@ def forward(self, x): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) @pytest.mark.unit -def test_resnet18_save_load(ir): +def test_resnet18(ir): """ This tests export save and load functionality on Resnet18 model """ @@ -253,8 +248,7 @@ def test_resnet18_save_load(ir): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) - trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program") + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) torch.export.save(trt_exp_program, "/tmp/trt.ep") deser_trt_exp_program = torch.export.load("/tmp/trt.ep") @@ -263,7 +257,7 @@ def test_resnet18_save_load(ir): cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) outputs_trt_deser = deser_trt_exp_program(input) @@ -271,61 +265,62 @@ def test_resnet18_save_load(ir): cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) -# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 -# @pytest.mark.unit -# def test_hybrid_conv_fallback(ir): -# """ -# This tests export save and load functionality on a hybrid -# model where a conv (a weighted layer) has been forced to fallback to Pytorch. -# """ - -# class MyModule(torch.nn.Module): -# def __init__(self): -# super().__init__() -# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) -# self.relu = torch.nn.ReLU() - -# def forward(self, x): -# conv = self.conv(x) -# relu = self.relu(conv) -# mul = relu * 0.5 -# return mul - -# model = MyModule().eval().cuda() -# input = torch.randn((1, 3, 224, 224)).to("cuda") - -# compile_spec = { -# "inputs": [ -# torchtrt.Input( -# input.shape, dtype=torch.float, format=torch.contiguous_format -# ) -# ], -# "ir": ir, -# "min_block_size": 1, -# "torch_executed_ops": "torch.ops.aten.convolution.default", -# } - -# trt_exp_program = torchtrt.compile(model, **compile_spec) -# torch.export.save(trt_exp_program, "/tmp/trt.ep") -# deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - -# outputs_pyt = model(input) -# outputs_trt = trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) - -# outputs_trt_deser = deser_trt_exp_program(input) -# for idx in range(len(outputs_pyt)): -# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) -# assertions.assertTrue( -# cos_sim > COSINE_THRESHOLD, -# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", -# ) +@pytest.mark.unit +def test_hybrid_conv_fallback(ir): + """ + This tests export save and load functionality on a hybrid + model where a conv (a weighted layer) has been forced to fallback to Pytorch. + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "torch_executed_ops": {"torch.ops.aten.convolution.default"}, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) + torch.export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/tests/py/dynamo/models/test_output_format.py b/tests/py/dynamo/models/test_output_format.py new file mode 100644 index 0000000000..3d2e747ceb --- /dev/null +++ b/tests/py/dynamo/models/test_output_format.py @@ -0,0 +1,62 @@ +import unittest + +import pytest +import timm +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_output_format(ir): + """ + This tests output_format type in the compilation setting + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + trt_ep = torchtrt.compile(model, ir="dynamo", inputs=[input], min_block_size=1) + assertions.assertTrue( + isinstance(trt_ep, torch.export.ExportedProgram), + msg=f"test_output_format output type does not match with torch.export.ExportedProgram", + ) + + trt_ts = torchtrt.compile( + model, + ir="dynamo", + inputs=[input], + min_block_size=1, + output_format="torchscript", + ) + assertions.assertTrue( + isinstance(trt_ts, torch.jit.ScriptModule), + msg=f"test_output_format output type does not match with torch.jit.ScriptModule", + ) + + trt_gm = torchtrt.compile( + model, + ir="dynamo", + inputs=[input], + min_block_size=1, + output_format="graph_module", + ) + assertions.assertTrue( + isinstance(trt_gm, torch.fx.GraphModule), + msg=f"test_output_format output type does not match with torch.fx.GraphModule", + ) diff --git a/tests/py/dynamo/runtime/test_hw_compat.py b/tests/py/dynamo/runtime/test_hw_compat.py index 9ee7206adf..4218cc7de0 100644 --- a/tests/py/dynamo/runtime/test_hw_compat.py +++ b/tests/py/dynamo/runtime/test_hw_compat.py @@ -2,9 +2,8 @@ import unittest import torch -from torch.testing._internal.common_utils import TestCase, run_tests - import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests class TestHardwareCompatibility(TestCase): @@ -24,6 +23,7 @@ def forward(self, x): pass_through_build_failures=True, hardware_compatible=True, use_python_runtime=False, + output_format="graph_module", ) self.assertTrue(optimized_model_hw_compat._run_on_acc_0.hardware_compatible) @@ -41,6 +41,7 @@ def forward(self, x): pass_through_build_failures=True, hardware_compatible=False, use_python_runtime=False, + output_format="graph_module", ) self.assertFalse(