Skip to content

Commit 2dd1ba3

Browse files
committed
feat(//core/execution): Type checking for the executor, now is the
responsibility of the user to transfer data to GPU and ensure types are correct. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2f86f84 commit 2dd1ba3

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,19 @@ More Information / System Architecture:
1717
...
1818
auto compile_settings = trtorch::ExtraInfo(dims);
1919
// FP16 execution
20-
compile_settings.op_precision = torch::kHalf;
20+
compile_settings.op_precision = torch::kFloat;
2121
// Compile module
2222
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
2323
// Run like normal
2424
auto results = trt_mod.forward({in_tensor});
2525
...
2626
```
2727
28+
> Notes on running in lower precisions:
29+
> - Set precision with extra_info.op_precision
30+
> - The module should be left in FP32 before compilation
31+
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32
32+
2833
## Platform Support
2934
3035
| Platform | Support |

core/execution/TRTEngine.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ c10::FunctionSchema GenerateEngineFunctionSchema(EngineID id, nvinfer1::ICudaEng
4040
}
4141
}
4242
}
43-
43+
4444
ss << in_ss.str();
4545
ss << ") -> (";
4646
ss << out_ss.str();
@@ -56,15 +56,15 @@ TRTEngine::TRTEngine(nvinfer1::ILogger& logger, std::string& serialized_engine)
5656
: schema(torch::jit::parseSchema("trt::noop() -> ()")) { // Need a better default
5757

5858
rt = nvinfer1::createInferRuntime(logger);
59-
59+
6060
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
6161
// Easy way to get a unique name for each engine, maybe there is a more descriptive way (using something associated with the graph maybe)
6262
id = reinterpret_cast<EngineID>(cuda_engine);
6363
exec_ctx = cuda_engine->createExecutionContext();
64-
64+
6565
uint64_t inputs = 0;
6666
uint64_t outputs = 0;
67-
67+
6868
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
6969
if(cuda_engine->bindingIsInput(x)) {
7070
inputs++;

core/execution/execution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include <utility>
2+
#include <utility>
33
#include "NvInfer.h"
44
#include "ATen/core/function_schema.h"
55

core/execution/register_trt_op.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "c10/cuda/CUDAStream.h"
22

3+
#include "torch/torch.h"
34
#include "torch/csrc/jit/custom_operator.h"
45

56
#include "core/util/prelude.h"
@@ -15,6 +16,18 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
1516
std::vector<at::Tensor> contig_inputs{};
1617
contig_inputs.reserve(inputs.size());
1718
for (size_t i = 0; i < inputs.size(); i++) {
19+
TRTORCH_CHECK(inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
20+
auto expected_type = torch::kF32;
21+
switch (ctx->getEngine().getBindingDataType(i)) {
22+
case nvinfer1::DataType::kHALF:
23+
expected_type = torch::kF16;
24+
break;
25+
case nvinfer1::DataType::kFLOAT:
26+
case nvinfer1::DataType::kINT8:
27+
default:
28+
expected_type = torch::kF32;
29+
}
30+
TRTORCH_CHECK(inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
1831
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
1932
auto shape = core::util::toVec(dims);
2033
contig_inputs.push_back(inputs[i].to(at::kCUDA).view(shape).contiguous());

0 commit comments

Comments
 (0)