Skip to content

chore: fix docs for export [release/2.1] #2448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docsrc/dynamo/dynamo_export.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _dynamo_export:

Compiling ``ExportedPrograms`` with Torch-TensorRT
Compiling Exported Programs with Torch-TensorRT
=============================================
.. currentmodule:: torch_tensorrt.dynamo

Expand All @@ -9,8 +9,6 @@ Compiling ``ExportedPrograms`` with Torch-TensorRT
:undoc-members:
:show-inheritance:

Using the Torch-TensorRT Frontend for ``torch.export.ExportedPrograms``
--------------------------------------------------------
Pytorch 2.1 introduced ``torch.export`` APIs which
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
frontend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple
Expand Down Expand Up @@ -43,8 +41,7 @@ Some of the frequently used options are as follows:

The complete list of options can be found `here <https://github.com/pytorch/TensorRT/blob/123a486d6644a5bbeeec33e2f32257349acc0b8f/py/torch_tensorrt/dynamo/compile.py#L51-L77>`_

.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in
our Torchscript IR. We plan to implement similar support for dynamo in our next release.
.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in our Torchscript IR. We plan to implement similar support for dynamo in our next release.

Under the hood
--------------
Expand Down
82 changes: 0 additions & 82 deletions docsrc/user_guide/dynamo_export.rst

This file was deleted.

41 changes: 21 additions & 20 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ The following code illustrates this approach.
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_traced_model = torchtrt.dynamo.serialize(trt_gm, inputs)
trt_traced_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_traced_model, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)
model(*inputs)

b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -50,39 +50,40 @@ b) ExportedProgram
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
exp_program = torch_tensorrt.dynamo.trace(model, inputs)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs, call_spec, ir="exported_program")
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs, exp_program.call_spec, ir="exported_program")
torch.export.save(trt_exp_program, "trt_model.ep")

# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model(inputs)
model(*inputs)

`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes, stiches all the nodes together and creates an ExportedProgram.
This is needed as `torch.export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341


Torchscript IR
-------------

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.

.. code-block:: python
.. code-block:: python

import torch
import torch_tensorrt
import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)

2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def export(
return exp_program
else:
raise ValueError(
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
)


Expand Down