Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fff1b80

Browse files
committedMar 20, 2024
chore: Upgrade to TRT 10.0
chore: updates to trt api chore: trt 10 fixes chore: more fixes
1 parent 4ae6ab9 commit fff1b80

File tree

20 files changed

+205
-182
lines changed

20 files changed

+205
-182
lines changed
 

‎core/conversion/converters/impl/constant_pad.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5555
util::toDims(c10::IntArrayRef(stride)));
5656
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
5757
slice_layer->setName((util::node_info(n) + "_slice").c_str());
58-
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
58+
slice_layer->setMode(nvinfer1::SampleMode::kFILL);
5959
slice_layer->setInput(4, *value_itensor);
6060

6161
if (ctx->input_is_dynamic) {

‎core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ nvinfer1::ILayer* add_bias_layer(
6161
auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride);
6262
sliceLayer->setInput(1, *start);
6363
sliceLayer->setInput(2, *size);
64-
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
64+
sliceLayer->setMode(nvinfer1::SampleMode::kFILL);
6565
nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0);
6666

6767
nvinfer1::Dims constantDims;
@@ -194,7 +194,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
194194
nvinfer1::IConvolutionLayer* convLayer =
195195
ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data);
196196
convLayer->setStrideNd(stride);
197-
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
197+
convLayer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
198198
convLayer->setPaddingNd(padding);
199199
convLayer->setPostPadding(out_padding);
200200
convLayer->setDilationNd(dilation);
@@ -293,7 +293,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
293293
TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n);
294294

295295
conv->setStrideNd(stride);
296-
conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
296+
conv->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
297297
conv->setPaddingNd(padding);
298298
conv->setPostPadding(out_padding);
299299
conv->setDilationNd(dilation);

‎core/conversion/converters/impl/interpolate.cpp

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void resize_layer_size(
7272
nvinfer1::ITensor* in,
7373
std::vector<int64_t> out_shape,
7474
std::vector<float> scales,
75-
nvinfer1::ResizeMode mode,
75+
nvinfer1::InterpolationMode mode,
7676
bool align_corners = false) {
7777
TORCHTRT_CHECK((out_shape.size() > 0) ^ (scales.size() > 0), "only one of out_shape or scales should be defined");
7878
auto resize_layer = ctx->net->addResize(*in);
@@ -141,7 +141,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
141141
float scale = args[2].IValue()->toDouble();
142142
std::vector<float> padded_scales(in_shape.size(), 1);
143143
padded_scales[padded_scales.size() - 1] = scale;
144-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
144+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
145145
} else {
146146
// Case 2: user uses output size
147147
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -150,7 +150,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
150150

151151
auto out_shape = in_shape;
152152
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
153-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
153+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
154154
}
155155

156156
return true;
@@ -172,7 +172,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
172172
float scale = scale_factors[0];
173173
std::vector<float> padded_scales(in_shape.size(), 1);
174174
padded_scales[padded_scales.size() - 1] = scale;
175-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
175+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
176176
} else {
177177
// Case 2: user uses output size
178178
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -181,7 +181,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
181181

182182
auto out_shape = in_shape;
183183
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
184-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
184+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
185185
}
186186

187187
return true;
@@ -203,7 +203,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
203203
std::vector<float> padded_scales(in_shape.size(), 1);
204204
padded_scales[padded_scales.size() - 2] = scale_h;
205205
padded_scales[padded_scales.size() - 1] = scale_w;
206-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
206+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
207207
} else {
208208
// Case 2: user uses output size
209209
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -212,7 +212,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
212212

213213
auto out_shape = in_shape;
214214
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
215-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
215+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
216216
}
217217

218218
return true;
@@ -236,7 +236,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
236236
std::vector<float> padded_scales(in_shape.size(), 1);
237237
padded_scales[padded_scales.size() - 2] = scale_h;
238238
padded_scales[padded_scales.size() - 1] = scale_w;
239-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
239+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
240240
} else {
241241
// Case 2: user uses output size
242242
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -245,7 +245,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
245245

246246
auto out_shape = in_shape;
247247
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
248-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
248+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
249249
}
250250

