Skip to content

Commit f9235ed

Browse files
committed
_TorchTensorRTModule Serialization Fix
1 parent 60863a3 commit f9235ed

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import logging
66
import pickle
7-
from typing import Any, List, Optional, Tuple, Union
7+
from typing import Any, List, Optional, Set, Tuple, Union
88

99
import torch
1010
from torch_tensorrt._Device import Device
@@ -227,6 +227,11 @@ def setup_engine(self) -> None:
227227

228228
def encode_metadata(self, metadata: Any) -> str:
229229
metadata = copy.deepcopy(metadata)
230+
metadata["settings"].torch_executed_ops = (
231+
TorchTensorRTModule.serialize_aten_ops(
232+
metadata["settings"].torch_executed_ops
233+
)
234+
)
230235
dumped_metadata = pickle.dumps(metadata)
231236
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
232237
return encoded_metadata
@@ -235,8 +240,21 @@ def encode_metadata(self, metadata: Any) -> str:
235240
def decode_metadata(encoded_metadata: bytes) -> Any:
236241
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
237242
metadata = pickle.loads(dumped_metadata)
243+
metadata["settings"].torch_executed_ops = (
244+
TorchTensorRTModule.deserialize_aten_ops(
245+
metadata["settings"].torch_executed_ops
246+
)
247+
)
238248
return metadata
239249

250+
@staticmethod
251+
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]:
252+
return {str(op) for op in aten_ops}
253+
254+
@staticmethod
255+
def deserialize_aten_ops(aten_ops: Set[str]) -> Set[torch._ops.OpOverload]:
256+
return {eval("torch.ops." + str(v)) for v in aten_ops}
257+
240258
def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
241259
if self.engine:
242260
return (

0 commit comments

Comments
 (0)