Skip to content

chore: Set return type of compilation to ExportedProgram [release/2.2] #2607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
@@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
Dynamo IR
-------------

Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
The `output_format` can take the following options

a) Converting to Torchscript
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
* `torchscript` (or) `ts` : This returns a TorchScript module
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.

a) Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
The following code illustrates this approach.
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save

.. code-block:: python
@@ -30,9 +34,9 @@ The following code illustrates this approach.
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_traced_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_traced_model, "trt_model.ts")
# trt_ts is a torch.jit.ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
torch.jit.save(trt_ts, "trt_model.ts")
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
@@ -41,8 +45,7 @@ The following code illustrates this approach.
b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.

.. code-block:: python
@@ -51,26 +54,36 @@ b) ExportedProgram
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
torch.export.save(trt_exp_program, "trt_model.ep")
# trt_ep is a torch.export.ExportedProgram object
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch.export.save(trt_ep, "trt_model.ep")
# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model(*inputs)
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
c) GraphModule
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
exported into `ExportedProgram` objects

.. code-block:: python
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
Torchscript IR
-------------

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.
For `ir=ts`, this behavior stays the same in 2.X versions as well.

.. code-block:: python
4 changes: 3 additions & 1 deletion examples/int8/training/vgg16/vgg16.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,12 @@
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""

from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce


class VGG(nn.Module):
10 changes: 6 additions & 4 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
@@ -32,12 +32,14 @@ class Device(object):
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""

