-
Notifications
You must be signed in to change notification settings - Fork 364
TorchTensorRTModule Serialization Fix #3572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2025-06-13 19:19:30.680941+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2025-06-13 19:19:58.614974+00:00
@@ -225,20 +225,28 @@
return
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
- metadata["settings"].torch_executed_ops = TorchTensorRTModule.serialize_aten_ops(metadata["settings"].torch_executed_ops)
+ metadata["settings"].torch_executed_ops = (
+ TorchTensorRTModule.serialize_aten_ops(
+ metadata["settings"].torch_executed_ops
+ )
+ )
dumped_metadata = pickle.dumps(metadata)
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
return encoded_metadata
@staticmethod
def decode_metadata(encoded_metadata: bytes) -> Any:
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
metadata = pickle.loads(dumped_metadata)
- metadata["settings"].torch_executed_ops = TorchTensorRTModule.deserialize_aten_ops(metadata["settings"].torch_executed_ops)
+ metadata["settings"].torch_executed_ops = (
+ TorchTensorRTModule.deserialize_aten_ops(
+ metadata["settings"].torch_executed_ops
+ )
+ )
return metadata
@staticmethod
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]:
return {str(op) for op in aten_ops}
return metadata | ||
|
||
@staticmethod | ||
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call this _serialize_torch_ops
6bf3127
to
f9235ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we associate this function with CompileSettings not the TorchTensorRTModule? as something like getstate /setstate
@@ -143,6 +144,16 @@ class CompilationSettings: | |||
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE | |||
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU | |||
|
|||
def __getstate__(self) -> dict[str, Any]: | |||
state = self.__dict__.copy() | |||
state["torch_executed_ops"] = {str(op) for op in state["torch_executed_ops"]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider using ConverterRegistry.qualified_name_or_str(target)
1f50da4
to
0b3edfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM can you add a testcase for the serialization? Just save and load a module that uses torch_executed_ops?
0b3edfe
to
2d7c87f
Compare
@@ -84,6 +84,41 @@ def test_resnet18_cpu_offload(ir): | |||
torch._dynamo.reset() | |||
|
|||
|
|||
@pytest.mark.unit | |||
def test_resnet18_torch_exec_ops(ir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also test that you can serialize and deserialize this model right?
@@ -372,6 +372,44 @@ def test_resnet18_dynamic(ir): | |||
) | |||
|
|||
|
|||
@pytest.mark.unit | |||
def test_resnet18_dynamic_torch_exec_ops(ir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test about dynamic inputs or serialization?
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: