diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
index 59d2c5d6c0..1751302404 100644
--- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
+++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
@@ -4,6 +4,7 @@
 from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
 
 import numpy as np
+import tensorrt as trt
 import torch
 import torch.fx
 from torch.fx.node import _get_qualified_name
@@ -25,7 +26,6 @@
 from torch_tensorrt.fx.observer import Observer
 from torch_tensorrt.logging import TRT_LOGGER
 
-import tensorrt as trt
 from packaging import version
 
 _LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -313,8 +313,10 @@ def run(
         )
         timing_cache = self._create_timing_cache(builder_config, existing_cache)
 
-        engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
-        assert engine
+        serialized_engine = self.builder.build_serialized_network(
+            self.ctx.net, builder_config
+        )
+        assert serialized_engine
 
         serialized_cache = (
             bytearray(timing_cache.serialize())
@@ -324,10 +326,10 @@ def run(
         _LOGGER.info(
             f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
         )
-        _LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
+        _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
 
         return TRTInterpreterResult(
-            engine, self._input_names, self._output_names, serialized_cache
+            serialized_engine, self._input_names, self._output_names, serialized_cache
         )
 
     def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
index 0c152e15f1..1fcb765b47 100644
--- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
+++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
@@ -29,7 +29,7 @@ class PythonTorchTensorRTModule(Module):  # type: ignore[misc]
 
     def __init__(
         self,
-        engine: trt.ICudaEngine,
+        engine: bytes,
         input_names: Optional[List[str]] = None,
         output_names: Optional[List[str]] = None,
         target_device: Device = Device._current_device(),
@@ -60,9 +60,9 @@ def _initialize(self) -> None:
         self.engine = runtime.deserialize_cuda_engine(self.engine)
         self.context = self.engine.create_execution_context()
 
-        assert (
-            self.engine.num_io_tensors // self.engine.num_optimization_profiles
-        ) == (len(self.input_names) + len(self.output_names))
+        assert self.engine.num_io_tensors == (
+            len(self.input_names) + len(self.output_names)
+        )
 
         self.input_dtypes = [
             dtype._from(self.engine.get_tensor_dtype(input_name))