251251
return true;
@@ -270,7 +270,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
270270
padded_scales[padded_scales.size() - 3] = scale_d;
271271
padded_scales[padded_scales.size() - 2] = scale_h;
272272
padded_scales[padded_scales.size() - 1] = scale_w;
273-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
273+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
274274
} else {
275275
// Case 2: user uses output size
276276
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -279,7 +279,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
279279

280280
auto out_shape = in_shape;
281281
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
282-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
282+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
283283
}
284284

285285
return true;
@@ -306,7 +306,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
306306
padded_scales[padded_scales.size() - 3] = scale_d;
307307
padded_scales[padded_scales.size() - 2] = scale_h;
308308
padded_scales[padded_scales.size() - 1] = scale_w;
309-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kNEAREST);
309+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kNEAREST);
310310
} else {
311311
// Case 2: user uses output size
312312
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -315,7 +315,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
315315

316316
auto out_shape = in_shape;
317317
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
318-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kNEAREST);
318+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kNEAREST);
319319
}
320320

321321
return true;
@@ -336,7 +336,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
336336
float scale = args[3].IValue()->toDouble();
337337
std::vector<float> padded_scales(in_shape.size(), 1);
338338
padded_scales[padded_scales.size() - 1] = scale;
339-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
339+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
340340
} else {
341341
// Case 2: user uses output size
342342
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -345,7 +345,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
345345

346346
auto out_shape = in_shape;
347347
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
348-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
348+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
349349
}
350350

351351
return true;
@@ -368,7 +368,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
368368
float scale = scale_factors[0];
369369
std::vector<float> padded_scales(in_shape.size(), 1);
370370
padded_scales[padded_scales.size() - 1] = scale;
371-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
371+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
372372
} else {
373373
// Case 2: user uses output size
374374
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -377,7 +377,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
377377

378378
auto out_shape = in_shape;
379379
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
380-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
380+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
381381
}
382382

383383
return true;
@@ -400,7 +400,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
400400
std::vector<float> padded_scales(in_shape.size(), 1);
401401
padded_scales[padded_scales.size() - 2] = scale_h;
402402
padded_scales[padded_scales.size() - 1] = scale_w;
403-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
403+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
404404
} else {
405405
// Case 2: user uses output size
406406
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -410,7 +410,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
410410

411411
auto out_shape = in_shape;
412412
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
413-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
413+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
414414
}
415415

416416
return true;
@@ -435,7 +435,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
435435
std::vector<float> padded_scales(in_shape.size(), 1);
436436
padded_scales[padded_scales.size() - 2] = scale_h;
437437
padded_scales[padded_scales.size() - 1] = scale_w;
438-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
438+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
439439
} else {
440440
// Case 2: user uses output size
441441
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -445,7 +445,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
445445

446446
auto out_shape = in_shape;
447447
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
448-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
448+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
449449
}
450450

451451
return true;
@@ -470,7 +470,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
470470
padded_scales[padded_scales.size() - 3] = scale_d;
471471
padded_scales[padded_scales.size() - 2] = scale_h;
472472
padded_scales[padded_scales.size() - 1] = scale_w;
473-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
473+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
474474
} else {
475475
// Case 2: user uses output size
476476
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -480,7 +480,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
480480

481481
auto out_shape = in_shape;
482482
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
483-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
483+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
484484
}
485485

486486
return true;
@@ -507,7 +507,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
507507
padded_scales[padded_scales.size() - 3] = scale_d;
508508
padded_scales[padded_scales.size() - 2] = scale_h;
509509
padded_scales[padded_scales.size() - 1] = scale_w;
510-
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::ResizeMode::kLINEAR, align_corners);
510+
resize_layer_size(ctx, n, in, {}, padded_scales, nvinfer1::InterpolationMode::kLINEAR, align_corners);
511511
} else {
512512
// Case 2: user uses output size
513513
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
@@ -517,7 +517,7 @@ auto interpolate_registrations TORCHTRT_UNUSED =
517517

518518
auto out_shape = in_shape;
519519
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
520-
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::ResizeMode::kLINEAR, align_corners);
520+
resize_layer_size(ctx, n, in, out_shape, {}, nvinfer1::InterpolationMode::kLINEAR, align_corners);
521521
}
522522

523523
return true;

‎core/conversion/converters/impl/linear.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,26 @@ auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
4040
in = in_shuffle->getOutput(0);
4141
}
4242

