diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 59d2c5d6c0..1751302404 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -25,7 +26,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -313,8 +313,10 @@ def run( ) timing_cache = self._create_timing_cache(builder_config, existing_cache) - engine = self.builder.build_serialized_network(self.ctx.net, builder_config) - assert engine + serialized_engine = self.builder.build_serialized_network( + self.ctx.net, builder_config + ) + assert serialized_engine serialized_cache = ( bytearray(timing_cache.serialize()) @@ -324,10 +326,10 @@ def run( _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory") + _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") return TRTInterpreterResult( - engine, self._input_names, self._output_names, serialized_cache + serialized_engine, self._input_names, self._output_names, serialized_cache ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 0c152e15f1..1fcb765b47 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] def __init__( self, - engine: trt.ICudaEngine, + engine: bytes, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, target_device: Device = Device._current_device(), @@ -60,9 +60,9 @@ def _initialize(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.engine) self.context = self.engine.create_execution_context() - assert ( - self.engine.num_io_tensors // self.engine.num_optimization_profiles - ) == (len(self.input_names) + len(self.output_names)) + assert self.engine.num_io_tensors == ( + len(self.input_names) + len(self.output_names) + ) self.input_dtypes = [ dtype._from(self.engine.get_tensor_dtype(input_name))