device_type: Optional[
trt.DeviceType
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
device_type: Optional[trt.DeviceType] = (
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
)
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
allow_gpu_fallback: bool = (
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
)

def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
@@ -28,12 +28,12 @@ class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1

shape_mode: Optional[
_ShapeMode
] = None #: Is input statically or dynamically shaped
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
shape_mode: Optional[_ShapeMode] = (
None #: Is input statically or dynamically shaped
)
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
)
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
15 changes: 11 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
OUTPUT_FORMAT,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
@@ -38,6 +39,7 @@
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo._exporter import export
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
convert_module,
@@ -88,6 +90,7 @@ def compile(
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
output_format: str = OUTPUT_FORMAT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -144,6 +147,7 @@ def compile(
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -200,9 +204,9 @@ def compile(
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else set(),
"torch_executed_ops": (
torch_executed_ops if torch_executed_ops is not None else set()
),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
@@ -219,11 +223,14 @@ def compile(
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"output_format": output_format,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
return compile_module(gm, inputs, settings)
trt_gm = compile_module(gm, inputs, settings)
trt_result = export(trt_gm, torch_inputs, output_format)
return trt_result


def compile_module(
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
REQUIRE_FULL_COMPILATION = False
OUTPUT_FORMAT = "exported_program"


def default_device() -> Device:
138 changes: 88 additions & 50 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import operator
from typing import Any, Dict, Sequence, Tuple, cast

@@ -19,50 +18,43 @@
def export(
gm: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
*,
ir: str = "torchscript",
output_format: str,
) -> ExportedProgram:
"""Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded.
> Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile``
"""Export the result of TensorRT compilation into the desired output format.
Arguments:
src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``)
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
Keyword Arguments:
inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::
input=[
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.
inputs (torch.Tensor): Torch input tensors
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
"""
if ir == "torchscript":
if output_format == "torchscript" or output_format == "ts":
return torch.jit.trace(gm, inputs)
elif ir == "exported_program":
elif output_format == "exported_program" or output_format == "ep":
patched_module = transform(gm, inputs)
exp_program = create_trt_exp_program(patched_module)

return exp_program
elif output_format == "graph_module" or output_format == "fx":
return gm
else:
raise ValueError(
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx"
)


def transform(
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
Inlining collapses submodules into nodes which is necessary for torch.export
serialization.
Arguments:
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
inputs (torch.Tensor): Torch input tensors
Returns an inlined torch.fx.GraphModule
"""
# Run shape analysis
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)

@@ -72,10 +64,6 @@ def transform(
# Inline pytorch submodules
inline_torch_modules(gm)

# Lift constant buffers and parameters in the graph
# torch.export serialization expects them to be lifted
lift_constant_pass(gm)

# Clean the graph
gm.delete_all_unused_submodules()
gm.graph.eliminate_dead_code()
@@ -84,34 +72,80 @@ def transform(
return gm


def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule:
"""
Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
Arguments:
gm (torch.fx.GraphModule): Unlifted GraphModule which contains parameters and buffers as get_attr nodes.
graph_signature (torch.export.ExportGraphSignature): Instance of ExportGraphSignature class created for the output ExportedProgram.
After lifting, this graph_signature will be modified with the parameters and buffers added appropriately.
Returns:
A lifted fx.GraphModule, modified graph_signature and a new state_dict
"""
# Get the state_dict of graph_module. This is different from exported_program.state_dict
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
# has all parameters registered as torch.tensors.
state_dict = gm.state_dict()

fake_mode = detect_fake_mode(
tuple(
node.meta["val"] for node in trt_gm.graph.nodes if node.op == "placeholder"
)
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
)
assert fake_mode is not None

# Locate the user input to insert new placeholders before them
first_user_input = None
for node in trt_gm.graph.nodes:
if node.op == "placeholder":
for node in gm.graph.nodes:
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
first_user_input = node
break

for node in trt_gm.graph.nodes:
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
# The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
non_user_input_idx = 0
for node in gm.graph.nodes:
if node.op == "get_attr":
constant_tensor = getattr(trt_gm, node.target)
with trt_gm.graph.inserting_before(first_user_input):
const_placeholder_node = trt_gm.graph.placeholder(node.target)
const_placeholder_node.meta = copy.deepcopy(node.meta)
constant_tensor = getattr(gm, node.target)
input_kind = InputKind.CONSTANT_TENSOR

# state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
for name, _ in gm.named_parameters():
if node.target == name:
input_kind = InputKind.PARAMETER
state_dict[name] = constant_tensor
break
for name, _ in gm.named_buffers():
if node.target == name:
input_kind = InputKind.BUFFER
state_dict[name] = constant_tensor
break

# Replace get_attr nodes with placeholder nodes and copy metadata.
with gm.graph.inserting_before(first_user_input):
const_placeholder_node = gm.graph.placeholder(node.target)
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
constant_tensor
)
node.replace_all_uses_with(const_placeholder_node)
trt_gm.graph.erase_node(node)
gm.graph.erase_node(node)

# Add these parameters/buffers/constants to the existing graph signature
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
graph_signature.input_specs.insert(
non_user_input_idx,
InputSpec(
kind=input_kind,
arg=TensorArgument(name=const_placeholder_node.name),
target=node.target,
),
)
non_user_input_idx += 1

gm.graph.eliminate_dead_code()
gm.graph.lint()

trt_gm.graph.eliminate_dead_code()
trt_gm.graph.lint()
return trt_gm
return gm, graph_signature, state_dict


def get_duplicate_nodes(
@@ -140,7 +174,7 @@ def get_duplicate_nodes(
def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Inline a submodule within the parent graph (gm). All `call_module` nodes
should be replaced by their submodule nodes.
should be replaced by their nodes in the submodule.
"""
# Clean the graph
gm.graph.eliminate_dead_code()
@@ -165,7 +199,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:

# Copy all nodes in the submodule into gm and
# store the output node of this submodule which is now present in gm

submodule_output = gm.graph.graph_copy(submodule.graph, val_map)

# Get their references (since we copied) in the parent graph (gm)
@@ -227,6 +260,7 @@ def create_trt_exp_program(
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
and constructs an Exported Program object with the new IO node names and state_dict
"""

input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
assert output_nodes
@@ -245,8 +279,12 @@ def create_trt_exp_program(
input_specs=input_specs, output_specs=output_specs
)

# Lift parameters/buffers/constants in the graph
# torch.export serialization expects them to be lifted
gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature)

trt_exp_program = ExportedProgram(
gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], []
gm, gm.graph, trt_graph_signature, state_dict, {}, [], [], []
)

return trt_exp_program
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
OUTPUT_FORMAT,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
@@ -64,6 +65,7 @@ class CompilationSettings:
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
"""

precision: torch.dtype = PRECISION
@@ -89,3 +91,4 @@ class CompilationSettings:
dla_sram_size: int = DLA_SRAM_SIZE
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
output_format: str = OUTPUT_FORMAT
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -28,9 +28,9 @@

_LOGGER: logging.Logger = logging.getLogger(__name__)

TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
Callable[[torch.fx.GraphModule], None]
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
)


class UnsupportedOperatorException(RuntimeError):
@@ -92,9 +92,9 @@ def __init__(
self._cur_node: Optional[torch.fx.Node] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
)
self.compilation_settings = compilation_settings

# Data types for TRT Module output Tensors
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -324,13 +324,11 @@ def get_trt_tensor(


@overload
def get_positive_dim(dim: int, dim_size: int) -> int:
...
def get_positive_dim(dim: int, dim_size: int) -> int: ...


@overload
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
...
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...


def get_positive_dim(
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,9 @@

aten = torch.ops.aten

_core_aten_decompositions: Dict[
OpOverload, Callable[[Any], Any]
] = core_aten_decompositions()
_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
core_aten_decompositions()
)
torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
@@ -180,9 +180,9 @@
}


ENABLED_TORCH_DECOMPOSITIONS: Dict[
OpOverload, Callable[[Any], Any]
] = get_torch_decompositions(torch_enabled_decompositions)
ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
get_torch_decompositions(torch_enabled_decompositions)
)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


Original file line number Diff line number Diff line change
@@ -27,12 +27,10 @@ def lower_efficient_attention(
return gm


def efficient_attention_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
def efficient_attention_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for efficient attention"""

# Original graph
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
@@ -22,12 +22,10 @@ def lower_linear(
return gm


def linear_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
):
def linear_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Constructs the original and replacement functions for linear"""

# Original graph
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -22,12 +22,10 @@ def view_to_reshape(
return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
def view_replacement() -> Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]:
"""Constructs the original and replacement functions for view"""

# Original graph
57 changes: 37 additions & 20 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

import tensorrt as trt
import torch
import torch_tensorrt
from torch.nn import Module
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo.runtime.tools import (
@@ -15,8 +16,6 @@
)
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import torch_tensorrt

logger = logging.getLogger(__name__)


@@ -101,9 +100,11 @@ def _initialize(self) -> None:
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes = [
@@ -113,9 +114,11 @@ def _initialize(self) -> None:
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.hidden_output_binding_indices_in_order
]

@@ -167,9 +170,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:Forward"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
else nullcontext()
):
self._check_initialized()

# If in safe mode, check at each iteration for for whether a switch is required
@@ -200,9 +205,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
)
if self.profiling_enabled
else nullcontext()
):
assert len(inputs) == len(
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
@@ -239,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
idx, tuple(contiguous_inputs[i].shape)
)

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
# create output tensors
outputs: List[torch.Tensor] = []

@@ -266,9 +279,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
)
bindings[idx] = output.data_ptr()

with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
):
self.context.execute_async_v2(
bindings, torch.cuda.current_stream().cuda_stream
)
38 changes: 16 additions & 22 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -3,24 +3,22 @@
import math
import operator
import warnings
from typing import cast, Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union, cast

import numpy as np

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.converters.impl import activation, convolution

from ..converter_registry import tensorrt_converter

from ..types import * # noqa: F403
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation, convolution

_LOGGER: logging.Logger = logging.getLogger(__name__)

@@ -317,21 +315,17 @@ def aten_ops_max_poolnd(
kwargs_new = {
"input": args[0],
"kernel_size": args[1],
"stride": args[2]
if len(args) > 2
else (None, None)
if len(args[1]) == 2
else (None, None, None),
"padding": args[3]
if len(args) > 3
else (0, 0)
if len(args[1]) == 2
else (0, 0, 0),
"dilation": args[4]
if len(args) > 4
else (1, 1)
if len(args[1]) == 2
else (1, 1, 1),
"stride": (
args[2]
if len(args) > 2
else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
),
"ceil_mode": args[5] if len(args) > 5 else False,
}
return acc_ops_converters.acc_ops_max_poolnd(
14 changes: 7 additions & 7 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
@@ -17,13 +17,13 @@
from .converter_registry import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from .observer import Observer
from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks
from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)

TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
Callable[[torch.fx.GraphModule], None]
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
)


class TRTInterpreterResult(NamedTuple):
@@ -75,9 +75,9 @@ def __init__(
self._cur_node_name: Optional[str] = None
self._input_names: List[str] = []
self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
)

def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
23 changes: 12 additions & 11 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@
from .passes.pass_utils import PassFunc, validate_inference
from .tools.timing_cache_utils import TimingCacheManager
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting

from .tracer.acc_tracer import acc_tracer
from .trt_module import TRTModule
from .utils import LowerPrecision
@@ -126,9 +125,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
input_specs=self.lower_setting.input_specs,
explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,
explicit_precision=self.lower_setting.explicit_precision,
logger_level=trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING,
logger_level=(
trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING
),
)

interp_result: TRTInterpreterResult = interpreter.run(
@@ -138,9 +139,11 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
strict_type_constraints=self.lower_setting.strict_type_constraints,
algorithm_selector=algo_selector,
timing_cache=cache_data,
profiling_verbosity=trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
profiling_verbosity=(
trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
tactic_sources=self.lower_setting.tactic_sources,
)

@@ -297,10 +300,8 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
# handle inputs with custom types. By default, just handle
# tensors and NoneType.
if fp16_conversion_fn is None:
conversion_fn = (
lambda x: x.half()
if x is not None and x.dtype == torch.float32
else x
conversion_fn = lambda x: (
x.half() if x is not None and x.dtype == torch.float32 else x
)
else:
conversion_fn = fp16_conversion_fn
7 changes: 3 additions & 4 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,6 @@
from torch.fx.experimental.const_fold import split_const_subgraphs

from ..observer import observable

from ..tracer.acc_tracer import acc_ops
from ..tracer.acc_tracer.acc_utils import get_attr
from .pass_utils import log_before_after, validate_inference
@@ -538,9 +537,9 @@ def get_reshape_batch_size_inferred_source(
)
if not reshape_batch_size:
continue
reshape_batch_size_inferred_source: Optional[
fx.Node
] = get_reshape_batch_size_inferred_source(reshape_batch_size)
reshape_batch_size_inferred_source: Optional[fx.Node] = (
get_reshape_batch_size_inferred_source(reshape_batch_size)
)
if not reshape_batch_size_inferred_source:
continue

23 changes: 12 additions & 11 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -5,19 +5,17 @@

import torch
from torch import nn
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
from torch.fx.passes.pass_manager import PassManager, inplace_wrapper
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
from torch.fx.passes.splitter_base import SplitResult, generate_inputs_for_submodules
from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion
from torch_tensorrt.fx.utils import LowerPrecision

from ..input_tensor_spec import generate_input_specs

from ..lower_setting import LowerSetting
from ..observer import Observer
from ..passes.remove_duplicate_output_args import remove_duplicate_output_args
from .graph_opts import common_subexpression_elimination

from .lower_basic_pass import ( # noqa
fix_clamp_numerical_limits_to_fp16,
fix_reshape_batch_dim,
@@ -26,7 +24,6 @@
run_const_fold,
)


_LOGGER: logging.Logger = logging.getLogger(__name__)


@@ -196,9 +193,11 @@ def lower_func(split_result: SplitResult) -> nn.Module:
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
(
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None
),
)
lowered_module = self._lower_func(
submod, submod_inputs, self.lower_setting, submod_name
@@ -236,9 +235,11 @@ def lower_func(split_result: SplitResult) -> nn.Module:
lowering_start_time = datetime.datetime.now()

self.lower_setting.additional_inputs = (
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
(
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None
),
)

lowered_module = self._lower_func(
4 changes: 1 addition & 3 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -195,9 +195,7 @@ def pass_with_validation(
kwargs2["rtol"] = rtol
if atol:
kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
kwargs2["msg"] = (
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
)
# If tensors are on different devices, make sure to compare
16 changes: 10 additions & 6 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_split.py
Original file line number Diff line number Diff line change
@@ -23,9 +23,11 @@ def forward(self, x):
Split(),
inputs,
expected_ops={
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
(
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
)
},
test_explicit_batch_dim=False,
)
@@ -70,9 +72,11 @@ def forward(self, x):
Split(),
input_specs,
expected_ops={
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
(
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
)
},
)

7 changes: 3 additions & 4 deletions py/torch_tensorrt/fx/tools/common_fx2trt.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
import tensorrt as trt
import torch
import torch.fx

import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
from torch.fx.experimental.normalize import NormalizeArgs
@@ -154,9 +153,9 @@ def run_test_custom_compare_results(
self.assert_has_op(mod, expected_ops)

interpreter_result = interpreter.run(
lower_precision=LowerPrecision.FP16
if fp16_mode
else LowerPrecision.FP32
lower_precision=(
LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
)
)
trt_mod = TRTModule(
interpreter_result.engine,
18 changes: 11 additions & 7 deletions py/torch_tensorrt/fx/trt_module.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import tensorrt as trt
import torch

from .utils import unified_dtype_converter, Frameworks
from .utils import Frameworks, unified_dtype_converter


class TRTModule(torch.nn.Module):
@@ -69,9 +69,11 @@ def _initialize(self):
for idx in self.output_binding_indices_in_order
]
self.output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.output_binding_indices_in_order
]
self.hidden_output_dtypes: Sequence[torch.dtype] = [
@@ -81,9 +83,11 @@ def _initialize(self):
for idx in self.hidden_output_binding_indices_in_order
]
self.hidden_output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
)
for idx in self.hidden_output_binding_indices_in_order
]

9 changes: 4 additions & 5 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

import tensorrt as trt
import torch
import torch_tensorrt._C.ts as _ts_C
from torch_tensorrt import _C, _enums
@@ -11,8 +12,6 @@
from torch_tensorrt.logging import Level, log
from torch_tensorrt.ts._Input import TorchScriptInput

import tensorrt as trt


def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input:
clone = torch.classes.tensorrt._Input()
@@ -406,9 +405,9 @@ def TensorRTCompileSpec(
"device": device,
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
"enabled_precisions": enabled_precisions
if enabled_precisions is not None
else set(), # Enabling FP16 kernels
"enabled_precisions": (
enabled_precisions if enabled_precisions is not None else set()
), # Enabling FP16 kernels
"refit": refit, # enable refit
"debug": debug, # enable debuggable engine
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -177,7 +177,7 @@ skip = [

[tool.black]
#line-length = 120
target-versions = ["py38", "py39", "py310", "py311", "py312"]
target-version = ["py38", "py39", "py310", "py311"]
force-exclude = """
elu_converter/setup.py
"""
152 changes: 74 additions & 78 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
@@ -42,8 +42,7 @@ def forward(self, x):
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

@@ -94,8 +93,7 @@ def forward(self, x):
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
# Check Pyt and TRT exported program outputs
@@ -114,15 +112,15 @@ def forward(self, x):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_base_full_compile_save_load(ir):
def test_no_compile(ir):
"""
This tests export save and load functionality on a base model
with multiple outputs which is fully TRT convertible
This tests export serde functionality on a model
which won't convert to TRT because of min_block_size=5 constraint
"""

class MyModule(torch.nn.Module):
@@ -147,31 +145,30 @@ def forward(self, x):
)
],
"ir": ir,
"min_block_size": 1,
"debug": True,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
# Check Pyt and TRT exported program outputs
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Check Pyt and deserialized TRT exported program outputs
outputs_trt_deser = deser_trt_exp_program(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@@ -210,8 +207,7 @@ def forward(self, x):
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

@@ -221,20 +217,20 @@ def forward(self, x):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_exp_program(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_resnet18_save_load(ir):
def test_resnet18(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
@@ -252,8 +248,7 @@ def test_resnet18_save_load(ir):
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_gm = torchtrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torchtrt.dynamo.export(trt_gm, [input], ir="exported_program")
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

@@ -262,69 +257,70 @@ def test_resnet18_save_load(ir):
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_exp_program(input)

cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


# Enable this test once this issue is resolved https://github.com/pytorch/TensorRT/issues/2341
# @pytest.mark.unit
# def test_hybrid_conv_fallback(ir):
# """
# This tests export save and load functionality on a hybrid
# model where a conv (a weighted layer) has been forced to fallback to Pytorch.
# """

# class MyModule(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
# self.relu = torch.nn.ReLU()

# def forward(self, x):
# conv = self.conv(x)
# relu = self.relu(conv)
# mul = relu * 0.5
# return mul

# model = MyModule().eval().cuda()
# input = torch.randn((1, 3, 224, 224)).to("cuda")

# compile_spec = {
# "inputs": [
# torchtrt.Input(
# input.shape, dtype=torch.float, format=torch.contiguous_format
# )
# ],
# "ir": ir,
# "min_block_size": 1,
# "torch_executed_ops": "torch.ops.aten.convolution.default",
# }

# trt_exp_program = torchtrt.compile(model, **compile_spec)
# torch.export.save(trt_exp_program, "/tmp/trt.ep")
# deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

# outputs_pyt = model(input)
# outputs_trt = trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )

# outputs_trt_deser = deser_trt_exp_program(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_save_load TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
@pytest.mark.unit
def test_hybrid_conv_fallback(ir):
"""
This tests export save and load functionality on a hybrid
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")

outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)

for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_exp_program(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
60 changes: 60 additions & 0 deletions tests/py/dynamo/models/test_output_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest

import pytest
import torch
import torch_tensorrt as torchtrt

assertions = unittest.TestCase()


@pytest.mark.unit
def test_output_format(ir):
"""
This tests output_format type in the compilation setting
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
out = self.relu(out)
return out

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

trt_ep = torchtrt.compile(model, ir=ir, inputs=[input], min_block_size=1)
assertions.assertTrue(
isinstance(trt_ep, torch.export.ExportedProgram),
msg=f"test_output_format output type does not match with torch.export.ExportedProgram",
)

trt_ts = torchtrt.compile(
model,
ir=ir,
inputs=[input],
min_block_size=1,
output_format="torchscript",
)
assertions.assertTrue(
isinstance(trt_ts, torch.jit.ScriptModule),
msg=f"test_output_format output type does not match with torch.jit.ScriptModule",
)

trt_gm = torchtrt.compile(
model,
ir=ir,
inputs=[input],
min_block_size=1,
output_format="graph_module",
)
assertions.assertTrue(
isinstance(trt_gm, torch.fx.GraphModule),
msg=f"test_output_format output type does not match with torch.fx.GraphModule",
)
# Clean up model env
torch._dynamo.reset()