-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Implement support for exporting Torch-TensorRT compiled graphs using torch.export serde APIs #2249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
feat: Implement support for exporting Torch-TensorRT compiled graphs using torch.export serde APIs #2249
Changes from all commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
cc42ca3
feat: Express TRT engines as nodes instead of modules
peri044 afcd5ec
chore: Fix input nodes to TRT graph
peri044 58dcc4f
chore: prototype
peri044 a57f3c0
chore: minor change
peri044 f1f202e
feat: Move tracing to use aot export apis
peri044 abaf047
chore: minor changes
peri044 370099f
chore: minor change
peri044 bb1f3cf
chore: minor changes
peri044 3d05b4d
chore: Rebase with main
peri044 8d99be5
chore: rebase
peri044 0aad214
chore: minor logging updates
peri044 8899735
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 8af2627
fix: Refactor tensor freezing in Dynamo
gs-olive f6969be
Key op fixes for failing tests
gs-olive bad1594
fix: Add constant folding utility to freezing
gs-olive db56dd6
chore: Move to new export APIs
peri044 bf961f5
chore: rebase with dynamo_tensor_freeze branch
peri044 b13aa82
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive dd95620
fix: Refactor tensor freezing in Dynamo
gs-olive 6bd3c64
Key op fixes for failing tests
gs-olive 248073f
fix: Add constant folding utility to freezing
gs-olive 3e5f434
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive 6bf6945
fix: Refactor tensor freezing in Dynamo
gs-olive 3b6e1e7
Key op fixes for failing tests
gs-olive 2107d8e
fix: Add constant folding utility to freezing
gs-olive fd5a41e
chore: add BERT test case
peri044 f047651
chore: remove pdb
peri044 4862c68
chore: rebase with main
peri044 0ec68e6
chore: rebase with export_prototype branch
peri044 1a39cae
feat: Express TRTengines as nodes
peri044 ab76c0d
chore: rebase
peri044 0cac5ad
chore: refactor
peri044 e4df382
feat: Add preliminary support for freezing tensors in Dynamo
gs-olive d022f4a
fix: Refactor tensor freezing in Dynamo
gs-olive 9610ba7
Key op fixes for failing tests
gs-olive e19aae7
fix: Add constant folding utility to freezing
gs-olive 2860be6
Merge remote-tracking branch 'origin/dynamo_tensor_freeze' into expor…
peri044 ae98595
chore: refactor code and add test cases for serde
peri044 601ff44
chore: Add support for hybrid graph save/load by inlining pytorch sub…
peri044 1be093f
chore: rebase with export_prototype
peri044 d73ef1c
chore: minor updates
peri044 88328c6
chore: minor updates
peri044 a4251f1
chore: updates
peri044 3790362
chore: rebase with main
peri044 20e2a42
chore: updates
peri044 e588a17
chore: update docs
peri044 c457813
chore: rebase with main
peri044 5d33251
chore: uncomment a failing test
peri044 47822d6
chore: updates
peri044 16640d0
chore: rebase
peri044 7522a71
chore: rebase
peri044 3bcb02d
chore: address review comments
peri044 07f357c
chore: fix tests
peri044 b2b6373
chore: revert harness.py changes
peri044 4d82e17
chore: fix tests
peri044 200b03f
chore: address review comments
peri044 52017d2
chore: updates
peri044 29073c3
chore: updates
peri044 c4c6e5c
chore: rebase with main
peri044 fbe929f
chore: fix tests
peri044 9ea829d
chore: address review comments
peri044 f24a646
chore: revert fx changes
peri044 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
.. _runtime: | ||
|
||
Saving models compiled with Torch-TensorRT | ||
==================================== | ||
|
||
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation. | ||
|
||
1) Dynamo IR | ||
|
||
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based. | ||
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects | ||
|
||
a) Converting to Torchscript | ||
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk. | ||
The following code illustrates this approach. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
inputs = torch.randn((1, 3, 224, 224)).cuda() | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule | ||
trt_script_model = torch.jit.trace(trt_gm, inputs) | ||
torch.jit.save(trt_script_model, "trt_model.ts") | ||
|
||
# Later, you can load it and run inference | ||
model = torch.jit.load("trt_model.ts").cuda() | ||
model(inputs) | ||
|
||
b) ExportedProgram | ||
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant | ||
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
from torch_tensorrt.dynamo.export import transform, create_exported_program | ||
|
||
model = MyModel().eval().cuda() | ||
inputs = torch.randn((1, 3, 224, 224)).cuda() | ||
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule | ||
# Transform and create an exported program | ||
trt_gm = transform(trt_gm, inputs) | ||
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict()) | ||
torch._export.save(trt_exp_program, "trt_model.ep") | ||
|
||
# Later, you can load it and run inference | ||
model = torch._export.load("trt_model.ep") | ||
model(inputs) | ||
|
||
`torch_tensorrt.dynamo.export.transform` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together. | ||
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes). | ||
|
||
NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341 | ||
|
||
2) Torchscript IR | ||
|
||
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR. | ||
This behavior stays the same in 2.X versions as well. | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch_tensorrt | ||
|
||
model = MyModel().eval().cuda() | ||
inputs = torch.randn((1, 3, 224, 224)).cuda() | ||
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object | ||
torch.jit.save(trt_ts, "trt_model.ts") | ||
|
||
# Later, you can load it and run inference | ||
model = torch.jit.load("trt_model.ts").cuda() | ||
model(inputs) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.