Skip to content

Commit 5a105c6

Browse files
committed
refactor(//core)!: Introducing a binding convention that will address
determinism issues with TensorRT The binding convention now looks for bindings by name and reorders outputs to match the order expected by PyTorch. BREAKING CHANGE: This changes the "ABI" of compiled TRTorch programs and the runtime and breaks backwards compatability between the runtime in 0.1.0+ and programs compiled pre-0.1.0 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7cfcca4 commit 5a105c6

File tree

6 files changed

+48
-71
lines changed

6 files changed

+48
-71
lines changed

core/conversion/conversion.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,12 @@ void AddInputs(ConversionCtx* ctx,
143143
for (size_t i = 0; i < input_tensors.size(); i++) {
144144
auto in = input_tensors[i];
145145
auto dims = input_dims[i];
146+
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
146147
LOG_INFO(ctx->logger,
147148
"Adding Input " << in->debugName() \
148-
<< " (conversion.AddInputs)");
149+
<< " named " << name << " in engine (conversion.AddInputs)");
149150
LOG_DEBUG(ctx->logger, "Input shape set to " << dims.input_shape);
150-
auto trt_in = ctx->net->addInput(in->debugName().c_str(),
151+
auto trt_in = ctx->net->addInput(name.c_str(),
151152
ctx->input_type, dims.input_shape);
152153
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
153154

@@ -160,6 +161,7 @@ void AddInputs(ConversionCtx* ctx,
160161
}
161162

162163
ctx->value_tensor_map[in] = trt_in;
164+
ctx->num_inputs += 1;
163165
}
164166

165167
TRTORCH_CHECK(profile->isValid(), "Optimization profile is invalid, please check the input range provided (conversion.AddInputs)");
@@ -174,14 +176,17 @@ void AddInputs(ConversionCtx* ctx,
174176

175177
void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
176178
for (auto out : outputs) {
179+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
177180
auto it = ctx->value_tensor_map.find(out);
178181
// Leaves the potential for unused outputs to be populated with nullptr "safely"
179182
TRTORCH_CHECK(it != ctx->value_tensor_map.end() && it->second,
180183
"No corresponding output TRT Tensor found for TorchScript output: " << out->debugName());
181184
auto out_tensor = it->second;
185+
out_tensor->setName(name.c_str());
182186
ctx->net->markOutput(*out_tensor);
183187
LOG_INFO(ctx->logger,
184-
"Marking Output " << out->debugName() << " (ctx.MarkOutput)");
188+
"Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
189+
ctx->num_outputs += 1;
185190
}
186191
}
187192

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ struct ConversionCtx {
4242

4343
~ConversionCtx();
4444

45+
uint64_t num_inputs = 0;
46+
uint64_t num_outputs = 0;
4547
bool input_is_dynamic = false;
4648
nvinfer1::IBuilder* builder;
4749
nvinfer1::INetworkDefinition* net;

core/execution/TRTEngine.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,20 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
4242
uint64_t outputs = 0;
4343

4444
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
45+
std::string name = cuda_engine->getBindingName(x);
46+
std::string idx_s = name.substr(name.find("_") + 1);
47+
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
48+
4549
if(cuda_engine->bindingIsInput(x)) {
4650
inputs++;
51+
in_binding_map[x] = idx;
4752
} else {
4853
outputs++;
54+
out_binding_map[x] = idx;
4955
}
5056
}
5157
num_io = std::make_pair(inputs, outputs);
58+
5259
}
5360

5461
TRTEngine& TRTEngine::operator=(const TRTEngine& other) {

core/execution/execution.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ struct TRTEngine : torch::CustomClassHolder {
2222
std::string name;
2323
util::logging::TRTorchLogger logger;
2424

25+
std::unordered_map<uint64_t, uint64_t> in_binding_map;
26+
std::unordered_map<uint64_t, uint64_t> out_binding_map;
27+
2528
~TRTEngine();
2629
TRTEngine(std::string serialized_engine);
2730
TRTEngine(std::string mod_name, std::string serialized_engine);
@@ -30,7 +33,7 @@ struct TRTEngine : torch::CustomClassHolder {
3033
//c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
3134
};
3235

33-
std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pair<uint64_t, uint64_t> io, std::vector<at::Tensor>& inputs);
36+
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3437

3538
} // namespace execution
3639
} // namespace core

