diff --git a/.circleci/config.yml b/.circleci/config.yml index f1143f25b0..d1e36447d3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -802,7 +802,7 @@ commands: - store_artifacts: path: /tmp/testlogs - test-dynamo-models_torch_export: + test-dynamo-models_export: description: "Test the Dynamo models via torch_export path" steps: - run: @@ -818,6 +818,20 @@ commands: - store_artifacts: path: /tmp/testlogs + test-dynamo-export_serde: + description: "Test the export serialize/deserialize functionality for Dynamo models" + steps: + - run: + name: Run Dynamo models and test export serde with TRT compiled modules + command: | + cd tests/py/dynamo/models + pytest test_export_serde.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + test-dynamo-converters: description: "Test the Dynamo aten converters" steps: @@ -1122,7 +1136,8 @@ jobs: - test-dynamo-backend - test-dynamo-shared_utilities - test-dynamo-models_torch_compile - - test-dynamo-models_torch_export + - test-dynamo-models_export + - test-dynamo-export_serde package-x86_64-linux: parameters: diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 7db6c19636..bfc19cce45 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -141,7 +141,8 @@ jobs: cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py - ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_dyn_models.py + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py popd tests-py-torch-compile-be: diff --git a/docsrc/index.rst b/docsrc/index.rst index 18fb1185e8..9e98c7a63d 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -42,6 +42,7 @@ User Guide * :ref:`getting_started_with_fx` * :ref:`ptq` * :ref:`runtime` +* :ref:`saving_models` * :ref:`dynamic_shapes` * :ref:`use_from_pytorch` * :ref:`using_dla` @@ -55,6 +56,7 @@ User Guide user_guide/getting_started_with_fx_path user_guide/ptq user_guide/runtime + user_guide/saving_models user_guide/dynamic_shapes user_guide/use_from_pytorch user_guide/using_dla diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst new file mode 100644 index 0000000000..46fadcb905 --- /dev/null +++ b/docsrc/user_guide/saving_models.rst @@ -0,0 +1,77 @@ +.. _runtime: + +Saving models compiled with Torch-TensorRT +==================================== + +Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation. + +1) Dynamo IR + +Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based. +The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects + +a) Converting to Torchscript +`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk. +The following code illustrates this approach. + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + inputs = torch.randn((1, 3, 224, 224)).cuda() + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule + trt_script_model = torch.jit.trace(trt_gm, inputs) + torch.jit.save(trt_script_model, "trt_model.ts") + + # Later, you can load it and run inference + model = torch.jit.load("trt_model.ts").cuda() + model(inputs) + +b) ExportedProgram +`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant +`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk. + +.. code-block:: python + + import torch + import torch_tensorrt + from torch_tensorrt.dynamo.export import transform, create_exported_program + + model = MyModel().eval().cuda() + inputs = torch.randn((1, 3, 224, 224)).cuda() + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule + # Transform and create an exported program + trt_gm = transform(trt_gm, inputs) + trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict()) + torch._export.save(trt_exp_program, "trt_model.ep") + + # Later, you can load it and run inference + model = torch._export.load("trt_model.ep") + model(inputs) + +`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. +This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). + +NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 + +2) Torchscript IR + + In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. + This behavior stays the same in 2.X versions as well. + + .. code-block:: python + + import torch + import torch_tensorrt + + model = MyModel().eval().cuda() + inputs = torch.randn((1, 3, 224, 224)).cuda() + trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object + torch.jit.save(trt_ts, "trt_model.ts") + + # Later, you can load it and run inference + model = torch.jit.load("trt_model.ts").cuda() + model(inputs) + diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 67bf6d523e..9b9f4c00a1 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -224,14 +224,14 @@ def compile( # Export the module torchtrt_inputs = prepare_inputs(input_list) - module = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) - compiled_aten_module: torch.fx.GraphModule = dynamo_compile( - module, + exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + trt_graph_module = dynamo_compile( + exp_program, inputs=torchtrt_inputs, enabled_precisions=enabled_precisions_set, **kwargs, ) - return compiled_aten_module + return trt_graph_module elif target_ir == _IRType.torch_compile: return torch_compile( module, enabled_precisions=enabled_precisions_set, **kwargs diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 8064ac0186..63cc2af10a 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -7,10 +7,12 @@ logger = logging.getLogger(__name__) if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): - from ._settings import * # noqa: F403 - from ._SourceIR import SourceIR # noqa: F403 - from .aten_tracer import trace # noqa: F403 - from .compile import compile # noqa: F403 - from .conversion import * # noqa: F403 - from .conversion.converter_registry import DYNAMO_CONVERTERS # noqa: F403 - from .conversion.converter_registry import dynamo_tensorrt_converter # noqa: F403 + from ._settings import * + from ._SourceIR import SourceIR + from .aten_tracer import trace + from .compile import compile + from .conversion import * + from .conversion.converter_registry import ( + DYNAMO_CONVERTERS, + dynamo_tensorrt_converter, + ) diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index da346635a2..0ef47ff2ef 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -7,7 +7,10 @@ import torch from torch._export import dynamic_dim, export from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo._defaults import ( + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + default_device, +) from torch_tensorrt.dynamo.lowering import get_decompositions from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device @@ -75,14 +78,11 @@ def trace( trace_inputs.append(torch_inputs[idx]) experimental_decompositions = kwargs.get( - "enable_experimental_decompositions", False + "enable_experimental_decompositions", ENABLE_EXPERIMENTAL_DECOMPOSITIONS ) with unittest.mock.patch( "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) ): - graph_module = export( - model, tuple(trace_inputs), constraints=constraints - ).module() + exp_program = export(model, tuple(trace_inputs), constraints=constraints) - logger.debug("Post export graph: " + str(graph_module.graph)) - return graph_module + return exp_program diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 0ef52edd43..5394c1382e 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -1,10 +1,12 @@ from __future__ import annotations +import collections.abc import logging from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch import torch_tensorrt +from torch.export import ExportedProgram from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum EngineCapability, @@ -34,6 +36,7 @@ from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.utils import ( get_torch_inputs, + prepare_inputs, set_log_level, to_torch_device, to_torch_tensorrt_device, @@ -43,7 +46,7 @@ def compile( - gm: Any, + exported_program: ExportedProgram, inputs: Any, *, device: Optional[Union[Device, torch.device, str]] = DEVICE, @@ -76,24 +79,23 @@ def compile( if debug: set_log_level(logger.parent, logging.DEBUG) + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + # Prepare torch_trt inputs + inputs = prepare_inputs(inputs) + device = to_torch_tensorrt_device(device) + + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + # Apply lowering on the graph module torch_inputs = get_torch_inputs(inputs, device) gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) enabled_precisions = set(enabled_precisions) - logger.warning( - "The Dynamo backend is an experimental feature, for which only the " - "following arguments are supported: " - "{enabled_precisions, debug, workspace_size, min_block_size, " - "max_aux_streams, version_compatible, optimization_level, " - "torch_executed_ops, pass_through_build_failures, " - "use_fast_partitioner, enable_experimental_decompositions, " - "require_full_compilation}" - ) - - device = to_torch_tensorrt_device(device) - if ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions @@ -207,12 +209,11 @@ def compile_module( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: continue - submodule = getattr(partitioned_module, name) - # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.get_submod_inputs( partitioned_module, @@ -239,19 +240,19 @@ def compile_module( name, ) - # Create TRT Module from submodule - trt_mod = convert_module( + # Create TRT engines from submodule + trt_module = convert_module( submodule, submodule_inputs, settings=settings, name=name, ) - trt_modules[name] = trt_mod + trt_modules[name] = trt_module # Replace all FX Modules with TRT Modules - for name, trt_mod in trt_modules.items(): - setattr(partitioned_module, name, trt_mod) + for name, trt_module in trt_modules.items(): + setattr(partitioned_module, name, trt_module) # Reset settings object to user specification after fallback to global partitioning mode if fast_partitioner_failed: diff --git a/py/torch_tensorrt/dynamo/export.py b/py/torch_tensorrt/dynamo/export.py new file mode 100644 index 0000000000..9bd1dbddb3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/export.py @@ -0,0 +1,269 @@ +import copy +import operator +from typing import Any, Dict, Sequence, Tuple, Union, cast + +import torch +from torch._export.exported_program import CallSpec +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensor +from torch.export import ExportedProgram, ExportGraphSignature +from torch_tensorrt.dynamo import partitioning + + +def transform( + gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + # Run shape analysis + _, outputs_map = partitioning.run_shape_analysis(gm, inputs) + + # Inline TensorRT submodules + inline_trt_modules(gm, outputs_map) + + # Inline pytorch submodules + inline_torch_modules(gm) + + # Lift constant buffers and parameters in the graph + # torch.export serialization expects them to be lifted + lift_constant_pass(gm) + + # Clean the graph + gm.delete_all_unused_submodules() + gm.graph.eliminate_dead_code() + gm.graph.lint() + + return gm + + +def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + fake_mode = detect_fake_mode( + tuple( + node.meta["val"] for node in trt_gm.graph.nodes if node.op == "placeholder" + ) + ) + + first_user_input = None + for node in trt_gm.graph.nodes: + if node.op == "placeholder": + first_user_input = node + break + + for node in trt_gm.graph.nodes: + if node.op == "get_attr": + constant_tensor = getattr(trt_gm, node.target) + with trt_gm.graph.inserting_before(first_user_input): + const_placeholder_node = trt_gm.graph.placeholder(node.target) + const_placeholder_node.meta = copy.deepcopy(node.meta) + const_placeholder_node.meta["val"] = fake_mode.from_tensor( + constant_tensor + ) + node.replace_all_uses_with(const_placeholder_node) + trt_gm.graph.erase_node(node) + + trt_gm.graph.eliminate_dead_code() + trt_gm.graph.lint() + return trt_gm + + +def get_duplicate_nodes( + gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule +) -> Tuple[Sequence[Any], Sequence[Any]]: + """ + We check if there are duplicate nodes when we copy submodule graph into gm. + Handle the case where the subgraph input placeholders are same as + gm placeholders. This happens when the first submodule in the graph is + a pytorch submodule + """ + submodule_placeholder_inputs = [ + node for node in submodule.graph.nodes if node.op == "placeholder" + ] + submodule_input_node_names = [node.name for node in submodule_placeholder_inputs] + gm_node_names = [node.name for node in gm.graph.nodes] + submodule_duplicate_inputs = [ + node for node in submodule_placeholder_inputs if node.name in gm_node_names + ] + gm_duplicate_inputs = [ + node for node in gm.graph.nodes if node.name in submodule_input_node_names + ] + return submodule_duplicate_inputs, gm_duplicate_inputs + + +def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Inline a submodule within the parent graph (gm). All `call_module` nodes + should be replaced by their submodule nodes. + """ + # Clean the graph + gm.graph.eliminate_dead_code() + gm.graph.lint() + + for gm_node in gm.graph.nodes: + if gm_node.op == "call_module" and "_run_on_gpu" in gm_node.name: + submodule = getattr(gm, gm_node.name) + with gm.graph.inserting_before(gm_node): + # Get inputs of submodule node which are most likely outputs of a previous TRT node + # or a placeholder of the main graph + submodule_inputs = gm_node.args + + submodule_duplicate_inputs, gm_duplicate_inputs = get_duplicate_nodes( + gm, submodule + ) + assert len(submodule_duplicate_inputs) == len(gm_duplicate_inputs) + # Avoid creating new copies of duplicate inputs by creating a mapping + val_map = {} + for i in range(len(submodule_duplicate_inputs)): + val_map[submodule_duplicate_inputs[i]] = gm_duplicate_inputs[i] + + # Copy all nodes in the submodule into gm and + # store the output node of this submodule which is now present in gm + + submodule_output = gm.graph.graph_copy(submodule.graph, val_map) + + # Get their references (since we copied) in the parent graph (gm) + if len(submodule_duplicate_inputs) == 0: + submodule_placeholder_input_names = [ + node.name + for node in submodule.graph.nodes + if node.op == "placeholder" + ] + gm_added_placeholder_inputs = [ + node + for node in gm.graph.nodes + if node.name in submodule_placeholder_input_names + ] + + assert len(submodule_inputs) == len(gm_added_placeholder_inputs) + + # Replace the added placeholder inputs with original inputs to this submodule node + for idx in range(len(gm_added_placeholder_inputs)): + gm_added_placeholder_inputs[idx].replace_all_uses_with( + submodule_inputs[idx] + ) + + # Erase the placeholder input nodes in the gm + for idx in range(len(gm_added_placeholder_inputs)): + gm.graph.erase_node(gm_added_placeholder_inputs[idx]) + + # Replace the pytorch submodule node (call_module) with the inlined subgraph output + gm_node.replace_all_uses_with(submodule_output) + + # copy the attributes of the submodule into gm (graph_copy doesn't do this) + copy_submodule_attributes(submodule, gm, gm_node.name) + + # Erase the pytorch submodule (call_module) node + gm.graph.erase_node(gm_node) + + return gm + + +def copy_submodule_attributes( + submodule: torch.fx.GraphModule, gm: torch.fx.GraphModule, submod_name: str +) -> None: + """ + Copy the getattr attriibutes from submodule to parent module gm. + The graph_copy call doesn't do this for us unfortunately. + """ + for idx, param in enumerate(gm.named_parameters()): + if submod_name in param[0]: + attr_name = param[0].replace(submod_name + ".", "") + gm.register_parameter(attr_name, param[1]) + + for idx, buffer in enumerate(gm.named_buffers()): + if submod_name in buffer[0]: + attr_name = buffer[0].replace(submod_name + ".", "") + gm.register_buffer(attr_name, buffer[1]) + + +def create_trt_exp_program( + gm: torch.fx.GraphModule, + call_spec: CallSpec, + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], +) -> ExportedProgram: + """Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines + and constructs an Exported Program object with the new IO node names, call_spec and state_dict + """ + input_node_names = [ + node.name for node in gm.graph.nodes if node.op == "placeholder" + ] + output_node_names = [node.name for node in gm.graph.nodes if node.op == "output"] + param_names = [param[0] for param in gm.named_parameters()] + buffer_names = [buffer[0] for buffer in gm.named_buffers()] + inputs_to_parameters = {} + inputs_to_buffers = {} + for node in gm.graph.nodes: + if node.target in param_names: + inputs_to_parameters[node.name] = node.target + if node.target in buffer_names: + inputs_to_buffers[node.name] = node.target + + trt_graph_signature = ExportGraphSignature( + parameters=param_names, + buffers=buffer_names, + user_inputs=input_node_names, + user_outputs=output_node_names, + inputs_to_parameters=inputs_to_parameters, + inputs_to_buffers=inputs_to_buffers, + buffers_to_mutate={}, + backward_signature=None, + assertion_dep_token=None, + ) + + trt_exp_program = ExportedProgram( + gm, gm.graph, trt_graph_signature, call_spec, state_dict, {}, [], [] + ) + + return trt_exp_program + + +def inline_trt_modules( + gm: torch.fx.GraphModule, outputs_map: Dict[Any, Sequence[Any]] +) -> torch.fx.GraphModule: + """ + Replace TRT submodules with trt engine nodes. + """ + for name, _ in gm.named_children(): + if "_run_on_acc" not in name: + continue + # Get the TRT submodule + trt_module = getattr(gm, name) + + # Ensure the trt module node in the main graph (gm) has inputs + trt_module_node = [node for node in gm.graph.nodes if node.name == name] + assert trt_module_node + trt_module_node = trt_module_node[0] + assert trt_module_node.args + + num_outputs = len(outputs_map[trt_module_node.name]) + # Insert a call_function node to perform inference on TRT engine + with gm.graph.inserting_before(trt_module_node): + trt_node = gm.graph.call_function( + torch.ops.tensorrt.execute_engine.default, + (trt_module_node.args, trt_module.engine), + ) + trt_node.meta["val"] = [] + # Generate meta data for TRT node (a FakeTensor with corresponding output shape) + for idx in range(num_outputs): + trt_node.meta["val"].append( + cast( + FakeTensor, + torch.empty_strided( + tuple(outputs_map[trt_module_node.name][idx]), + tuple([1] * len(outputs_map[trt_module_node.name][idx])), + ), + ) + ) + + if num_outputs == 1: + # Insert getitem nodes as outputs (for export serialization to work) + with gm.graph.inserting_after(trt_node): + getitem_output = gm.graph.call_function(operator.getitem, (trt_node, 0)) + trt_module_node.replace_all_uses_with(getitem_output) + else: + # Multiple outputs case: + # Replace uses of submodule with the trt_node. + # getitem nodes are already added inherently by the partitioner + trt_module_node.replace_all_uses_with(trt_node) + + # Erase the TRT submodule (call_module) node. + gm.graph.erase_node(trt_module_node) + + return gm diff --git a/py/torch_tensorrt/dynamo/partitioning/__init__.py b/py/torch_tensorrt/dynamo/partitioning/__init__.py index 1f9d11b14b..1a8cc94099 100644 --- a/py/torch_tensorrt/dynamo/partitioning/__init__.py +++ b/py/torch_tensorrt/dynamo/partitioning/__init__.py @@ -1,3 +1,3 @@ from ._adjacency_partitioner import partition as fast_partition from ._global_partitioner import partition as global_partition -from .common import get_graph_converter_support, get_submod_inputs +from .common import get_graph_converter_support, get_submod_inputs, run_shape_analysis diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 14c068260f..c6eee22ab3 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional, Sequence, Set, Tuple +from typing import Any, Dict, Optional, Sequence, Set, Tuple import torch from torch.fx.node import _get_qualified_name @@ -16,6 +16,43 @@ } +def run_shape_analysis( + parent_module: torch.fx.GraphModule, inputs: Sequence[Input] +) -> Tuple[Dict[Any, Sequence[Any]], Dict[Any, Sequence[Any]]]: + submod_inputs_shape_map: Dict[Any, Sequence[Any]] = {} + submod_outputs_shape_map: Dict[Any, Sequence[Any]] = {} + sub_inputs: Sequence[torch.Tensor] = [] + sub_outputs: Sequence[torch.Tensor] = [] + + # Register a hook to capture IO shapes for submodules + def get_submodule_io( + self: Any, inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor] + ) -> None: + nonlocal sub_inputs, sub_outputs + sub_inputs = inputs + sub_outputs = outputs + return + + # Iterate through submodules (both Torch and TRT) and store IO shapes + for name, _ in parent_module.named_children(): + submodule = getattr(parent_module, name) + handle = submodule.register_forward_hook(get_submodule_io) + parent_module(*inputs) + handle.remove() + submod_inputs_shape_map[name] = ( + [input.shape for input in sub_inputs] + if isinstance(sub_inputs, (tuple, list)) + else [sub_inputs.shape] + ) + submod_outputs_shape_map[name] = ( + [output.shape for output in sub_outputs] + if isinstance(sub_outputs, (tuple, list)) + else [sub_outputs.shape] + ) + + return submod_inputs_shape_map, submod_outputs_shape_map + + def get_submod_inputs( mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index c997250b5f..0d77f3f712 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -12,6 +12,7 @@ # Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py new file mode 100644 index 0000000000..5e0dc7406c --- /dev/null +++ b/tests/py/dynamo/models/test_export_serde.py @@ -0,0 +1,349 @@ +import unittest + +import pytest +import timm +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch._export.serde.serialize import deserialize, serialize +from torch_tensorrt.dynamo.export import create_trt_exp_program, transform +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_base_full_compile(ir): + """ + This tests export serde functionality on a base model + which is fully TRT convertible + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_gm = transform(trt_gm, [input]) + trt_exp_program = create_trt_exp_program( + trt_gm, exp_program.call_spec, trt_gm.state_dict() + ) + serialized_prog = serialize(trt_exp_program) + deserialized_prog = deserialize(*serialized_prog) + + # Check Pyt and TRT exported program outputs + cos_sim = cosine_similarity(model(input), trt_exp_program(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + # Check Pyt and deserialized TRT exported program outputs + cos_sim = cosine_similarity(model(input), deserialized_prog(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_base_full_compile_multiple_outputs(ir): + """ + This tests export serde functionality on a base model + with multiple outputs which is fully TRT convertible + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + conv = conv * 0.5 + relu = self.relu(conv) + return conv, relu + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_gm = transform(trt_gm, [input]) + trt_exp_program = create_trt_exp_program( + trt_gm, exp_program.call_spec, trt_gm.state_dict() + ) + + serialized_prog = serialize(trt_exp_program) + deserialized_prog = deserialize(*serialized_prog) + # Check Pyt and TRT exported program outputs + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Check Pyt and deserialized TRT exported program outputs + outputs_trt_deser = deserialized_prog(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_base_full_compile_save_load(ir): + """ + This tests export save and load functionality on a base model + with multiple outputs which is fully TRT convertible + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + conv = conv * 0.5 + relu = self.relu(conv) + return conv, relu + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_gm = transform(trt_gm, [input]) + trt_exp_program = create_trt_exp_program( + trt_gm, exp_program.call_spec, trt_gm.state_dict() + ) + + torch._export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + # Check Pyt and TRT exported program outputs + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + # Check Pyt and deserialized TRT exported program outputs + outputs_trt_deser = deser_trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_hybrid_relu_fallback(ir): + """ + This tests export save and load functionality on a hybrid + model with Pytorch and TRT segments. Relu (unweighted) layer is forced to + fallback + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + "torch_executed_ops": "torch.ops.aten.relu.default", + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_gm = transform(trt_gm, [input]) + trt_exp_program = create_trt_exp_program( + trt_gm, exp_program.call_spec, trt_gm.state_dict() + ) + + torch._export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_exp_program(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +def test_resnet18_save_load(ir): + """ + This tests export save and load functionality on Resnet18 model + """ + model = models.resnet18().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "ir": ir, + "min_block_size": 1, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec) + trt_gm = transform(trt_gm, [input]) + trt_exp_program = create_trt_exp_program( + trt_gm, exp_program.call_spec, trt_gm.state_dict() + ) + torch._export.save(trt_exp_program, "/tmp/trt.ep") + deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + + outputs_pyt = model(input) + outputs_trt = trt_exp_program(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + outputs_trt_deser = deser_trt_exp_program(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341 +# @pytest.mark.unit +# def test_hybrid_conv_fallback(ir): +# """ +# This tests export save and load functionality on a hybrid +# model where a conv (a weighted layer) has been forced to fallback to Pytorch. +# """ + +# class MyModule(torch.nn.Module): +# def __init__(self): +# super().__init__() +# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) +# self.relu = torch.nn.ReLU() + +# def forward(self, x): +# conv = self.conv(x) +# relu = self.relu(conv) +# mul = relu * 0.5 +# return mul + +# model = MyModule().eval().cuda() +# input = torch.randn((1, 3, 224, 224)).to("cuda") + +# compile_spec = { +# "inputs": [ +# torchtrt.Input( +# input.shape, dtype=torch.float, format=torch.contiguous_format +# ) +# ], +# "ir": ir, +# "min_block_size": 1, +# "torch_executed_ops": "torch.ops.aten.convolution.default", +# } + +# trt_exp_program = torchtrt.compile(model, **compile_spec) +# torch._export.save(trt_exp_program, "/tmp/trt.ep") +# deser_trt_exp_program = torch._export.load("/tmp/trt.ep") + +# outputs_pyt = model(input) +# outputs_trt = trt_exp_program(input) +# for idx in range(len(outputs_pyt)): +# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# ) + +# outputs_trt_deser = deser_trt_exp_program(input) +# for idx in range(len(outputs_pyt)): +# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) +# assertions.assertTrue( +# cos_sim > COSINE_THRESHOLD, +# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", +# )