4
4
import copy
5
5
import logging
6
6
import pickle
7
- from typing import Any , List , Optional , Tuple , Union
7
+ from typing import Any , List , Optional , Set , Tuple , Union
8
8
9
9
import torch
10
10
from torch_tensorrt ._Device import Device
@@ -227,6 +227,11 @@ def setup_engine(self) -> None:
227
227
228
228
def encode_metadata (self , metadata : Any ) -> str :
229
229
metadata = copy .deepcopy (metadata )
230
+ metadata ["settings" ].torch_executed_ops = (
231
+ TorchTensorRTModule .serialize_aten_ops (
232
+ metadata ["settings" ].torch_executed_ops
233
+ )
234
+ )
230
235
dumped_metadata = pickle .dumps (metadata )
231
236
encoded_metadata = base64 .b64encode (dumped_metadata ).decode ("utf-8" )
232
237
return encoded_metadata
@@ -235,8 +240,21 @@ def encode_metadata(self, metadata: Any) -> str:
235
240
def decode_metadata (encoded_metadata : bytes ) -> Any :
236
241
dumped_metadata = base64 .b64decode (encoded_metadata .encode ("utf-8" ))
237
242
metadata = pickle .loads (dumped_metadata )
243
+ metadata ["settings" ].torch_executed_ops = (
244
+ TorchTensorRTModule .deserialize_aten_ops (
245
+ metadata ["settings" ].torch_executed_ops
246
+ )
247
+ )
238
248
return metadata
239
249
250
+ @staticmethod
251
+ def serialize_aten_ops (aten_ops : Set [torch ._ops .OpOverload ]) -> Set [str ]:
252
+ return {str (op ) for op in aten_ops }
253
+
254
+ @staticmethod
255
+ def deserialize_aten_ops (aten_ops : Set [str ]) -> Set [torch ._ops .OpOverload ]:
256
+ return {eval ("torch.ops." + str (v )) for v in aten_ops }
257
+
240
258
def get_extra_state (self ) -> SerializedTorchTensorRTModuleFmt :
241
259
if self .engine :
242
260
return (
0 commit comments