Skip to content

Commit 14691e7

Browse files
peri044narendasan
authored andcommitted
fix: Using TensorRT 8 new API calls
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent dad25f6 commit 14691e7

File tree

5 files changed

+19
-20
lines changed

5 files changed

+19
-20
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
129129
}
130130

131131
ConversionCtx::~ConversionCtx() {
132-
builder->destroy();
133-
net->destroy();
134-
cfg->destroy();
132+
delete builder;
133+
delete net;
134+
delete cfg;
135135
for (auto ptr : builder_resources) {
136136
free(ptr);
137137
}
@@ -149,14 +149,11 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
149149
}
150150

151151
std::string ConversionCtx::SerializeEngine() {
152-
auto engine = builder->buildEngineWithConfig(*net, *cfg);
153-
if (!engine) {
154-
TRTORCH_THROW_ERROR("Building TensorRT engine failed");
152+
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);
153+
if (!serialized_network) {
154+
TRTORCH_THROW_ERROR("Building serialized network failed in TensorRT");
155155
}
156-
auto serialized_engine = engine->serialize();
157-
engine->destroy();
158-
auto engine_str = std::string((const char*)serialized_engine->data(), serialized_engine->size());
159-
serialized_engine->destroy();
156+
auto engine_str = std::string((const char*)serialized_network->data(), serialized_network->size());
160157
return engine_str;
161158
}
162159

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct BuilderSettings {
3232
bool strict_types = false;
3333
bool truncate_long_and_double = false;
3434
Device device;
35-
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
35+
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kSTANDARD;
3636
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3737
uint64_t num_min_timing_iters = 2;
3838
uint64_t num_avg_timing_iters = 1;

core/runtime/TRTEngine.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
4343

4444
rt = nvinfer1::createInferRuntime(util::logging::get_logger());
4545

46+
rt = nvinfer1::createInferRuntime(logger);
4647
name = slugify(mod_name) + "_engine";
4748

4849
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
@@ -84,9 +85,10 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
8485
}
8586

8687
TRTEngine::~TRTEngine() {
87-
exec_ctx->destroy();
88-
cuda_engine->destroy();
89-
rt->destroy();
88+
delete exec_ctx;
89+
delete cuda_engine;
90+
delete rt;
91+
9092
}
9193

9294
// TODO: Implement a call method

core/util/trt_util.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DeviceType
8787

8888
inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::EngineCapability& cap) {
8989
switch (cap) {
90-
case nvinfer1::EngineCapability::kDEFAULT:
90+
case nvinfer1::EngineCapability::kSTANDARD:
9191
return stream << "Default";
92-
case nvinfer1::EngineCapability::kSAFE_GPU:
92+
case nvinfer1::EngineCapability::kSAFETY:
9393
return stream << "Safe GPU";
94-
case nvinfer1::EngineCapability::kSAFE_DLA:
94+
case nvinfer1::EngineCapability::kDLA_STANDALONE:
9595
return stream << "Safe DLA";
9696
default:
9797
return stream << "Unknown Engine Capability Setting";

cpp/api/src/compile_spec.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,14 +386,14 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
386386

387387
switch (external.capability) {
388388
case CompileSpec::EngineCapability::kSAFE_GPU:
389-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU;
389+
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFETY;
390390
break;
391391
case CompileSpec::EngineCapability::kSAFE_DLA:
392-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA;
392+
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDLA_STANDALONE;
393393
break;
394394
case CompileSpec::EngineCapability::kDEFAULT:
395395
default:
396-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT;
396+
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSTANDARD;
397397
}
398398

399399
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;

0 commit comments

Comments
 (0)