core/execution/register_trt_op.cpp

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,50 +9,43 @@
99
namespace trtorch {
1010
namespace core {
1111
namespace execution {
12-
std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pair<uint64_t, uint64_t> io, std::vector<at::Tensor>& inputs) {
12+
13+
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
14+
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
1315
std::vector<void*> gpu_handles;
1416

1517
std::vector<at::Tensor> contig_inputs{};
1618
contig_inputs.reserve(inputs.size());
19+
1720
for (size_t i = 0; i < inputs.size(); i++) {
18-
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
19-
auto expected_type = util::toATenDType(ctx->getEngine().getBindingDataType(i));
20-
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
21-
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
21+
uint64_t pyt_idx = compiled_engine->in_binding_map[i];
22+
TRTORCH_CHECK(inputs[pyt_idx].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[pyt_idx].device());
23+
auto expected_type = util::toATenDType(compiled_engine->exec_ctx->getEngine().getBindingDataType(i));
24+
TRTORCH_CHECK(inputs[pyt_idx].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[pyt_idx].dtype());
25+
auto dims = core::util::toDimsPad(inputs[pyt_idx].sizes(), 1);
2226
auto shape = core::util::toVec(dims);
23-
contig_inputs.push_back(inputs[i].view(shape).contiguous());
27+
contig_inputs.push_back(inputs[pyt_idx].view(shape).contiguous());
2428
LOG_DEBUG("Input shape: " << dims);
25-
ctx->setBindingDimensions(i, dims);
29+
compiled_engine->exec_ctx->setBindingDimensions(i, dims);
2630
gpu_handles.push_back(contig_inputs.back().data_ptr());
2731
}
2832

29-
TRTORCH_CHECK(ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)");
33+
TRTORCH_CHECK(compiled_engine->exec_ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)");
3034

31-
std::vector<at::Tensor> outputs;
32-
for (uint64_t o = inputs.size(); o < (io.first + io.second); o++) {
33-
auto out_shape = ctx->getBindingDimensions(o);
35+
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
36+
for (size_t o = inputs.size(); o < (compiled_engine->num_io.first + compiled_engine->num_io.second); o++) {
37+
uint64_t pyt_idx = compiled_engine->out_binding_map[o];
38+
auto out_shape = compiled_engine->exec_ctx->getBindingDimensions(o);
3439
LOG_DEBUG("Output shape: " << out_shape);
3540
auto dims = core::util::toVec(out_shape);
36-
auto type = util::toATenDType(ctx->getEngine().getBindingDataType(o));
37-
outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous());
38-
gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr());
41+
auto type = util::toATenDType(compiled_engine->exec_ctx->getEngine().getBindingDataType(o));
42+
std::cout << pyt_idx << std::endl;
43+
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
44+
gpu_handles.push_back(outputs[pyt_idx].data_ptr());
3945
}
4046

41-
// Is this the right stream?
4247
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
43-
44-
ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
45-
46-
return outputs;
47-
}
48-
49-
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> engine) {
50-
// Verify calling convention (right to left or left to right)
51-
LOG_DEBUG("Attempting to run engine (ID: " << std::hex << engine->name << ")");
52-
53-
auto io = engine->num_io;
54-
auto ctx = engine->exec_ctx;
55-
auto outputs = RunCudaEngine(ctx, io, inputs);
48+
compiled_engine->exec_ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
5649

5750
return outputs;
5851
}

tests/util/run_graph_engine.cpp

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include "c10/cuda/CUDAStream.h"
44
#include "torch/csrc/jit/ir/ir.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/custom_class.h"
67
#include "core/conversion/conversion.h"
8+
#include "core/execution/execution.h"
79
#include "cuda_runtime_api.h"
810

911
#include <vector>
@@ -28,7 +30,7 @@ std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::T
2830
auto opt = core::util::toVec(i.sizes());
2931

3032
std::vector<int64_t> min_range(opt);
31-
std::vector<int64_t> max_range(opt);
33+
std::vector<int64_t> max_range(opt);
3234

3335
min_range[1] = ceil(opt[1]/2.0);
3436
max_range[1] = 2*opt[1];
@@ -40,44 +42,9 @@ std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::T
4042
}
4143

4244
std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs) {
43-
auto rt = nvinfer1::createInferRuntime(core::util::logging::get_logger());
44-
auto engine = rt->deserializeCudaEngine(eng.c_str(), eng.size());
45-
auto ctx = engine->createExecutionContext();
46-
47-
std::vector<void*> gpu_handles;
48-
49-
std::vector<at::Tensor> contig_inputs{};
50-
contig_inputs.reserve(inputs.size());
51-
for (size_t i = 0; i < inputs.size(); i++) {
52-
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
53-
auto expected_type = core::util::toATenDType(ctx->getEngine().getBindingDataType(i));
54-
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
55-
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
56-
auto shape = core::util::toVec(dims);
57-
contig_inputs.push_back(inputs[i].view(shape).contiguous());
58-
LOG_DEBUG("In shape:" << shape);
59-
ctx->setBindingDimensions(i, dims);
60-
gpu_handles.push_back(contig_inputs.back().data_ptr());
61-
}
62-
63-
TRTORCH_CHECK(ctx->allInputDimensionsSpecified(), "Not enough inputs provided (execution.RunCudaEngine)");
64-
65-
std::vector<at::Tensor> outputs;
66-
for (int64_t o = inputs.size(); o < engine->getNbBindings(); o++) {
67-
auto out_shape = ctx->getBindingDimensions(o);
68-
LOG_DEBUG("Output: " << engine->getBindingName(o) << " out shape: " << out_shape);
69-
auto dims = core::util::toVec(out_shape);
70-
auto type = core::util::toATenDType(ctx->getEngine().getBindingDataType(o));
71-
outputs.push_back(at::empty(dims, {at::kCUDA}).to(type).contiguous());
72-
gpu_handles.push_back(outputs[outputs.size() - 1].data_ptr());
73-
}
74-
75-
// Is this the right stream?
76-
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
77-
78-
ctx->enqueueV2(gpu_handles.data(), stream, nullptr);
79-
80-
stream.synchronize();
45+
LOG_DEBUG("Running TRT version");
46+
auto engine_ptr = c10::make_intrusive<trtorch::core::execution::TRTEngine>("test_engine", eng);
47+
auto outputs = trtorch::core::execution::execute_engine(inputs, engine_ptr);
8148
return outputs;
8249
}
8350

0 commit comments

Comments
 (0)