43-
auto w_tensor = args[1].IValue()->toTensor();
44-
Weights w = Weights(ctx, w_tensor);
43+
// Convert w_tensor to ITensor
44+
auto weight = args[1].IValue()->toTensor();
45+
auto weight_tensor = tensor_to_const(ctx, weight, util::node_info(n) + "_weight");
46+
auto mm_layer = ctx->net->addMatrixMultiply(
47+
*in, nvinfer1::MatrixOperation::kNONE, *weight_tensor, nvinfer1::MatrixOperation::kNONE);
4548

46-
nvinfer1::ILayer* new_layer;
47-
if (!args[2].IValue()->isNone()) {
48-
Weights b(ctx, args[2].IValue()->toTensor());
49-
new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, b.data);
50-
} else {
51-
LOG_DEBUG("There is no bias for the linear layer");
52-
new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, Weights().data);
53-
}
49+
TORCHTRT_CHECK(mm_layer, "Unable to create linear layer from node: " << *n);
50+
mm_layer->setName(util::node_info(n).c_str());
5451

55-
TORCHTRT_CHECK(new_layer, "Unable to create linear layer from node: " << *n);
52+
auto mm_output = mm_layer->getOutput(0);
5653

57-
new_layer->setName(util::node_info(n).c_str());
58-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
54+
if (!args[2].IValue()->isNone()) {
55+
// Convert bias to ITensor
56+
auto bias = args[2].IValue()->toTensor();
57+
auto bias_tensor = tensor_to_const(ctx, bias, util::node_info(n) + "_bias");
58+
auto bias_add_layer = add_elementwise(
59+
ctx, nvinfer1::ElementWiseOperation::kSUM, mm_output, bias_tensor, util::node_info(n) + "_bias_add");
60+
mm_output = bias_add_layer->getOutput(0);
61+
}
62+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_output);
5963

6064
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
6165

