Skip to content

TorchTensorRTModule Serialization Fix #3572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Collection, Optional, Set, Tuple, Union
from typing import Any, Collection, Optional, Set, Tuple, Union

from torch.fx.node import Target
from torch_tensorrt._Device import Device
Expand Down Expand Up @@ -143,6 +143,21 @@ class CompilationSettings:
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
ConverterRegistry,
)

state = self.__dict__.copy()
state["torch_executed_ops"] = {
op if isinstance(op, str) else ConverterRegistry.qualified_name_or_str(op)
for op in state["torch_executed_ops"]
}
return state

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
Expand Down
33 changes: 33 additions & 0 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import tempfile
import unittest
Expand Down Expand Up @@ -372,6 +373,38 @@ def test_resnet18_dynamic(ir):
)


@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
def test_resnet18_torch_exec_ops_serde(ir):
"""
This tests export save and load functionality on Resnet18 model
"""
model = models.resnet18().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [input],
"ir": ir,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"torch_executed_ops": {torch.ops.aten.addmm, "torch.ops.aten.add"},
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)
deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = deser_trt_module(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_hybrid_conv_fallback(ir):
"""
Expand Down
37 changes: 37 additions & 0 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,43 @@ def test_resnet18_cpu_offload(ir):
torch._dynamo.reset()


@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
def test_resnet18_torch_exec_ops(ir):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
Same here

model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.float32,
)
],
"ir": ir,
"enabled_precisions": {torch.float32, torch.float16, torch.bfloat16},
"min_block_size": 1,
"debug": True,
"output_format": "exported_program",
"cache_built_engines": True,
"reuse_cached_engines": True,
"torch_executed_ops": {torch.ops.aten.matmul, "torch.ops.aten.add"},
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
def test_mobilenet_v2(ir):
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
Expand Down
Loading