From dcd4e092b8b1924c770ac64a31bb98ee036e68a9 Mon Sep 17 00:00:00 2001
From: Gabriel Cuendet <gabriel.cuendet@gmail.com>
Date: Mon, 11 Dec 2023 13:29:02 +0100
Subject: [PATCH] Fix memory leaks

Add missing wrapping of raw pointers in smart pointers, so that the
destructors of the underlying TensorRT objects are called properly

Signed-off-by: Gabriel Cuendet <gabriel.cuendet@cognex.com>
---
 core/conversion/conversionctx/ConversionCtx.cpp | 2 +-
 core/runtime/register_jit_hooks.cpp             | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp
index c9a76602c2..2eb363706f 100644
--- a/core/conversion/conversionctx/ConversionCtx.cpp
+++ b/core/conversion/conversionctx/ConversionCtx.cpp
@@ -164,7 +164,7 @@ void ConversionCtx::RecordNewITensor(const torch::jit::Value* value, nvinfer1::I
 
 std::string ConversionCtx::SerializeEngine() {
 #if NV_TENSORRT_MAJOR > 7
-  auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);
+  auto serialized_network = make_trt(builder->buildSerializedNetwork(*net, *cfg));
   if (!serialized_network) {
     TORCHTRT_THROW_ERROR("Building serialized network failed in TensorRT");
   }
diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp
index 1acc27dda5..5ad0efb3b0 100644
--- a/core/runtime/register_jit_hooks.cpp
+++ b/core/runtime/register_jit_hooks.cpp
@@ -87,7 +87,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
         .def_pickle(
             [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
               // Serialize TensorRT engine
-              auto serialized_trt_engine = self->cuda_engine->serialize();
+              auto serialized_trt_engine = make_trt(self->cuda_engine->serialize());
 
               // Adding device info related meta data to the serialized file
               auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());