‎core/runtime/TRTEngine.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,25 @@ TRTEngine::TRTEngine(
120120
} else {
121121
uint64_t inputs_size = _in_binding_names.size();
122122
in_binding_names.resize(inputs_size);
123-
for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
123+
for (uint64_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
124124
auto binding_name = _in_binding_names[pyt_idx];
125-
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
126-
std::string engine_binded_name = cuda_engine->getIOTensorName(trt_idx);
127-
TORCHTRT_CHECK(
128-
(binding_name == engine_binded_name),
129-
"Could not find a TensorRT engine binding for input named " << binding_name);
125+
// Check if the binding name provided is in the list of engine's bindings
126+
// by iterating through nbIOTensors and verify it is an input binding
127+
bool is_binding = false, is_input = false;
128+
int32_t trt_idx;
129+
for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) {
130+
std::string curr_bind_name = cuda_engine->getIOTensorName(idx);
131+
if (curr_bind_name == binding_name) {
132+
is_binding = true;
133+
trt_idx = idx;
134+
if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
135+
is_input = true;
136+
}
137+
}
138+
}
139+
TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for input named " << binding_name);
130140
TORCHTRT_CHECK(
131-
(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
132-
"Binding " << binding_name << " specified as input but found as output in TensorRT engine");
141+
is_input, "Binding " << binding_name << " specified as input but found as output in TensorRT engine");
133142
LOG_DEBUG(
134143
"Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
135144
<< ", Torch binding index: " << pyt_idx);
@@ -141,11 +150,25 @@ TRTEngine::TRTEngine(
141150
out_binding_names.resize(outputs);
142151
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
143152
auto binding_name = _out_binding_names[pyt_idx];
144-
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
145-
TORCHTRT_CHECK((trt_idx != -1), "Could not find a TensorRT engine binding for output named " << binding_name);
153+
// Check if the binding name provided is in the list of engine's bindings
154+
// by iterating through nbIOTensors and verify it is an output binding
155+
bool is_binding = false, is_output = false;
156+
int32_t trt_idx;
157+
for (int32_t idx = 0; idx < cuda_engine->getNbIOTensors(); idx++) {
158+
std::string curr_bind_name = cuda_engine->getIOTensorName(idx);
159+
if (curr_bind_name == binding_name) {
160+
is_binding = true;
161+
trt_idx = idx;
162+
if (cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kOUTPUT) {
163+
is_output = true;
164+
}
165+
}
166+
}
167+
168+
TORCHTRT_CHECK(is_binding, "Could not find a TensorRT engine binding for output named " << binding_name);
146169
TORCHTRT_CHECK(
147-
!(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
148-
"Binding " << binding_name << " specified as output but found as input in TensorRT engine");
170+
is_output, "Binding " << binding_name << " specified as output but found as input in TensorRT engine");
171+
149172
LOG_DEBUG(
150173
"Output binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
151174
<< ", Torch binding index: " << inputs_size + pyt_idx);

‎cpp/include/torch_tensorrt/ptq.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
#include "torch_tensorrt/macros.h"
2222

2323
#ifndef DOXYGEN_SHOULD_SKIP_THIS
24-
namespace nvinfer1 {
25-
class IInt8Calibrator;
26-
class IInt8EntropyCalibrator2;
27-
} // namespace nvinfer1
24+
// namespace nvinfer1 {
25+
// class IInt8Calibrator;
26+
// class IInt8EntropyCalibrator2;
27+
// } // namespace nvinfer1
2828

2929
namespace torch_tensorrt {
3030
namespace ptq {

‎py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "pybind11/stl.h"
33

44
#include "ATen/core/jit_type.h"
5+
#include "NvInferRuntimeBase.h"
56
#include "Python.h"
67
#include "core/compiler.h"
78
#include "core/conversion/conversion.h"
@@ -77,6 +78,10 @@ class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrato
7778
using Derived = pyCalibratorTrampoline<nvinfer1::IInt8Calibrator>;
7879
using Derived::Derived;
7980

81+
nvinfer1::InterfaceInfo getInterfaceInfo() const noexcept override {
82+
return nvinfer1::InterfaceInfo{"PYTHON CALIBRATOR", 1, 0};
83+
}
84+
8085
nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override {
8186
try {
8287
PYBIND11_OVERLOAD_PURE_NAME(

‎py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def convert_module_to_trt_engine(
634634
import io
635635

636636
with io.BytesIO() as engine_bytes:
637-
engine_bytes.write(interpreter_result.engine.serialize())
637+
engine_bytes.write(interpreter_result.engine)
638638
engine_bytearray = engine_bytes.getvalue()
639639

640640
return engine_bytearray

‎py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def run(
172172

173173
if version.parse(trt.__version__) >= version.parse("8.2"):
174174
builder_config.profiling_verbosity = (
175-
trt.ProfilingVerbosity.VERBOSE
175+
trt.ProfilingVerbosity.DETAILED
176176
if self.compilation_settings.debug
177177
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
178178
)
@@ -252,7 +252,7 @@ def run(
252252
if tactic_sources is not None:
253253
builder_config.set_tactic_sources(tactic_sources=tactic_sources)
254254

255-
engine = self.builder.build_engine(self.ctx.net, builder_config)
255+
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
256256
assert engine
257257

258258
serialized_cache = (
@@ -263,7 +263,7 @@ def run(
263263
_LOGGER.info(
264264
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
265265
)
266-
_LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory")
266+
_LOGGER.info(f"TRT Engine uses: {engine.nbytes} bytes of Memory")
267267

268268
return TRTInterpreterResult(
269269
engine, self._input_names, self._output_names, serialized_cache

‎py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ def convert_module(
8787
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
8888

8989
with io.BytesIO() as engine_bytes:
90-
engine_bytes.write(interpreter_result.engine.serialize())
90+
engine_bytes.write(interpreter_result.engine)
9191
engine_str = engine_bytes.getvalue()
92+
9293
return TorchTensorRTModule(
9394
serialized_engine=engine_str,
9495
name=name,

‎py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def convNd(
6363
)
6464

6565
# Process weight terms
66-
if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor):
66+
if isinstance(weight, TRTTensor):
6767
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
6868
# Append new dimension (unsqueeze) if the convolution is 1d
6969
if is_conv1d:

‎py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def deconvNd(
6363
)
6464

6565
# Process weight terms
66-
if ctx.net.has_explicit_precision or isinstance(weight, TRTTensor):
66+
if isinstance(weight, TRTTensor):
6767
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
6868
# Append new dimension (unsqueeze) if the deconvolution is 1d
6969
if is_deconv1d:

‎py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,6 @@ def convert_binary_elementwise(
147147
ctx, rhs_val, trt_promoted_type, name, target, source_ir
148148
)
149149

150-
# Check the limitation in the doc string.
151-
if ctx.net.has_implicit_batch_dimension:
152-
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
153-
assert len(lhs_val.shape) >= len(
154-
rhs_val.shape
155-
), f"{lhs_val.shape} >= {rhs_val.shape}"
156-
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
157-
assert len(rhs_val.shape) >= len(
158-
lhs_val.shape
159-
), f"{rhs_val.shape} >= {lhs_val.shape}"
160-
161150
lhs_val, rhs_val = broadcast(
162151
ctx.net, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
163152
)

‎py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def constant_padNd(
5353
)
5454
value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype)
5555
layer.set_input(4, value_const)
56-
layer.mode = trt.SliceMode.FILL
56+
layer.mode = trt.SampleMode.FILL
5757

5858
set_layer_name(layer, target, name, source_ir)
5959
return layer.get_output(0)
@@ -91,7 +91,7 @@ def reflection_padNd(
9191
shape=tuple(new_shape),
9292
stride=tuple(stride_list),
9393
)
94-
layer.mode = trt.SliceMode.REFLECT
94+
layer.mode = trt.SampleMode.REFLECT
9595

9696
set_layer_name(layer, target, name, source_ir)
9797
return layer.get_output(0)
@@ -129,7 +129,7 @@ def replication_padNd(
129129
shape=tuple(new_shape),
130130
stride=tuple(stride_list),
131131
)
132-
layer.mode = trt.SliceMode.CLAMP
132+
layer.mode = trt.SampleMode.CLAMP
133133

134134
set_layer_name(layer, target, name, source_ir)
135135
return layer.get_output(0)
@@ -167,7 +167,7 @@ def circular_padNd(
167167
shape=tuple(new_shape),
168168
stride=tuple(stride_list),
169169
)
170-
layer.mode = trt.SliceMode.WRAP
170+
layer.mode = trt.SampleMode.WRAP
171171

172172
set_layer_name(layer, target, name, source_ir)
173173
return layer.get_output(0)

‎py/torch_tensorrt/dynamo/conversion/impl/permutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def roll(
6666
shape=shape,
6767
stride=stride,
6868
)
69-
layer.mode = trt.SliceMode.WRAP
69+
layer.mode = trt.SampleMode.WRAP
7070
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
7171
return layer.get_output(0)
7272

@@ -83,7 +83,7 @@ def roll(
8383
shape=flatten_shape,
8484
stride=stride,
8585
)
86-
layer.mode = trt.SliceMode.WRAP
86+
layer.mode = trt.SampleMode.WRAP
8787
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
8888
output = layer.get_output(0)
8989
output = impl.shuffle.reshape(

‎py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def upsample(
2929
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
3030
else:
3131
raise RuntimeError(
32-
f"At least one of out_shape and scale_factors should be specified."
32+
"At least one of out_shape and scale_factors should be specified."
3333
)
3434

3535
# interpolate mode
3636
if resize_mode == "nearest" or None:
37-
resize_layer.resize_mode = trt.ResizeMode.NEAREST
37+
resize_layer.resize_mode = trt.InterpolationMode.NEAREST
3838
elif resize_mode == "bilinear":
39-
resize_layer.resize_mode = trt.ResizeMode.LINEAR
39+
resize_layer.resize_mode = trt.InterpolationMode.LINEAR
4040
if align_corners is None or not align_corners:
4141
raise RuntimeError(
4242
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT."

‎py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from contextlib import nullcontext
5-
from typing import Any, Dict, List, Optional, Sequence, Tuple
5+
from typing import Any, Dict, List, Optional, Tuple
66

77
import tensorrt as trt
88
import torch
@@ -55,73 +55,69 @@ def __init__(
5555

5656
def _initialize(self) -> None:
5757
self.initialized = True
58+
logger = trt.Logger()
59+
runtime = trt.Runtime(logger)
60+
self.engine = runtime.deserialize_cuda_engine(self.engine)
5861
self.context = self.engine.create_execution_context()
5962

6063
# Indices of inputs/outputs in the trt engine bindings, in the order
6164
# as they are in the original PyTorch model.
62-
self.input_binding_indices_in_order: Sequence[int] = [
63-
self.engine.get_binding_index(name) for name in self.input_names
64-
]
65-
self.output_binding_indices_in_order: Sequence[int] = [
66-
self.engine.get_binding_index(name) for name in self.output_names
67-
]
68-
primary_input_outputs = set()
69-
primary_input_outputs.update(self.input_binding_indices_in_order)
70-
primary_input_outputs.update(self.output_binding_indices_in_order)
71-
self.hidden_output_binding_indices_in_order: Sequence[int] = []
72-
self.hidden_output_names: Sequence[str] = []
73-
for i in range(
74-
self.engine.num_bindings // self.engine.num_optimization_profiles
75-
):
76-
if i not in primary_input_outputs:
77-
self.hidden_output_binding_indices_in_order.append(i)
78-
self.hidden_output_names.append(self.engine.get_binding_name(i))
7965

80-
assert (self.engine.num_bindings // self.engine.num_optimization_profiles) == (
66+
# TODO: Verify if the following is required especially the hidden outputs
67+
# self.input_binding_indices_in_order: Sequence[int] = [
68+
# self.engine.get_binding_index(name) for name in self.input_names
69+
# ]
70+
# self.output_binding_indices_in_order: Sequence[int] = [
71+
# self.engine.get_binding_index(name) for name in self.output_names
72+
# ]
73+
# primary_input_outputs = set()
74+
# primary_input_outputs.update(self.input_binding_indices_in_order)
75+
# primary_input_outputs.update(self.output_binding_indices_in_order)
76+
# self.hidden_output_binding_indices_in_order: Sequence[int] = []
77+
# self.hidden_output_names: Sequence[str] = []
78+
# for i in range(
79+
# self.engine.num_bindings // self.engine.num_optimization_profiles
80+
# ):
81+
# if i not in primary_input_outputs:
82+
# self.hidden_output_binding_indices_in_order.append(i)
83+
# self.hidden_output_names.append(self.engine.get_binding_name(i))
84+
85+
assert (
86+
self.engine.num_io_tensors // self.engine.num_optimization_profiles
87+
) == (
8188
len(self.input_names)
8289
+ len(self.output_names)
83-
+ len(self.hidden_output_names)
90+
# + len(self.hidden_output_names) #TODO: Verify if this is required
8491
)
8592

8693
self.input_dtypes = [
8794
unified_dtype_converter(
88-
self.engine.get_binding_dtype(idx), Frameworks.TORCH
95+
self.engine.get_tensor_dtype(input_name), Frameworks.TORCH
8996
)
90-
for idx in self.input_binding_indices_in_order
97+
for input_name in self.input_names
9198
]
92-
self.input_shapes: Sequence[Sequence[int]] = [
93-
tuple(self.engine.get_binding_shape(idx))
94-
for idx in self.input_binding_indices_in_order
99+
self.input_shapes = [
100+
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
95101
]
96102
self.output_dtypes = [
97103
unified_dtype_converter(
98-
self.engine.get_binding_dtype(idx), Frameworks.TORCH
104+
self.engine.get_tensor_dtype(output_name), Frameworks.TORCH
99105
)
100-
for idx in self.output_binding_indices_in_order
106+
for output_name in self.output_names
101107
]
102108
self.output_shapes = [
103-
(
104-
tuple(self.engine.get_binding_shape(idx))
105-
if self.engine.has_implicit_batch_dimension
106-
else tuple()
107-
)
108-
for idx in self.output_binding_indices_in_order
109-
]
110-
self.hidden_output_dtypes = [
111-
unified_dtype_converter(
112-
self.engine.get_binding_dtype(idx), Frameworks.TORCH
113-
)
114-
for idx in self.hidden_output_binding_indices_in_order
115-
]
116-
self.hidden_output_shapes = [
117-
(
118-
tuple(self.engine.get_binding_shape(idx))
119-
if self.engine.has_implicit_batch_dimension
120-
else tuple()
121-
)
122-
for idx in self.hidden_output_binding_indices_in_order
109+
self.engine.get_tensor_shape(output_name)
110+
for output_name in self.output_names
123111
]
124112

113+
# TODO: Verify what this is for ?
114+
# self.hidden_output_dtypes = [
115+
# unified_dtype_converter(
116+
# self.engine.get_binding_dtype(idx), Frameworks.TORCH
117+
# )
118+
# for idx in self.hidden_output_binding_indices_in_order
119+
# ]
120+
125121
def _check_initialized(self) -> None:
126122
if not self.initialized:
127123
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
@@ -217,12 +213,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
217213
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
218214

219215
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
220-
bindings: List[Any] = [None] * (
221-
len(self.input_names)
222-
+ len(self.output_names)
223-
+ len(self.hidden_output_names)
224-
)
225-
216+
bindings = []
217+
# [None] * (
218+
# len(self.input_names)
219+
# + len(self.output_names)
220+
# # + len(self.hidden_output_names) # TODO: Verify if this is required
221+
# )
226222
for i, input_name in enumerate(self.input_names):
227223
if not contiguous_inputs[i].is_cuda:
228224
logger.warning(
@@ -241,11 +237,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
241237
contiguous_inputs[i].dtype == self.input_dtypes[i]
242238
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
243239

244-
idx = self.input_binding_indices_in_order[i]
245-
bindings[idx] = contiguous_inputs[i].data_ptr()
246-
247-
self.context.set_binding_shape(
248-
idx, tuple(contiguous_inputs[i].shape)
240+
bindings.append(contiguous_inputs[i].data_ptr())
241+
self.context.set_input_shape(
242+
input_name, tuple(contiguous_inputs[i].shape)
249243
)
250244

251245
with (
@@ -258,26 +252,32 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
258252
# create output tensors
259253
outputs: List[torch.Tensor] = []
260254

261-
for i, idx in enumerate(self.output_binding_indices_in_order):
262-
shape = tuple(self.context.get_binding_shape(idx))
255+
for i, output_name in enumerate(self.output_names):
256+
shape = tuple(self.context.get_tensor_shape(output_name))
263257

264258
output = torch.empty(
265259
size=shape,
266260
dtype=self.output_dtypes[i],
267261
device=torch.cuda.current_device(),
268262
)
263+
bindings.append(output.data_ptr())
269264
outputs.append(output)
270-
bindings[idx] = output.data_ptr()
271265

272-
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
273-
shape = tuple(self.context.get_binding_shape(idx))
266+
# TODO: Check what is this for ?
267+
# for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
268+
# shape = tuple(self.context.get_binding_shape(idx))
274269

275-
output = torch.empty(
276-
size=shape,
277-
dtype=self.hidden_output_dtypes[i],
278-
device=torch.cuda.current_device(),
279-
)
280-
bindings[idx] = output.data_ptr()
270+
# output = torch.empty(
271+
# size=shape,
272+
# dtype=self.hidden_output_dtypes[i],
273+
# device=torch.cuda.current_device(),
274+
# )
275+
276+
# Assign tensor address appropriately
277+
for idx in range(self.engine.num_io_tensors):
278+
self.context.set_tensor_address(
279+
self.engine.get_tensor_name(idx), bindings[idx]
280+
)
281281

282282
with (
283283
torch.autograd.profiler.record_function(
@@ -286,9 +286,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
286286
if self.profiling_enabled
287287
else nullcontext()
288288
):
289-
self.context.execute_async_v2(
290-
bindings, torch.cuda.current_stream().cuda_stream
291-
)
289+
self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream)
292290

293291
if len(outputs) == 1:
294292
return outputs[0]

‎py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,27 @@
33
import math
44
import operator
55
import warnings
6-
from typing import cast, Dict, Optional, Sequence, Tuple, Union
6+
from typing import Dict, Optional, Sequence, Tuple, Union, cast
77

88
import numpy as np
99

1010
# @manual=//deeplearning/trt/python:py_tensorrt
1111
import tensorrt as trt
1212
import torch
13-
14-
from ..converter_registry import tensorrt_converter
15-
16-
from ..tracer.acc_tracer import acc_ops
17-
from ..types import * # noqa: F403
1813
from torch.fx.immutable_collections import immutable_list
1914
from torch.fx.node import Argument, Target
20-
21-
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
22-
23-
from .converter_utils import * # noqa: F403
15+
from torch_tensorrt.fx.converters.impl import activation, convolution
2416
from torch_tensorrt.fx.passes.lower_basic_pass import (
2517
trt_transposed_linear,
2618
trt_transposed_matmul,
2719
)
2820
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation, convolution
21+
22+
from ..converter_registry import tensorrt_converter
23+
from ..tracer.acc_tracer import acc_ops
24+
from ..types import * # noqa: F403
25+
from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter
26+
from .converter_utils import * # noqa: F403
3027

3128
_LOGGER: logging.Logger = logging.getLogger(__name__)
3229

@@ -323,7 +320,7 @@ def acc_ops_pad_with_slice_layer(
323320
)
324321

325322
layer.set_input(4, value_const)
326-
layer.mode = trt.SliceMode.FILL
323+
layer.mode = trt.SampleMode.FILL
327324
set_layer_name(layer, target, name)
328325

329326
return layer.get_output(0)
@@ -840,7 +837,7 @@ def acc_ops_tile(
840837
shapes = [1] * len(dims)
841838
strides = [1] * len(dims)
842839
layer = network.add_slice(input_val, starts, shapes, strides)
843-
layer.mode = trt.SliceMode.WRAP
840+
layer.mode = trt.SampleMode.WRAP
844841
set_layer_name(layer, target, name)
845842

846843
if has_dynamic_shape(input_val.shape): # type: ignore[union-attr]
@@ -3536,9 +3533,9 @@ def acc_ops_interpolate(
35363533
layer.scales = [1, 1] + list(scale_factor)
35373534

35383535
if mode.lower() in ["linear", "bilinear", "trilinear"]:
3539-
layer.resize_mode = trt.ResizeMode.LINEAR
3536+
layer.resize_mode = trt.InterpolationMode.LINEAR
35403537
else:
3541-
layer.resize_mode = trt.ResizeMode.NEAREST
3538+
layer.resize_mode = trt.InterpolationMode.NEAREST
35423539

35433540
if (align_corners is not None) and align_corners:
35443541
layer.coordinate_transformation = (

‎py/torch_tensorrt/fx/utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from enum import Enum
2-
from typing import Dict, List, Optional, Callable, Union
2+
from typing import Callable, Dict, List, Optional, Union
3+
34
import numpy as np
4-
from packaging import version
55

66
# @manual=//deeplearning/trt/python:py_tensorrt
77
import tensorrt as trt
88
import torch
99
from functorch import make_fx
1010
from functorch.experimental import functionalize
11+
from torch_tensorrt._utils import sanitized_torch_version
1112
from torch_tensorrt.fx.passes.lower_basic_pass import (
1213
replace_op_with_indices,
1314
run_const_fold,
1415
)
15-
from torch_tensorrt._utils import sanitized_torch_version
16+
17+
from packaging import version
18+
1619
from .types import Shape, TRTDataType
1720

1821

@@ -45,6 +48,11 @@ class Frameworks(Enum):
4548
Frameworks.TORCH: torch.float32,
4649
Frameworks.TRT: trt.float32,
4750
},
51+
trt.bool: {
52+
Frameworks.NUMPY: bool,
53+
Frameworks.TORCH: torch.bool,
54+
Frameworks.TRT: trt.bool,
55+
},
4856
}
4957

5058
if trt.__version__ >= "7.0":
@@ -89,10 +97,10 @@ def unified_dtype_converter(
8997
The equivalent data type in the requested framework.
9098
"""
9199
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
92-
100+
trt_major_version = int(trt.__version__.split(".")[0])
93101
if dtype in (np.int8, torch.int8, trt.int8):
94102
return DataTypeEquivalence[trt.int8][to]
95-
elif trt.__version__ >= "7.0" and dtype in (np.bool_, torch.bool, trt.bool):
103+
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
96104
return DataTypeEquivalence[trt.bool][to]
97105
elif dtype in (np.int32, torch.int32, trt.int32):
98106
return DataTypeEquivalence[trt.int32][to]

‎tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,10 @@ def forward(self, a, b):
2525
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
2626
)
2727

28-
# Deserialize the TensorRT engine
29-
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
30-
engine = runtime.deserialize_cuda_engine(trt_engine_str)
31-
3228
# Inference on TRT Engine
33-
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
29+
py_trt_module = PythonTorchTensorRTModule(
30+
trt_engine_str, ["a", "b"], ["output0"]
31+
)
3432
trt_output = py_trt_module(input_data_0, input_data_1).cpu()
3533

3634
# Inference on PyTorch model

0 commit comments

Comments
 (0)
Please sign in to comment.